sleepyhead111 commited on
Commit
b3360fe
·
verified ·
1 Parent(s): a2608e1

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq-0.10.2/fairseq/__pycache__/__init__.cpython-310.pyc +0 -0
  2. fairseq-0.10.2/fairseq/__pycache__/binarizer.cpython-310.pyc +0 -0
  3. fairseq-0.10.2/fairseq/__pycache__/checkpoint_utils.cpython-310.pyc +0 -0
  4. fairseq-0.10.2/fairseq/__pycache__/distributed_utils.cpython-310.pyc +0 -0
  5. fairseq-0.10.2/fairseq/__pycache__/file_utils.cpython-310.pyc +0 -0
  6. fairseq-0.10.2/fairseq/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc +0 -0
  7. fairseq-0.10.2/fairseq/__pycache__/quantization_utils.cpython-310.pyc +0 -0
  8. fairseq-0.10.2/fairseq/__pycache__/search.cpython-310.pyc +0 -0
  9. fairseq-0.10.2/fairseq/__pycache__/token_generation_constraints.cpython-310.pyc +0 -0
  10. fairseq-0.10.2/fairseq/clib/libbleu/module.cpp +37 -0
  11. fairseq-0.10.2/fairseq/criterions/__pycache__/__init__.cpython-310.pyc +0 -0
  12. fairseq-0.10.2/fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc +0 -0
  13. fairseq-0.10.2/fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc +0 -0
  14. fairseq-0.10.2/fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc +0 -0
  15. fairseq-0.10.2/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc +0 -0
  16. fairseq-0.10.2/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc +0 -0
  17. fairseq-0.10.2/fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc +0 -0
  18. fairseq-0.10.2/fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc +0 -0
  19. fairseq-0.10.2/fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc +0 -0
  20. fairseq-0.10.2/fairseq/criterions/__pycache__/wav2vec_criterion.cpython-310.pyc +0 -0
  21. fairseq-0.10.2/fairseq/data/__init__.py +124 -0
  22. fairseq-0.10.2/fairseq/data/add_target_dataset.py +70 -0
  23. fairseq-0.10.2/fairseq/data/append_token_dataset.py +41 -0
  24. fairseq-0.10.2/fairseq/data/backtranslation_dataset.py +165 -0
  25. fairseq-0.10.2/fairseq/data/base_wrapper_dataset.py +78 -0
  26. fairseq-0.10.2/fairseq/data/bucket_pad_length_dataset.py +76 -0
  27. fairseq-0.10.2/fairseq/data/colorize_dataset.py +25 -0
  28. fairseq-0.10.2/fairseq/data/data_utils.py +499 -0
  29. fairseq-0.10.2/fairseq/data/data_utils_fast.cpp +0 -0
  30. fairseq-0.10.2/fairseq/data/data_utils_fast.pyx +123 -0
  31. fairseq-0.10.2/fairseq/data/denoising_dataset.py +436 -0
  32. fairseq-0.10.2/fairseq/data/indexed_dataset.py +561 -0
  33. fairseq-0.10.2/fairseq/data/iterators.py +594 -0
  34. fairseq-0.10.2/fairseq/data/legacy/masked_lm_dataset.py +303 -0
  35. fairseq-0.10.2/fairseq/data/lm_context_window_dataset.py +79 -0
  36. fairseq-0.10.2/fairseq/data/monolingual_dataset.py +230 -0
  37. fairseq-0.10.2/fairseq/data/nested_dictionary_dataset.py +125 -0
  38. fairseq-0.10.2/fairseq/data/noising.py +333 -0
  39. fairseq-0.10.2/fairseq/data/numel_dataset.py +31 -0
  40. fairseq-0.10.2/fairseq/data/plasma_utils.py +91 -0
  41. fairseq-0.10.2/fairseq/data/prepend_token_dataset.py +41 -0
  42. fairseq-0.10.2/fairseq/data/raw_label_dataset.py +23 -0
  43. fairseq-0.10.2/fairseq/data/replace_dataset.py +36 -0
  44. fairseq-0.10.2/fairseq/data/roll_dataset.py +18 -0
  45. fairseq-0.10.2/fairseq/data/round_robin_zip_datasets.py +117 -0
  46. fairseq-0.10.2/fairseq/data/sort_dataset.py +21 -0
  47. fairseq-0.10.2/fairseq/data/subsample_dataset.py +72 -0
  48. fairseq-0.10.2/fairseq/data/token_block_utils_fast.cpp +0 -0
  49. fairseq-0.10.2/fairseq/data/token_block_utils_fast.pyx +187 -0
  50. fairseq-0.10.2/fairseq/data/transform_eos_dataset.py +120 -0
fairseq-0.10.2/fairseq/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (780 Bytes). View file
 
fairseq-0.10.2/fairseq/__pycache__/binarizer.cpython-310.pyc ADDED
Binary file (3.11 kB). View file
 
fairseq-0.10.2/fairseq/__pycache__/checkpoint_utils.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
fairseq-0.10.2/fairseq/__pycache__/distributed_utils.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
fairseq-0.10.2/fairseq/__pycache__/file_utils.cpython-310.pyc ADDED
Binary file (8.69 kB). View file
 
fairseq-0.10.2/fairseq/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc ADDED
Binary file (5.01 kB). View file
 
fairseq-0.10.2/fairseq/__pycache__/quantization_utils.cpython-310.pyc ADDED
Binary file (3.53 kB). View file
 
fairseq-0.10.2/fairseq/__pycache__/search.cpython-310.pyc ADDED
Binary file (21.9 kB). View file
 
fairseq-0.10.2/fairseq/__pycache__/token_generation_constraints.cpython-310.pyc ADDED
Binary file (16.2 kB). View file
 
fairseq-0.10.2/fairseq/clib/libbleu/module.cpp ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include <Python.h>
10
+
11
+
12
+ static PyMethodDef method_def[] = {
13
+ {NULL, NULL, 0, NULL}
14
+ };
15
+
16
+ static struct PyModuleDef module_def = {
17
+ PyModuleDef_HEAD_INIT,
18
+ "libbleu", /* name of module */
19
+ NULL, /* module documentation, may be NULL */
20
+ -1, /* size of per-interpreter state of the module,
21
+ or -1 if the module keeps state in global variables. */
22
+ method_def
23
+ };
24
+
25
+
26
+ #if PY_MAJOR_VERSION == 2
27
+ PyMODINIT_FUNC init_libbleu()
28
+ #else
29
+ PyMODINIT_FUNC PyInit_libbleu()
30
+ #endif
31
+ {
32
+ PyObject *m = PyModule_Create(&module_def);
33
+ if (!m) {
34
+ return NULL;
35
+ }
36
+ return m;
37
+ }
fairseq-0.10.2/fairseq/criterions/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.07 kB). View file
 
fairseq-0.10.2/fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc ADDED
Binary file (4.68 kB). View file
 
fairseq-0.10.2/fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc ADDED
Binary file (4.45 kB). View file
 
fairseq-0.10.2/fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc ADDED
Binary file (3.83 kB). View file
 
fairseq-0.10.2/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc ADDED
Binary file (6.12 kB). View file
 
fairseq-0.10.2/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc ADDED
Binary file (4.63 kB). View file
 
fairseq-0.10.2/fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc ADDED
Binary file (3.15 kB). View file
 
fairseq-0.10.2/fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc ADDED
Binary file (5.93 kB). View file
 
fairseq-0.10.2/fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc ADDED
Binary file (4.53 kB). View file
 
fairseq-0.10.2/fairseq/criterions/__pycache__/wav2vec_criterion.cpython-310.pyc ADDED
Binary file (5.57 kB). View file
 
fairseq-0.10.2/fairseq/data/__init__.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """isort:skip_file"""
6
+
7
+ from .dictionary import Dictionary, TruncatedDictionary
8
+
9
+ from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
10
+
11
+ from .base_wrapper_dataset import BaseWrapperDataset
12
+
13
+ from .add_target_dataset import AddTargetDataset
14
+ from .append_token_dataset import AppendTokenDataset
15
+ from .audio.raw_audio_dataset import FileAudioDataset
16
+ from .backtranslation_dataset import BacktranslationDataset
17
+ from .bucket_pad_length_dataset import BucketPadLengthDataset
18
+ from .colorize_dataset import ColorizeDataset
19
+ from .concat_dataset import ConcatDataset
20
+ from .concat_sentences_dataset import ConcatSentencesDataset
21
+ from .denoising_dataset import DenoisingDataset
22
+ from .id_dataset import IdDataset
23
+ from .indexed_dataset import (
24
+ IndexedCachedDataset,
25
+ IndexedDataset,
26
+ IndexedRawTextDataset,
27
+ MMapIndexedDataset,
28
+ )
29
+ from .language_pair_dataset import LanguagePairDataset
30
+ from .list_dataset import ListDataset
31
+ from .lm_context_window_dataset import LMContextWindowDataset
32
+ from .lru_cache_dataset import LRUCacheDataset
33
+ from .mask_tokens_dataset import MaskTokensDataset
34
+ from .monolingual_dataset import MonolingualDataset
35
+ from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset
36
+ from .nested_dictionary_dataset import NestedDictionaryDataset
37
+ from .noising import NoisingDataset
38
+ from .numel_dataset import NumelDataset
39
+ from .num_samples_dataset import NumSamplesDataset
40
+ from .offset_tokens_dataset import OffsetTokensDataset
41
+ from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
42
+ from .prepend_dataset import PrependDataset
43
+ from .prepend_token_dataset import PrependTokenDataset
44
+ from .raw_label_dataset import RawLabelDataset
45
+ from .replace_dataset import ReplaceDataset
46
+ from .resampling_dataset import ResamplingDataset
47
+ from .roll_dataset import RollDataset
48
+ from .round_robin_zip_datasets import RoundRobinZipDatasets
49
+ from .sort_dataset import SortDataset
50
+ from .strip_token_dataset import StripTokenDataset
51
+ from .subsample_dataset import SubsampleDataset
52
+ from .token_block_dataset import TokenBlockDataset
53
+ from .transform_eos_dataset import TransformEosDataset
54
+ from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
55
+ from .shorten_dataset import TruncateDataset, RandomCropDataset
56
+ from .multilingual.sampled_multi_dataset import SampledMultiDataset
57
+ from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset
58
+ from .fasta_dataset import FastaDataset, EncodedFastaDataset
59
+
60
+ from .iterators import (
61
+ CountingIterator,
62
+ EpochBatchIterator,
63
+ GroupedIterator,
64
+ ShardedIterator,
65
+ )
66
+
67
+ __all__ = [
68
+ "AddTargetDataset",
69
+ "AppendTokenDataset",
70
+ "BacktranslationDataset",
71
+ "BaseWrapperDataset",
72
+ "BucketPadLengthDataset",
73
+ "ColorizeDataset",
74
+ "ConcatDataset",
75
+ "ConcatSentencesDataset",
76
+ "CountingIterator",
77
+ "DenoisingDataset",
78
+ "Dictionary",
79
+ "EncodedFastaDataset",
80
+ "EpochBatchIterator",
81
+ "FairseqDataset",
82
+ "FairseqIterableDataset",
83
+ "FastaDataset",
84
+ "GroupedIterator",
85
+ "IdDataset",
86
+ "IndexedCachedDataset",
87
+ "IndexedDataset",
88
+ "IndexedRawTextDataset",
89
+ "LanguagePairDataset",
90
+ "LeftPadDataset",
91
+ "ListDataset",
92
+ "LMContextWindowDataset",
93
+ "LRUCacheDataset",
94
+ "MaskTokensDataset",
95
+ "MMapIndexedDataset",
96
+ "MonolingualDataset",
97
+ "MultiCorpusSampledDataset",
98
+ "NestedDictionaryDataset",
99
+ "NoisingDataset",
100
+ "NumelDataset",
101
+ "NumSamplesDataset",
102
+ "OffsetTokensDataset",
103
+ "PadDataset",
104
+ "PrependDataset",
105
+ "PrependTokenDataset",
106
+ "ReplaceDataset",
107
+ "RollDataset",
108
+ "FileAudioDataset",
109
+ "RawLabelDataset",
110
+ "ResamplingDataset",
111
+ "RightPadDataset",
112
+ "RoundRobinZipDatasets",
113
+ "SampledMultiDataset",
114
+ "SampledMultiEpochDataset",
115
+ "ShardedIterator",
116
+ "SortDataset",
117
+ "StripTokenDataset",
118
+ "SubsampleDataset",
119
+ "TokenBlockDataset",
120
+ "TransformEosDataset",
121
+ "TransformEosLangPairDataset",
122
+ "TruncateDataset",
123
+ "TruncatedDictionary",
124
+ ]
fairseq-0.10.2/fairseq/data/add_target_dataset.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import BaseWrapperDataset, data_utils
9
+
10
+
11
+ class AddTargetDataset(BaseWrapperDataset):
12
+ def __init__(
13
+ self,
14
+ dataset,
15
+ labels,
16
+ pad,
17
+ eos,
18
+ batch_targets,
19
+ process_label=None,
20
+ add_to_input=False,
21
+ ):
22
+ super().__init__(dataset)
23
+ self.labels = labels
24
+ self.batch_targets = batch_targets
25
+ self.pad = pad
26
+ self.eos = eos
27
+ self.process_label = process_label
28
+ self.add_to_input = add_to_input
29
+
30
+ def get_label(self, index):
31
+ return (
32
+ self.labels[index]
33
+ if self.process_label is None
34
+ else self.process_label(self.labels[index])
35
+ )
36
+
37
+ def __getitem__(self, index):
38
+ item = self.dataset[index]
39
+ item["label"] = self.get_label(index)
40
+ return item
41
+
42
+ def size(self, index):
43
+ sz = self.dataset.size(index)
44
+ own_sz = len(self.get_label(index))
45
+ return (sz, own_sz)
46
+
47
+ def collater(self, samples):
48
+ collated = self.dataset.collater(samples)
49
+ if len(collated) == 0:
50
+ return collated
51
+ indices = set(collated["id"].tolist())
52
+ target = [s["label"] for s in samples if s["id"] in indices]
53
+
54
+ if self.batch_targets:
55
+ collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
56
+ target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
57
+ collated["ntokens"] = collated["target_lengths"].sum().item()
58
+ else:
59
+ collated["ntokens"] = sum([len(t) for t in target])
60
+
61
+ collated["target"] = target
62
+
63
+ if self.add_to_input:
64
+ eos = target.new_full((target.size(0), 1), self.eos)
65
+ collated["target"] = torch.cat([target, eos], dim=-1).long()
66
+ collated["net_input"]["prev_output_tokens"] = torch.cat(
67
+ [eos, target], dim=-1
68
+ ).long()
69
+ collated["ntokens"] += target.size(0)
70
+ return collated
fairseq-0.10.2/fairseq/data/append_token_dataset.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from . import BaseWrapperDataset
10
+
11
+
12
+ class AppendTokenDataset(BaseWrapperDataset):
13
+ def __init__(self, dataset, token=None):
14
+ super().__init__(dataset)
15
+ self.token = token
16
+ if token is not None:
17
+ self._sizes = np.array(dataset.sizes) + 1
18
+ else:
19
+ self._sizes = dataset.sizes
20
+
21
+ def __getitem__(self, idx):
22
+ item = self.dataset[idx]
23
+ if self.token is not None:
24
+ item = torch.cat([item, item.new([self.token])])
25
+ return item
26
+
27
+ @property
28
+ def sizes(self):
29
+ return self._sizes
30
+
31
+ def num_tokens(self, index):
32
+ n = self.dataset.num_tokens(index)
33
+ if self.token is not None:
34
+ n += 1
35
+ return n
36
+
37
+ def size(self, index):
38
+ n = self.dataset.size(index)
39
+ if self.token is not None:
40
+ n += 1
41
+ return n
fairseq-0.10.2/fairseq/data/backtranslation_dataset.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from fairseq import utils
8
+
9
+ from . import FairseqDataset
10
+
11
+
12
+ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
13
+ """Backtranslate a list of samples.
14
+
15
+ Given an input (*samples*) of the form:
16
+
17
+ [{'id': 1, 'source': 'hallo welt'}]
18
+
19
+ this will return:
20
+
21
+ [{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
22
+
23
+ Args:
24
+ samples (List[dict]): samples to backtranslate. Individual samples are
25
+ expected to have a 'source' key, which will become the 'target'
26
+ after backtranslation.
27
+ collate_fn (callable): function to collate samples into a mini-batch
28
+ generate_fn (callable): function to generate backtranslations
29
+ cuda (bool): use GPU for generation (default: ``True``)
30
+
31
+ Returns:
32
+ List[dict]: an updated list of samples with a backtranslated source
33
+ """
34
+ collated_samples = collate_fn(samples)
35
+ s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
36
+ generated_sources = generate_fn(s)
37
+
38
+ id_to_src = {sample["id"]: sample["source"] for sample in samples}
39
+
40
+ # Go through each tgt sentence in batch and its corresponding best
41
+ # generated hypothesis and create a backtranslation data pair
42
+ # {id: id, source: generated backtranslation, target: original tgt}
43
+ return [
44
+ {
45
+ "id": id.item(),
46
+ "target": id_to_src[id.item()],
47
+ "source": hypos[0]["tokens"].cpu(),
48
+ }
49
+ for id, hypos in zip(collated_samples["id"], generated_sources)
50
+ ]
51
+
52
+
53
+ class BacktranslationDataset(FairseqDataset):
54
+ """
55
+ Sets up a backtranslation dataset which takes a tgt batch, generates
56
+ a src using a tgt-src backtranslation function (*backtranslation_fn*),
57
+ and returns the corresponding `{generated src, input tgt}` batch.
58
+
59
+ Args:
60
+ tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
61
+ backtranslated. Only the source side of this dataset will be used.
62
+ After backtranslation, the source sentences in this dataset will be
63
+ returned as the targets.
64
+ src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
65
+ sentences.
66
+ tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
67
+ sentences to be backtranslated.
68
+ backtranslation_fn (callable, optional): function to call to generate
69
+ backtranslations. This is typically the `generate` method of a
70
+ :class:`~fairseq.sequence_generator.SequenceGenerator` object.
71
+ Pass in None when it is not available at initialization time, and
72
+ use set_backtranslation_fn function to set it when available.
73
+ output_collater (callable, optional): function to call on the
74
+ backtranslated samples to create the final batch
75
+ (default: ``tgt_dataset.collater``).
76
+ cuda: use GPU for generation
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ tgt_dataset,
82
+ src_dict,
83
+ tgt_dict=None,
84
+ backtranslation_fn=None,
85
+ output_collater=None,
86
+ cuda=True,
87
+ **kwargs
88
+ ):
89
+ self.tgt_dataset = tgt_dataset
90
+ self.backtranslation_fn = backtranslation_fn
91
+ self.output_collater = (
92
+ output_collater if output_collater is not None else tgt_dataset.collater
93
+ )
94
+ self.cuda = cuda if torch.cuda.is_available() else False
95
+ self.src_dict = src_dict
96
+ self.tgt_dict = tgt_dict
97
+
98
+ def __getitem__(self, index):
99
+ """
100
+ Returns a single sample from *tgt_dataset*. Note that backtranslation is
101
+ not applied in this step; use :func:`collater` instead to backtranslate
102
+ a batch of samples.
103
+ """
104
+ return self.tgt_dataset[index]
105
+
106
+ def __len__(self):
107
+ return len(self.tgt_dataset)
108
+
109
+ def set_backtranslation_fn(self, backtranslation_fn):
110
+ self.backtranslation_fn = backtranslation_fn
111
+
112
+ def collater(self, samples):
113
+ """Merge and backtranslate a list of samples to form a mini-batch.
114
+
115
+ Using the samples from *tgt_dataset*, load a collated target sample to
116
+ feed to the backtranslation model. Then take the backtranslation with
117
+ the best score as the source and the original input as the target.
118
+
119
+ Note: we expect *tgt_dataset* to provide a function `collater()` that
120
+ will collate samples into the format expected by *backtranslation_fn*.
121
+ After backtranslation, we will feed the new list of samples (i.e., the
122
+ `(backtranslated source, original source)` pairs) to *output_collater*
123
+ and return the result.
124
+
125
+ Args:
126
+ samples (List[dict]): samples to backtranslate and collate
127
+
128
+ Returns:
129
+ dict: a mini-batch with keys coming from *output_collater*
130
+ """
131
+ if samples[0].get("is_dummy", False):
132
+ return samples
133
+ samples = backtranslate_samples(
134
+ samples=samples,
135
+ collate_fn=self.tgt_dataset.collater,
136
+ generate_fn=(lambda net_input: self.backtranslation_fn(net_input)),
137
+ cuda=self.cuda,
138
+ )
139
+ return self.output_collater(samples)
140
+
141
+ def num_tokens(self, index):
142
+ """Just use the tgt dataset num_tokens"""
143
+ return self.tgt_dataset.num_tokens(index)
144
+
145
+ def ordered_indices(self):
146
+ """Just use the tgt dataset ordered_indices"""
147
+ return self.tgt_dataset.ordered_indices()
148
+
149
+ def size(self, index):
150
+ """Return an example's size as a float or tuple. This value is used
151
+ when filtering a dataset with ``--max-positions``.
152
+
153
+ Note: we use *tgt_dataset* to approximate the length of the source
154
+ sentence, since we do not know the actual length until after
155
+ backtranslation.
156
+ """
157
+ tgt_size = self.tgt_dataset.size(index)[0]
158
+ return (tgt_size, tgt_size)
159
+
160
+ @property
161
+ def supports_prefetch(self):
162
+ return getattr(self.tgt_dataset, "supports_prefetch", False)
163
+
164
+ def prefetch(self, indices):
165
+ return self.tgt_dataset.prefetch(indices)
fairseq-0.10.2/fairseq/data/base_wrapper_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torch.utils.data.dataloader import default_collate
7
+
8
+ from . import FairseqDataset
9
+
10
+
11
+ class BaseWrapperDataset(FairseqDataset):
12
+ def __init__(self, dataset):
13
+ super().__init__()
14
+ self.dataset = dataset
15
+
16
+ def __getitem__(self, index):
17
+ return self.dataset[index]
18
+
19
+ def __len__(self):
20
+ return len(self.dataset)
21
+
22
+ def collater(self, samples):
23
+ if hasattr(self.dataset, "collater"):
24
+ return self.dataset.collater(samples)
25
+ else:
26
+ return default_collate(samples)
27
+
28
+ @property
29
+ def sizes(self):
30
+ return self.dataset.sizes
31
+
32
+ def num_tokens(self, index):
33
+ return self.dataset.num_tokens(index)
34
+
35
+ def size(self, index):
36
+ return self.dataset.size(index)
37
+
38
+ def ordered_indices(self):
39
+ return self.dataset.ordered_indices()
40
+
41
+ @property
42
+ def supports_prefetch(self):
43
+ return getattr(self.dataset, "supports_prefetch", False)
44
+
45
+ def attr(self, attr: str, index: int):
46
+ return self.dataset.attr(attr, index)
47
+
48
+ def prefetch(self, indices):
49
+ self.dataset.prefetch(indices)
50
+
51
+ def get_batch_shapes(self):
52
+ return self.dataset.get_batch_shapes()
53
+
54
+ def batch_by_size(
55
+ self,
56
+ indices,
57
+ max_tokens=None,
58
+ max_sentences=None,
59
+ required_batch_size_multiple=1,
60
+ ):
61
+ return self.dataset.batch_by_size(
62
+ indices,
63
+ max_tokens=max_tokens,
64
+ max_sentences=max_sentences,
65
+ required_batch_size_multiple=required_batch_size_multiple,
66
+ )
67
+
68
+ def filter_indices_by_size(self, indices, max_sizes):
69
+ return self.dataset.filter_indices_by_size(indices, max_sizes)
70
+
71
+ @property
72
+ def can_reuse_epoch_itr_across_epochs(self):
73
+ return self.dataset.can_reuse_epoch_itr_across_epochs
74
+
75
+ def set_epoch(self, epoch):
76
+ super().set_epoch(epoch)
77
+ if hasattr(self.dataset, "set_epoch"):
78
+ self.dataset.set_epoch(epoch)
fairseq-0.10.2/fairseq/data/bucket_pad_length_dataset.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+ from fairseq.data import BaseWrapperDataset
9
+
10
+
11
+ class BucketPadLengthDataset(BaseWrapperDataset):
12
+ """
13
+ Bucket and pad item lengths to the nearest bucket size. This can be used to
14
+ reduce the number of unique batch shapes, which is important on TPUs since
15
+ each new batch shape requires a recompilation.
16
+
17
+ Args:
18
+ dataset (FairseqDatset): dataset to bucket
19
+ sizes (List[int]): all item sizes
20
+ num_buckets (int): number of buckets to create
21
+ pad_idx (int): padding symbol
22
+ left_pad (bool): if True, pad on the left; otherwise right pad
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ dataset,
28
+ sizes,
29
+ num_buckets,
30
+ pad_idx,
31
+ left_pad,
32
+ ):
33
+ super().__init__(dataset)
34
+ self.pad_idx = pad_idx
35
+ self.left_pad = left_pad
36
+
37
+ assert num_buckets > 0
38
+ self.buckets = np.unique(
39
+ np.percentile(
40
+ sizes,
41
+ np.linspace(0, 100, num_buckets + 1),
42
+ interpolation="lower",
43
+ )[1:]
44
+ )
45
+
46
+ def get_bucketed_sizes(orig_sizes, buckets):
47
+ sizes = np.copy(orig_sizes)
48
+ assert np.min(sizes) >= 0
49
+ start_val = -1
50
+ for end_val in buckets:
51
+ mask = (sizes > start_val) & (sizes <= end_val)
52
+ sizes[mask] = end_val
53
+ start_val = end_val
54
+ return sizes
55
+
56
+ self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
57
+
58
+ def __getitem__(self, index):
59
+ item = self.dataset[index]
60
+ bucket_size = self._bucketed_sizes[index]
61
+ num_pad = bucket_size - item.size(-1)
62
+ return F.pad(
63
+ item,
64
+ (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
65
+ value=self.pad_idx,
66
+ )
67
+
68
+ @property
69
+ def sizes(self):
70
+ return self._bucketed_sizes
71
+
72
+ def num_tokens(self, index):
73
+ return self._bucketed_sizes[index]
74
+
75
+ def size(self, index):
76
+ return self._bucketed_sizes[index]
fairseq-0.10.2/fairseq/data/colorize_dataset.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import BaseWrapperDataset
9
+
10
+
11
+ class ColorizeDataset(BaseWrapperDataset):
12
+ """ Adds 'colors' property to net input that is obtained from the provided color getter for use by models """
13
+
14
+ def __init__(self, dataset, color_getter):
15
+ super().__init__(dataset)
16
+ self.color_getter = color_getter
17
+
18
+ def collater(self, samples):
19
+ base_collate = super().collater(samples)
20
+ if len(base_collate) > 0:
21
+ base_collate["net_input"]["colors"] = torch.tensor(
22
+ list(self.color_getter(self.dataset, s["id"]) for s in samples),
23
+ dtype=torch.long,
24
+ )
25
+ return base_collate
fairseq-0.10.2/fairseq/data/data_utils.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ try:
7
+ from collections.abc import Iterable
8
+ except ImportError:
9
+ from collections import Iterable
10
+ import contextlib
11
+ import itertools
12
+ import logging
13
+ import os
14
+ import warnings
15
+ from typing import Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def infer_language_pair(path):
25
+ """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
26
+ src, dst = None, None
27
+ for filename in os.listdir(path):
28
+ parts = filename.split(".")
29
+ if len(parts) >= 3 and len(parts[1].split("-")) == 2:
30
+ return parts[1].split("-")
31
+ return src, dst
32
+
33
+
34
+ def collate_tokens(
35
+ values,
36
+ pad_idx,
37
+ eos_idx=None,
38
+ left_pad=False,
39
+ move_eos_to_beginning=False,
40
+ pad_to_length=None,
41
+ pad_to_multiple=1,
42
+ ):
43
+ """Convert a list of 1d tensors into a padded 2d tensor."""
44
+ size = max(v.size(0) for v in values)
45
+ size = size if pad_to_length is None else max(size, pad_to_length)
46
+ if pad_to_multiple != 1 and size % pad_to_multiple != 0:
47
+ size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
48
+ res = values[0].new(len(values), size).fill_(pad_idx)
49
+
50
+ def copy_tensor(src, dst):
51
+ assert dst.numel() == src.numel()
52
+ if move_eos_to_beginning:
53
+ if eos_idx is None:
54
+ # if no eos_idx is specified, then use the last token in src
55
+ dst[0] = src[-1]
56
+ else:
57
+ dst[0] = eos_idx
58
+ dst[1:] = src[:-1]
59
+ else:
60
+ dst.copy_(src)
61
+
62
+ for i, v in enumerate(values):
63
+ copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
64
+ return res
65
+
66
+
67
+ def load_indexed_dataset(
68
+ path, dictionary=None, dataset_impl=None, combine=False, default="cached"
69
+ ):
70
+ """A helper function for loading indexed datasets.
71
+
72
+ Args:
73
+ path (str): path to indexed dataset (e.g., 'data-bin/train')
74
+ dictionary (~fairseq.data.Dictionary): data dictionary
75
+ dataset_impl (str, optional): which dataset implementation to use. If
76
+ not provided, it will be inferred automatically. For legacy indexed
77
+ data we use the 'cached' implementation by default.
78
+ combine (bool, optional): automatically load and combine multiple
79
+ datasets. For example, if *path* is 'data-bin/train', then we will
80
+ combine 'data-bin/train', 'data-bin/train1', ... and return a
81
+ single ConcatDataset instance.
82
+ """
83
+ from fairseq.data.concat_dataset import ConcatDataset
84
+ import fairseq.data.indexed_dataset as indexed_dataset
85
+
86
+ datasets = []
87
+ for k in itertools.count():
88
+ path_k = path + (str(k) if k > 0 else "")
89
+ path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
90
+
91
+ dataset_impl_k = dataset_impl
92
+ if dataset_impl_k is None:
93
+ dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
94
+ dataset = indexed_dataset.make_dataset(
95
+ path_k,
96
+ impl=dataset_impl_k or default,
97
+ fix_lua_indexing=True,
98
+ dictionary=dictionary,
99
+ )
100
+ if dataset is None:
101
+ break
102
+ logger.info("loaded {} examples from: {}".format(len(dataset), path_k))
103
+ datasets.append(dataset)
104
+ if not combine:
105
+ break
106
+ if len(datasets) == 0:
107
+ return None
108
+ elif len(datasets) == 1:
109
+ return datasets[0]
110
+ else:
111
+ return ConcatDataset(datasets)
112
+
113
+
114
+ @contextlib.contextmanager
115
+ def numpy_seed(seed, *addl_seeds):
116
+ """Context manager which seeds the NumPy PRNG with the specified seed and
117
+ restores the state afterward"""
118
+ if seed is None:
119
+ yield
120
+ return
121
+ if len(addl_seeds) > 0:
122
+ seed = int(hash((seed, *addl_seeds)) % 1e6)
123
+ state = np.random.get_state()
124
+ np.random.seed(seed)
125
+ try:
126
+ yield
127
+ finally:
128
+ np.random.set_state(state)
129
+
130
+
131
+ def collect_filtered(function, iterable, filtered):
132
+ """
133
+ Similar to :func:`filter` but collects filtered elements in ``filtered``.
134
+
135
+ Args:
136
+ function (callable): function that returns ``False`` for elements that
137
+ should be filtered
138
+ iterable (iterable): iterable to filter
139
+ filtered (list): list to store filtered elements
140
+ """
141
+ for el in iterable:
142
+ if function(el):
143
+ yield el
144
+ else:
145
+ filtered.append(el)
146
+
147
+
148
+ def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
149
+ def compare_leq(a, b):
150
+ return a <= b if not isinstance(a, tuple) else max(a) <= b
151
+
152
+ def check_size(idx):
153
+ if isinstance(max_positions, float) or isinstance(max_positions, int):
154
+ return size_fn(idx) <= max_positions
155
+ elif isinstance(max_positions, dict):
156
+ idx_size = size_fn(idx)
157
+ assert isinstance(idx_size, dict)
158
+ intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
159
+ return all(
160
+ all(
161
+ a is None or b is None or a <= b
162
+ for a, b in zip(idx_size[key], max_positions[key])
163
+ )
164
+ for key in intersect_keys
165
+ )
166
+ else:
167
+ # Hacky as heck, for the specific case of multilingual training with RoundRobin.
168
+ if isinstance(size_fn(idx), dict) and isinstance(max_positions, tuple):
169
+ return all(
170
+ a is None or b is None or compare_leq(a, b)
171
+ for a, b in zip(size_fn(idx).values(), max_positions)
172
+ )
173
+ # For MultiCorpusSampledDataset, will generalize it later
174
+ if not isinstance(size_fn(idx), Iterable):
175
+ return all(size_fn(idx) <= b for b in max_positions)
176
+ return all(
177
+ a is None or b is None or a <= b
178
+ for a, b in zip(size_fn(idx), max_positions)
179
+ )
180
+
181
+ ignored = []
182
+ itr = collect_filtered(check_size, indices, ignored)
183
+ indices = np.fromiter(itr, dtype=np.int64, count=-1)
184
+ return indices, ignored
185
+
186
+
187
+ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
188
+ """
189
+ [deprecated] Filter indices based on their size.
190
+ Use `FairseqDataset::filter_indices_by_size` instead.
191
+
192
+ Args:
193
+ indices (List[int]): ordered list of dataset indices
194
+ dataset (FairseqDataset): fairseq dataset instance
195
+ max_positions (tuple): filter elements larger than this size.
196
+ Comparisons are done component-wise.
197
+ raise_exception (bool, optional): if ``True``, raise an exception if
198
+ any elements are filtered (default: False).
199
+ """
200
+ warnings.warn(
201
+ "data_utils.filter_by_size is deprecated. "
202
+ "Use `FairseqDataset::filter_indices_by_size` instead.",
203
+ stacklevel=2,
204
+ )
205
+ if isinstance(max_positions, float) or isinstance(max_positions, int):
206
+ if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
207
+ ignored = indices[dataset.sizes[indices] > max_positions].tolist()
208
+ indices = indices[dataset.sizes[indices] <= max_positions]
209
+ elif (
210
+ hasattr(dataset, "sizes")
211
+ and isinstance(dataset.sizes, list)
212
+ and len(dataset.sizes) == 1
213
+ ):
214
+ ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
215
+ indices = indices[dataset.sizes[0][indices] <= max_positions]
216
+ else:
217
+ indices, ignored = _filter_by_size_dynamic(
218
+ indices, dataset.size, max_positions
219
+ )
220
+ else:
221
+ indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
222
+
223
+ if len(ignored) > 0 and raise_exception:
224
+ raise Exception(
225
+ (
226
+ "Size of sample #{} is invalid (={}) since max_positions={}, "
227
+ "skip this example with --skip-invalid-size-inputs-valid-test"
228
+ ).format(ignored[0], dataset.size(ignored[0]), max_positions)
229
+ )
230
+ if len(ignored) > 0:
231
+ logger.warning(
232
+ (
233
+ "{} samples have invalid sizes and will be skipped, "
234
+ "max_positions={}, first few sample ids={}"
235
+ ).format(len(ignored), max_positions, ignored[:10])
236
+ )
237
+ return indices
238
+
239
+
240
+ def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
241
+ """Filter a list of sample indices. Remove those that are longer
242
+ than specified in max_sizes.
243
+
244
+ Args:
245
+ indices (np.array): original array of sample indices
246
+ max_sizes (int or list[int] or tuple[int]): max sample size,
247
+ can be defined separately for src and tgt (then list or tuple)
248
+
249
+ Returns:
250
+ np.array: filtered sample array
251
+ list: list of removed indices
252
+ """
253
+ if max_sizes is None:
254
+ return indices, []
255
+ if type(max_sizes) in (int, float):
256
+ max_src_size, max_tgt_size = max_sizes, max_sizes
257
+ else:
258
+ max_src_size, max_tgt_size = max_sizes
259
+ if tgt_sizes is None:
260
+ ignored = indices[src_sizes[indices] > max_src_size]
261
+ else:
262
+ ignored = indices[
263
+ (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
264
+ ]
265
+ if len(ignored) > 0:
266
+ if tgt_sizes is None:
267
+ indices = indices[src_sizes[indices] <= max_src_size]
268
+ else:
269
+ indices = indices[
270
+ (src_sizes[indices] <= max_src_size)
271
+ & (tgt_sizes[indices] <= max_tgt_size)
272
+ ]
273
+ return indices, ignored.tolist()
274
+
275
+
276
+ def batch_by_size(
277
+ indices,
278
+ num_tokens_fn,
279
+ max_tokens=None,
280
+ max_sentences=None,
281
+ required_batch_size_multiple=1,
282
+ fixed_shapes=None,
283
+ ):
284
+ """
285
+ Yield mini-batches of indices bucketed by size. Batches may contain
286
+ sequences of different lengths.
287
+
288
+ Args:
289
+ indices (List[int]): ordered list of dataset indices
290
+ num_tokens_fn (callable): function that returns the number of tokens at
291
+ a given index
292
+ max_tokens (int, optional): max number of tokens in each batch
293
+ (default: None).
294
+ max_sentences (int, optional): max number of sentences in each
295
+ batch (default: None).
296
+ required_batch_size_multiple (int, optional): require batch size to
297
+ be less than N or a multiple of N (default: 1).
298
+ fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
299
+ only be created with the given shapes. *max_sentences* and
300
+ *required_batch_size_multiple* will be ignored (default: None).
301
+ """
302
+ try:
303
+ from fairseq.data.data_utils_fast import (
304
+ batch_by_size_fast,
305
+ batch_fixed_shapes_fast,
306
+ )
307
+ except ImportError:
308
+ raise ImportError(
309
+ "Please build Cython components with: `pip install --editable .` "
310
+ "or `python setup.py build_ext --inplace`"
311
+ )
312
+
313
+ max_tokens = max_tokens if max_tokens is not None else -1
314
+ max_sentences = max_sentences if max_sentences is not None else -1
315
+ bsz_mult = required_batch_size_multiple
316
+
317
+ if not isinstance(indices, np.ndarray):
318
+ indices = np.fromiter(indices, dtype=np.int64, count=-1)
319
+
320
+ if fixed_shapes is None:
321
+ return batch_by_size_fast(
322
+ indices,
323
+ num_tokens_fn,
324
+ max_tokens,
325
+ max_sentences,
326
+ bsz_mult,
327
+ )
328
+ else:
329
+ fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
330
+ sort_order = np.lexsort(
331
+ [
332
+ fixed_shapes[:, 1].argsort(), # length
333
+ fixed_shapes[:, 0].argsort(), # bsz
334
+ ]
335
+ )
336
+ fixed_shapes_sorted = fixed_shapes[sort_order]
337
+ return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
338
+
339
+
340
+ def post_process(sentence: str, symbol: str):
341
+ if symbol == "sentencepiece":
342
+ sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
343
+ elif symbol == "wordpiece":
344
+ sentence = sentence.replace(" ", "").replace("_", " ").strip()
345
+ elif symbol == "letter":
346
+ sentence = sentence.replace(" ", "").replace("|", " ").strip()
347
+ elif symbol == "_EOW":
348
+ sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
349
+ elif symbol is not None and symbol != "none":
350
+ sentence = (sentence + " ").replace(symbol, "").rstrip()
351
+ return sentence
352
+
353
+
354
+ def compute_mask_indices(
355
+ shape: Tuple[int, int],
356
+ padding_mask: Optional[torch.Tensor],
357
+ mask_prob: float,
358
+ mask_length: int,
359
+ mask_type: str = "static",
360
+ mask_other: float = 0.0,
361
+ min_masks: int = 0,
362
+ no_overlap: bool = False,
363
+ min_space: int = 0,
364
+ ) -> np.ndarray:
365
+ """
366
+ Computes random mask spans for a given shape
367
+
368
+ Args:
369
+ shape: the the shape for which to compute masks.
370
+ should be of size 2 where first element is batch size and 2nd is timesteps
371
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
372
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
373
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
374
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
375
+ mask_type: how to compute mask lengths
376
+ static = fixed size
377
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
378
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
379
+ poisson = sample from possion distribution with lambda = mask length
380
+ min_masks: minimum number of masked spans
381
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
382
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
383
+ """
384
+
385
+ bsz, all_sz = shape
386
+ mask = np.full((bsz, all_sz), False)
387
+
388
+ all_num_mask = int(
389
+ # add a random number for probabilistic rounding
390
+ mask_prob * all_sz / float(mask_length)
391
+ + np.random.rand()
392
+ )
393
+
394
+ all_num_mask = max(min_masks, all_num_mask)
395
+
396
+ mask_idcs = []
397
+ for i in range(bsz):
398
+ if padding_mask is not None:
399
+ sz = all_sz - padding_mask[i].long().sum().item()
400
+ num_mask = int(
401
+ # add a random number for probabilistic rounding
402
+ mask_prob * sz / float(mask_length)
403
+ + np.random.rand()
404
+ )
405
+ num_mask = max(min_masks, num_mask)
406
+ else:
407
+ sz = all_sz
408
+ num_mask = all_num_mask
409
+
410
+ if mask_type == "static":
411
+ lengths = np.full(num_mask, mask_length)
412
+ elif mask_type == "uniform":
413
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
414
+ elif mask_type == "normal":
415
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
416
+ lengths = [max(1, int(round(x))) for x in lengths]
417
+ elif mask_type == "poisson":
418
+ lengths = np.random.poisson(mask_length, size=num_mask)
419
+ lengths = [int(round(x)) for x in lengths]
420
+ else:
421
+ raise Exception("unknown mask selection " + mask_type)
422
+
423
+ if sum(lengths) == 0:
424
+ lengths[0] = min(mask_length, sz - 1)
425
+
426
+ if no_overlap:
427
+ mask_idc = []
428
+
429
+ def arrange(s, e, length, keep_length):
430
+ span_start = np.random.randint(s, e - length)
431
+ mask_idc.extend(span_start + i for i in range(length))
432
+
433
+ new_parts = []
434
+ if span_start - s - min_space >= keep_length:
435
+ new_parts.append((s, span_start - min_space + 1))
436
+ if e - span_start - keep_length - min_space > keep_length:
437
+ new_parts.append((span_start + length + min_space, e))
438
+ return new_parts
439
+
440
+ parts = [(0, sz)]
441
+ min_length = min(lengths)
442
+ for length in sorted(lengths, reverse=True):
443
+ lens = np.fromiter(
444
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
445
+ np.int,
446
+ )
447
+ l_sum = np.sum(lens)
448
+ if l_sum == 0:
449
+ break
450
+ probs = lens / np.sum(lens)
451
+ c = np.random.choice(len(parts), p=probs)
452
+ s, e = parts.pop(c)
453
+ parts.extend(arrange(s, e, length, min_length))
454
+ mask_idc = np.asarray(mask_idc)
455
+ else:
456
+ min_len = min(lengths)
457
+ if sz - min_len <= num_mask:
458
+ min_len = sz - num_mask - 1
459
+
460
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
461
+
462
+ mask_idc = np.asarray(
463
+ [
464
+ mask_idc[j] + offset
465
+ for j in range(len(mask_idc))
466
+ for offset in range(lengths[j])
467
+ ]
468
+ )
469
+
470
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
471
+
472
+ min_len = min([len(m) for m in mask_idcs])
473
+ for i, mask_idc in enumerate(mask_idcs):
474
+ if len(mask_idc) > min_len:
475
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
476
+ mask[i, mask_idc] = True
477
+
478
+ return mask
479
+
480
+
481
+ def get_mem_usage():
482
+ try:
483
+ import psutil
484
+
485
+ mb = 1024 * 1024
486
+ return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
487
+ except ImportError:
488
+ return "N/A"
489
+
490
+
491
+ def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor:
492
+ bsz, max_lens = lens.size(0), torch.max(lens).item()
493
+ mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
494
+ mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
495
+ return mask
496
+
497
+
498
+ def lengths_to_mask(lens: torch.LongTensor) -> torch.BoolTensor:
499
+ return ~lengths_to_padding_mask(lens)
fairseq-0.10.2/fairseq/data/data_utils_fast.cpp ADDED
The diff for this file is too large to render. See raw diff
 
fairseq-0.10.2/fairseq/data/data_utils_fast.pyx ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cython: language_level=3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+
9
+ cimport cython
10
+ cimport numpy as np
11
+
12
+ from libc.stdint cimport int32_t, int64_t
13
+
14
+ ctypedef int64_t DTYPE_t
15
+
16
+
17
+ cdef _is_batch_full(int64_t num_sentences, int64_t num_tokens, int64_t max_tokens, int64_t max_sentences):
18
+ if num_sentences == 0:
19
+ return 0
20
+ if max_sentences > 0 and num_sentences == max_sentences:
21
+ return 1
22
+ if max_tokens > 0 and num_tokens > max_tokens:
23
+ return 1
24
+ return 0
25
+
26
+
27
+ @cython.cdivision(True)
28
+ cpdef list batch_by_size_fast(
29
+ np.ndarray[DTYPE_t, ndim=1] indices,
30
+ num_tokens_fn,
31
+ int64_t max_tokens,
32
+ int64_t max_sentences,
33
+ int32_t bsz_mult,
34
+ ):
35
+ cdef int64_t sample_len = 0
36
+ cdef list sample_lens = []
37
+ cdef list batch = []
38
+ cdef list batches = []
39
+ cdef int64_t mod_len
40
+ cdef int64_t i
41
+ cdef int64_t idx
42
+ cdef int64_t num_tokens
43
+ cdef DTYPE_t[:] indices_view = indices
44
+
45
+ for i in range(len(indices_view)):
46
+ idx = indices_view[i]
47
+ num_tokens = num_tokens_fn(idx)
48
+ sample_lens.append(num_tokens)
49
+ sample_len = max(sample_len, num_tokens)
50
+
51
+ assert max_tokens <= 0 or sample_len <= max_tokens, (
52
+ "sentence at index {} of size {} exceeds max_tokens "
53
+ "limit of {}!".format(idx, sample_len, max_tokens)
54
+ )
55
+ num_tokens = (len(batch) + 1) * sample_len
56
+
57
+ if _is_batch_full(len(batch), num_tokens, max_tokens, max_sentences):
58
+ mod_len = max(
59
+ bsz_mult * (len(batch) // bsz_mult),
60
+ len(batch) % bsz_mult,
61
+ )
62
+ batches.append(batch[:mod_len])
63
+ batch = batch[mod_len:]
64
+ sample_lens = sample_lens[mod_len:]
65
+ sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
66
+ batch.append(idx)
67
+ if len(batch) > 0:
68
+ batches.append(batch)
69
+ return batches
70
+
71
+
72
+ cdef _find_valid_shape(
73
+ DTYPE_t[:, :] shapes_view,
74
+ int64_t num_sentences,
75
+ int64_t num_tokens,
76
+ ):
77
+ """Return index of first valid shape of -1 if none is found."""
78
+ for i in range(shapes_view.shape[0]):
79
+ if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]:
80
+ return i
81
+ return -1
82
+
83
+
84
+ @cython.cdivision(True)
85
+ cpdef list batch_fixed_shapes_fast(
86
+ np.ndarray[DTYPE_t, ndim=1] indices,
87
+ num_tokens_fn,
88
+ np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
89
+ ):
90
+ cdef int64_t sample_len = 0
91
+ cdef list sample_lens = []
92
+ cdef list batch = []
93
+ cdef list batches = []
94
+ cdef int64_t mod_len
95
+ cdef int64_t i
96
+ cdef int64_t idx
97
+ cdef int64_t num_tokens
98
+ cdef DTYPE_t[:] indices_view = indices
99
+ cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted
100
+
101
+ for i in range(len(indices_view)):
102
+ idx = indices_view[i]
103
+ num_tokens = num_tokens_fn(idx)
104
+ sample_lens.append(num_tokens)
105
+ sample_len = max(sample_len, num_tokens)
106
+
107
+ shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len)
108
+ if shape_idx == -1:
109
+ batches.append(batch)
110
+ batch = []
111
+ sample_lens = []
112
+ sample_len = 0
113
+ shapes_view = fixed_shapes_sorted
114
+ elif shape_idx > 0:
115
+ # small optimization for the next call to _find_valid_shape
116
+ shapes_view = shapes_view[shape_idx:]
117
+
118
+ batch.append(idx)
119
+
120
+ if len(batch) > 0:
121
+ batches.append(batch)
122
+
123
+ return batches
fairseq-0.10.2/fairseq/data/denoising_dataset.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from . import FairseqDataset, data_utils
12
+
13
+
14
+ def collate(
15
+ samples,
16
+ pad_idx,
17
+ eos_idx,
18
+ vocab,
19
+ left_pad_source=False,
20
+ left_pad_target=False,
21
+ input_feeding=True,
22
+ pad_to_length=None,
23
+ ):
24
+ assert input_feeding
25
+ if len(samples) == 0:
26
+ return {}
27
+
28
+ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
29
+ return data_utils.collate_tokens(
30
+ [s[key] for s in samples],
31
+ pad_idx,
32
+ eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
33
+ left_pad=left_pad,
34
+ move_eos_to_beginning=move_eos_to_beginning,
35
+ pad_to_length=pad_to_length,
36
+ )
37
+
38
+ id = torch.LongTensor([s["id"] for s in samples])
39
+ src_tokens = merge(
40
+ "source",
41
+ left_pad=left_pad_source,
42
+ pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
43
+ )
44
+ # sort by descending source length
45
+ src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
46
+ src_lengths, sort_order = src_lengths.sort(descending=True)
47
+ id = id.index_select(0, sort_order)
48
+ src_tokens = src_tokens.index_select(0, sort_order)
49
+
50
+ prev_output_tokens = None
51
+ target = None
52
+ if samples[0].get("target", None) is not None:
53
+ target = merge(
54
+ "target",
55
+ left_pad=left_pad_target,
56
+ pad_to_length=pad_to_length["target"]
57
+ if pad_to_length is not None
58
+ else None,
59
+ )
60
+ target = target.index_select(0, sort_order)
61
+ ntokens = sum(len(s["target"]) for s in samples)
62
+
63
+ if input_feeding:
64
+ # we create a shifted version of targets for feeding the
65
+ # previous output token(s) into the next decoder step
66
+ prev_output_tokens = merge(
67
+ "target",
68
+ left_pad=left_pad_target,
69
+ move_eos_to_beginning=True,
70
+ pad_to_length=pad_to_length["target"]
71
+ if pad_to_length is not None
72
+ else None,
73
+ )
74
+ prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
75
+ else:
76
+ ntokens = sum(len(s["source"]) for s in samples)
77
+
78
+ batch = {
79
+ "id": id,
80
+ "ntokens": ntokens,
81
+ "net_input": {
82
+ "src_tokens": src_tokens,
83
+ "src_lengths": src_lengths,
84
+ },
85
+ "target": target,
86
+ "nsentences": samples[0]["source"].size(0),
87
+ "sort_order": sort_order,
88
+ }
89
+ if prev_output_tokens is not None:
90
+ batch["net_input"]["prev_output_tokens"] = prev_output_tokens
91
+
92
+ return batch
93
+
94
+
95
+ class DenoisingDataset(FairseqDataset):
96
+ """
97
+ A wrapper around TokenBlockDataset for BART dataset.
98
+
99
+ Args:
100
+ dataset (TokenBlockDataset): dataset to wrap
101
+ sizes (List[int]): sentence lengths
102
+ vocab (~fairseq.data.Dictionary): vocabulary
103
+ mask_idx (int): dictionary index used for masked token
104
+ mask_whole_words: only mask whole words. This should be a byte mask
105
+ over vocab indices, indicating whether it is the beginning of a
106
+ word. We will extend any mask to encompass the whole word.
107
+ shuffle (bool, optional): shuffle the elements before batching.
108
+ Default: ``True``
109
+ seed: Seed for random number generator for reproducibility.
110
+ args: argparse arguments.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ dataset,
116
+ sizes,
117
+ vocab,
118
+ mask_idx,
119
+ mask_whole_words,
120
+ shuffle,
121
+ seed,
122
+ args,
123
+ eos=None,
124
+ item_transform_func=None,
125
+ ):
126
+ self.dataset = dataset
127
+
128
+ self.sizes = sizes
129
+
130
+ self.vocab = vocab
131
+ self.shuffle = shuffle
132
+ self.seed = seed
133
+ self.mask_idx = mask_idx
134
+ self.mask_whole_word = mask_whole_words
135
+ self.mask_ratio = args.mask
136
+ self.random_ratio = args.mask_random
137
+ self.insert_ratio = args.insert
138
+ self.rotate_ratio = args.rotate
139
+ self.permute_sentence_ratio = args.permute_sentences
140
+ self.eos = eos if eos is not None else vocab.eos()
141
+ self.item_transform_func = item_transform_func
142
+
143
+ if args.bpe != "gpt2":
144
+ self.full_stop_index = self.vocab.eos()
145
+ else:
146
+ assert args.bpe == "gpt2"
147
+ self.full_stop_index = self.vocab.index("13")
148
+
149
+ self.replace_length = args.replace_length
150
+ if self.replace_length not in [-1, 0, 1]:
151
+ raise ValueError(f"invalid arg: replace_length={self.replace_length}")
152
+ if args.mask_length not in ["subword", "word", "span-poisson"]:
153
+ raise ValueError(f"invalid arg: mask-length={args.mask_length}")
154
+ if args.mask_length == "subword" and args.replace_length not in [0, 1]:
155
+ raise ValueError(f"if using subwords, use replace-length=1 or 0")
156
+
157
+ self.mask_span_distribution = None
158
+ if args.mask_length == "span-poisson":
159
+ _lambda = args.poisson_lambda
160
+
161
+ lambda_to_the_k = 1
162
+ e_to_the_minus_lambda = math.exp(-_lambda)
163
+ k_factorial = 1
164
+ ps = []
165
+ for k in range(0, 128):
166
+ ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
167
+ lambda_to_the_k *= _lambda
168
+ k_factorial *= k + 1
169
+ if ps[-1] < 0.0000001:
170
+ break
171
+ ps = torch.FloatTensor(ps)
172
+ self.mask_span_distribution = torch.distributions.Categorical(ps)
173
+
174
+ self.epoch = 0
175
+
176
+ @property
177
+ def can_reuse_epoch_itr_across_epochs(self):
178
+ return True # only the noise changes, not item sizes
179
+
180
+ def set_epoch(self, epoch, **unused):
181
+ self.epoch = epoch
182
+
183
+ def __getitem__(self, index):
184
+ with data_utils.numpy_seed(self.seed, self.epoch, index):
185
+ tokens = self.dataset[index]
186
+ assert tokens[-1] == self.eos
187
+ source, target = tokens, tokens.clone()
188
+
189
+ if self.permute_sentence_ratio > 0.0:
190
+ source = self.permute_sentences(source, self.permute_sentence_ratio)
191
+
192
+ if self.mask_ratio > 0:
193
+ source = self.add_whole_word_mask(source, self.mask_ratio)
194
+
195
+ if self.insert_ratio > 0:
196
+ source = self.add_insertion_noise(source, self.insert_ratio)
197
+
198
+ if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
199
+ source = self.add_rolling_noise(source)
200
+ # there can additional changes to make:
201
+ if self.item_transform_func is not None:
202
+ source, target = self.item_transform_func(source, target)
203
+
204
+ assert (source >= 0).all()
205
+ assert (source[1:-1] >= 1).all()
206
+ assert (source <= len(self.vocab)).all()
207
+ assert source[0] == self.vocab.bos()
208
+ assert source[-1] == self.eos
209
+ return {
210
+ "id": index,
211
+ "source": source,
212
+ "target": target,
213
+ }
214
+
215
+ def __len__(self):
216
+ return len(self.dataset)
217
+
218
+ def permute_sentences(self, source, p=1.0):
219
+ full_stops = source == self.full_stop_index
220
+ # Pretend it ends with a full stop so last span is a sentence
221
+ full_stops[-2] = 1
222
+
223
+ # Tokens that are full stops, where the previous token is not
224
+ sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
225
+ result = source.clone()
226
+
227
+ num_sentences = sentence_ends.size(0)
228
+ num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
229
+ substitutions = torch.randperm(num_sentences)[:num_to_permute]
230
+ ordering = torch.arange(0, num_sentences)
231
+ ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
232
+
233
+ # Ignore <bos> at start
234
+ index = 1
235
+ for i in ordering:
236
+ sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
237
+ result[index : index + sentence.size(0)] = sentence
238
+ index += sentence.size(0)
239
+ return result
240
+
241
+ def word_starts(self, source):
242
+ if self.mask_whole_word is not None:
243
+ is_word_start = self.mask_whole_word.gather(0, source)
244
+ else:
245
+ is_word_start = torch.ones(source.size())
246
+ is_word_start[0] = 0
247
+ is_word_start[-1] = 0
248
+ return is_word_start
249
+
250
+ def add_whole_word_mask(self, source, p):
251
+ is_word_start = self.word_starts(source)
252
+ num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
253
+ num_inserts = 0
254
+ if num_to_mask == 0:
255
+ return source
256
+
257
+ if self.mask_span_distribution is not None:
258
+ lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
259
+
260
+ # Make sure we have enough to mask
261
+ cum_length = torch.cumsum(lengths, 0)
262
+ while cum_length[-1] < num_to_mask:
263
+ lengths = torch.cat(
264
+ [
265
+ lengths,
266
+ self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
267
+ ],
268
+ dim=0,
269
+ )
270
+ cum_length = torch.cumsum(lengths, 0)
271
+
272
+ # Trim to masking budget
273
+ i = 0
274
+ while cum_length[i] < num_to_mask:
275
+ i += 1
276
+ lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
277
+ num_to_mask = i + 1
278
+ lengths = lengths[:num_to_mask]
279
+
280
+ # Handle 0-length mask (inserts) separately
281
+ lengths = lengths[lengths > 0]
282
+ num_inserts = num_to_mask - lengths.size(0)
283
+ num_to_mask -= num_inserts
284
+ if num_to_mask == 0:
285
+ return self.add_insertion_noise(source, num_inserts / source.size(0))
286
+
287
+ assert (lengths > 0).all()
288
+ else:
289
+ lengths = torch.ones((num_to_mask,)).long()
290
+ assert is_word_start[-1] == 0
291
+ word_starts = is_word_start.nonzero(as_tuple=False)
292
+ indices = word_starts[
293
+ torch.randperm(word_starts.size(0))[:num_to_mask]
294
+ ].squeeze(1)
295
+ mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
296
+
297
+ source_length = source.size(0)
298
+ assert source_length - 1 not in indices
299
+ to_keep = torch.ones(source_length, dtype=torch.bool)
300
+ is_word_start[
301
+ -1
302
+ ] = 255 # acts as a long length, so spans don't go over the end of doc
303
+ if self.replace_length == 0:
304
+ to_keep[indices] = 0
305
+ else:
306
+ # keep index, but replace it with [MASK]
307
+ source[indices] = self.mask_idx
308
+ source[indices[mask_random]] = torch.randint(
309
+ 1, len(self.vocab), size=(mask_random.sum(),)
310
+ )
311
+
312
+ if self.mask_span_distribution is not None:
313
+ assert len(lengths.size()) == 1
314
+ assert lengths.size() == indices.size()
315
+ lengths -= 1
316
+ while indices.size(0) > 0:
317
+ assert lengths.size() == indices.size()
318
+ lengths -= is_word_start[indices + 1].long()
319
+ uncompleted = lengths >= 0
320
+ indices = indices[uncompleted] + 1
321
+ mask_random = mask_random[uncompleted]
322
+ lengths = lengths[uncompleted]
323
+ if self.replace_length != -1:
324
+ # delete token
325
+ to_keep[indices] = 0
326
+ else:
327
+ # keep index, but replace it with [MASK]
328
+ source[indices] = self.mask_idx
329
+ source[indices[mask_random]] = torch.randint(
330
+ 1, len(self.vocab), size=(mask_random.sum(),)
331
+ )
332
+ else:
333
+ # A bit faster when all lengths are 1
334
+ while indices.size(0) > 0:
335
+ uncompleted = is_word_start[indices + 1] == 0
336
+ indices = indices[uncompleted] + 1
337
+ mask_random = mask_random[uncompleted]
338
+ if self.replace_length != -1:
339
+ # delete token
340
+ to_keep[indices] = 0
341
+ else:
342
+ # keep index, but replace it with [MASK]
343
+ source[indices] = self.mask_idx
344
+ source[indices[mask_random]] = torch.randint(
345
+ 1, len(self.vocab), size=(mask_random.sum(),)
346
+ )
347
+
348
+ assert source_length - 1 not in indices
349
+
350
+ source = source[to_keep]
351
+
352
+ if num_inserts > 0:
353
+ source = self.add_insertion_noise(source, num_inserts / source.size(0))
354
+
355
+ return source
356
+
357
+ def add_permuted_noise(self, tokens, p):
358
+ num_words = len(tokens)
359
+ num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
360
+ substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
361
+ tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
362
+ return tokens
363
+
364
+ def add_rolling_noise(self, tokens):
365
+ offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
366
+ tokens = torch.cat(
367
+ (tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
368
+ dim=0,
369
+ )
370
+ return tokens
371
+
372
+ def add_insertion_noise(self, tokens, p):
373
+ if p == 0.0:
374
+ return tokens
375
+
376
+ num_tokens = len(tokens)
377
+ n = int(math.ceil(num_tokens * p))
378
+
379
+ noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
380
+ noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
381
+ noise_mask[noise_indices] = 1
382
+ result = torch.LongTensor(n + len(tokens)).fill_(-1)
383
+
384
+ num_random = int(math.ceil(n * self.random_ratio))
385
+ result[noise_indices[num_random:]] = self.mask_idx
386
+ result[noise_indices[:num_random]] = torch.randint(
387
+ low=1, high=len(self.vocab), size=(num_random,)
388
+ )
389
+
390
+ result[~noise_mask] = tokens
391
+
392
+ assert (result >= 0).all()
393
+ return result
394
+
395
+ def collater(self, samples, pad_to_length=None):
396
+ """Merge a list of samples to form a mini-batch.
397
+ Args:
398
+ samples (List[dict]): samples to collate
399
+ Returns:
400
+ dict: a mini-batch of data
401
+ """
402
+ return collate(
403
+ samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length
404
+ )
405
+
406
+ def num_tokens(self, index):
407
+ """Return the number of tokens in a sample. This value is used to
408
+ enforce ``--max-tokens`` during batching."""
409
+ return self.sizes[index]
410
+
411
+ def size(self, index):
412
+ """Return an example's size as a float or tuple. This value is used when
413
+ filtering a dataset with ``--max-positions``."""
414
+ return self.sizes[index]
415
+
416
+ def ordered_indices(self):
417
+ """Return an ordered list of indices. Batches will be constructed based
418
+ on this order."""
419
+ if self.shuffle:
420
+ indices = np.random.permutation(len(self))
421
+ else:
422
+ indices = np.arange(len(self))
423
+ return indices[np.argsort(self.sizes[indices], kind="mergesort")]
424
+
425
+ def prefetch(self, indices):
426
+ self.src.prefetch(indices)
427
+ self.tgt.prefetch(indices)
428
+
429
+ @property
430
+ def supports_prefetch(self):
431
+ return (
432
+ hasattr(self.src, "supports_prefetch")
433
+ and self.src.supports_prefetch
434
+ and hasattr(self.tgt, "supports_prefetch")
435
+ and self.tgt.supports_prefetch
436
+ )
fairseq-0.10.2/fairseq/data/indexed_dataset.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import shutil
8
+ import struct
9
+ from functools import lru_cache
10
+
11
+ import numpy as np
12
+ import torch
13
+ from fairseq.data.fasta_dataset import FastaDataset
14
+ from fairseq.file_io import PathManager
15
+
16
+ from . import FairseqDataset
17
+
18
+
19
+ def __best_fitting_dtype(vocab_size=None):
20
+ if vocab_size is not None and vocab_size < 65500:
21
+ return np.uint16
22
+ else:
23
+ return np.int32
24
+
25
+
26
+ def get_available_dataset_impl():
27
+ return ["raw", "lazy", "cached", "mmap", "fasta"]
28
+
29
+
30
+ def infer_dataset_impl(path):
31
+ if IndexedRawTextDataset.exists(path):
32
+ return "raw"
33
+ elif IndexedDataset.exists(path):
34
+ with open(index_file_path(path), "rb") as f:
35
+ magic = f.read(8)
36
+ if magic == IndexedDataset._HDR_MAGIC:
37
+ return "cached"
38
+ elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
39
+ return "mmap"
40
+ else:
41
+ return None
42
+ elif FastaDataset.exists(path):
43
+ return "fasta"
44
+ else:
45
+ return None
46
+
47
+
48
+ def make_builder(out_file, impl, vocab_size=None):
49
+ if impl == "mmap":
50
+ return MMapIndexedDatasetBuilder(
51
+ out_file, dtype=__best_fitting_dtype(vocab_size)
52
+ )
53
+ elif impl == "fasta":
54
+ raise NotImplementedError
55
+ else:
56
+ return IndexedDatasetBuilder(out_file)
57
+
58
+
59
+ def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None):
60
+ if impl == "raw" and IndexedRawTextDataset.exists(path):
61
+ assert dictionary is not None
62
+ return IndexedRawTextDataset(path, dictionary)
63
+ elif impl == "lazy" and IndexedDataset.exists(path):
64
+ return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing)
65
+ elif impl == "cached" and IndexedDataset.exists(path):
66
+ return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing)
67
+ elif impl == "mmap" and MMapIndexedDataset.exists(path):
68
+ return MMapIndexedDataset(path)
69
+ elif impl == "fasta" and FastaDataset.exists(path):
70
+ from fairseq.data.fasta_dataset import EncodedFastaDataset
71
+
72
+ return EncodedFastaDataset(path, dictionary)
73
+ return None
74
+
75
+
76
+ def dataset_exists(path, impl):
77
+ if impl == "raw":
78
+ return IndexedRawTextDataset.exists(path)
79
+ elif impl == "mmap":
80
+ return MMapIndexedDataset.exists(path)
81
+ else:
82
+ return IndexedDataset.exists(path)
83
+
84
+
85
+ def read_longs(f, n):
86
+ a = np.empty(n, dtype=np.int64)
87
+ f.readinto(a)
88
+ return a
89
+
90
+
91
+ def write_longs(f, a):
92
+ f.write(np.array(a, dtype=np.int64))
93
+
94
+
95
+ dtypes = {
96
+ 1: np.uint8,
97
+ 2: np.int8,
98
+ 3: np.int16,
99
+ 4: np.int32,
100
+ 5: np.int64,
101
+ 6: np.float,
102
+ 7: np.double,
103
+ 8: np.uint16,
104
+ }
105
+
106
+
107
+ def code(dtype):
108
+ for k in dtypes.keys():
109
+ if dtypes[k] == dtype:
110
+ return k
111
+ raise ValueError(dtype)
112
+
113
+
114
+ def index_file_path(prefix_path):
115
+ return prefix_path + ".idx"
116
+
117
+
118
+ def data_file_path(prefix_path):
119
+ return prefix_path + ".bin"
120
+
121
+
122
+ class IndexedDataset(FairseqDataset):
123
+ """Loader for TorchNet IndexedDataset"""
124
+
125
+ _HDR_MAGIC = b"TNTIDX\x00\x00"
126
+
127
+ def __init__(self, path, fix_lua_indexing=False):
128
+ super().__init__()
129
+ self.path = path
130
+ self.fix_lua_indexing = fix_lua_indexing
131
+ self.data_file = None
132
+ self.read_index(path)
133
+
134
+ def read_index(self, path):
135
+ with open(index_file_path(path), "rb") as f:
136
+ magic = f.read(8)
137
+ assert magic == self._HDR_MAGIC, (
138
+ "Index file doesn't match expected format. "
139
+ "Make sure that --dataset-impl is configured properly."
140
+ )
141
+ version = f.read(8)
142
+ assert struct.unpack("<Q", version) == (1,)
143
+ code, self.element_size = struct.unpack("<QQ", f.read(16))
144
+ self.dtype = dtypes[code]
145
+ self._len, self.s = struct.unpack("<QQ", f.read(16))
146
+ self.dim_offsets = read_longs(f, self._len + 1)
147
+ self.data_offsets = read_longs(f, self._len + 1)
148
+ self.sizes = read_longs(f, self.s)
149
+
150
+ def read_data(self, path):
151
+ self.data_file = open(data_file_path(path), "rb", buffering=0)
152
+
153
+ def check_index(self, i):
154
+ if i < 0 or i >= self._len:
155
+ raise IndexError("index out of range")
156
+
157
+ def __del__(self):
158
+ if self.data_file:
159
+ self.data_file.close()
160
+
161
+ @lru_cache(maxsize=8)
162
+ def __getitem__(self, i):
163
+ if not self.data_file:
164
+ self.read_data(self.path)
165
+ self.check_index(i)
166
+ tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
167
+ a = np.empty(tensor_size, dtype=self.dtype)
168
+ self.data_file.seek(self.data_offsets[i] * self.element_size)
169
+ self.data_file.readinto(a)
170
+ item = torch.from_numpy(a).long()
171
+ if self.fix_lua_indexing:
172
+ item -= 1 # subtract 1 for 0-based indexing
173
+ return item
174
+
175
+ def __len__(self):
176
+ return self._len
177
+
178
+ def num_tokens(self, index):
179
+ return self.sizes[index]
180
+
181
+ def size(self, index):
182
+ return self.sizes[index]
183
+
184
+ @staticmethod
185
+ def exists(path):
186
+ return PathManager.exists(index_file_path(path)) and PathManager.exists(
187
+ data_file_path(path)
188
+ )
189
+
190
+ @property
191
+ def supports_prefetch(self):
192
+ return False # avoid prefetching to save memory
193
+
194
+
195
+ class IndexedCachedDataset(IndexedDataset):
196
+ def __init__(self, path, fix_lua_indexing=False):
197
+ super().__init__(path, fix_lua_indexing=fix_lua_indexing)
198
+ self.cache = None
199
+ self.cache_index = {}
200
+
201
+ @property
202
+ def supports_prefetch(self):
203
+ return True
204
+
205
+ def prefetch(self, indices):
206
+ if all(i in self.cache_index for i in indices):
207
+ return
208
+ if not self.data_file:
209
+ self.read_data(self.path)
210
+ indices = sorted(set(indices))
211
+ total_size = 0
212
+ for i in indices:
213
+ total_size += self.data_offsets[i + 1] - self.data_offsets[i]
214
+ self.cache = np.empty(total_size, dtype=self.dtype)
215
+ ptx = 0
216
+ self.cache_index.clear()
217
+ for i in indices:
218
+ self.cache_index[i] = ptx
219
+ size = self.data_offsets[i + 1] - self.data_offsets[i]
220
+ a = self.cache[ptx : ptx + size]
221
+ self.data_file.seek(self.data_offsets[i] * self.element_size)
222
+ self.data_file.readinto(a)
223
+ ptx += size
224
+ if self.data_file:
225
+ # close and delete data file after prefetch so we can pickle
226
+ self.data_file.close()
227
+ self.data_file = None
228
+
229
+ @lru_cache(maxsize=8)
230
+ def __getitem__(self, i):
231
+ self.check_index(i)
232
+ tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
233
+ a = np.empty(tensor_size, dtype=self.dtype)
234
+ ptx = self.cache_index[i]
235
+ np.copyto(a, self.cache[ptx : ptx + a.size])
236
+ item = torch.from_numpy(a).long()
237
+ if self.fix_lua_indexing:
238
+ item -= 1 # subtract 1 for 0-based indexing
239
+ return item
240
+
241
+
242
+ class IndexedRawTextDataset(FairseqDataset):
243
+ """Takes a text file as input and binarizes it in memory at instantiation.
244
+ Original lines are also kept in memory"""
245
+
246
+ def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
247
+ self.tokens_list = []
248
+ self.lines = []
249
+ self.sizes = []
250
+ self.append_eos = append_eos
251
+ self.reverse_order = reverse_order
252
+ self.read_data(path, dictionary)
253
+ self.size = len(self.tokens_list)
254
+
255
+ def read_data(self, path, dictionary):
256
+ with open(path, "r", encoding="utf-8") as f:
257
+ for line in f:
258
+ self.lines.append(line.strip("\n"))
259
+ tokens = dictionary.encode_line(
260
+ line,
261
+ add_if_not_exist=False,
262
+ append_eos=self.append_eos,
263
+ reverse_order=self.reverse_order,
264
+ ).long()
265
+ self.tokens_list.append(tokens)
266
+ self.sizes.append(len(tokens))
267
+ self.sizes = np.array(self.sizes)
268
+
269
+ def check_index(self, i):
270
+ if i < 0 or i >= self.size:
271
+ raise IndexError("index out of range")
272
+
273
+ @lru_cache(maxsize=8)
274
+ def __getitem__(self, i):
275
+ self.check_index(i)
276
+ return self.tokens_list[i]
277
+
278
+ def get_original_text(self, i):
279
+ self.check_index(i)
280
+ return self.lines[i]
281
+
282
+ def __del__(self):
283
+ pass
284
+
285
+ def __len__(self):
286
+ return self.size
287
+
288
+ def num_tokens(self, index):
289
+ return self.sizes[index]
290
+
291
+ def size(self, index):
292
+ return self.sizes[index]
293
+
294
+ @staticmethod
295
+ def exists(path):
296
+ return PathManager.exists(path)
297
+
298
+
299
+ class IndexedDatasetBuilder(object):
300
+ element_sizes = {
301
+ np.uint8: 1,
302
+ np.int8: 1,
303
+ np.int16: 2,
304
+ np.int32: 4,
305
+ np.int64: 8,
306
+ np.float: 4,
307
+ np.double: 8,
308
+ }
309
+
310
+ def __init__(self, out_file, dtype=np.int32):
311
+ self.out_file = open(out_file, "wb")
312
+ self.dtype = dtype
313
+ self.data_offsets = [0]
314
+ self.dim_offsets = [0]
315
+ self.sizes = []
316
+ self.element_size = self.element_sizes[self.dtype]
317
+
318
+ def add_item(self, tensor):
319
+ # +1 for Lua compatibility
320
+ bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype))
321
+ self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
322
+ for s in tensor.size():
323
+ self.sizes.append(s)
324
+ self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
325
+
326
+ def merge_file_(self, another_file):
327
+ index = IndexedDataset(another_file)
328
+ assert index.dtype == self.dtype
329
+
330
+ begin = self.data_offsets[-1]
331
+ for offset in index.data_offsets[1:]:
332
+ self.data_offsets.append(begin + offset)
333
+ self.sizes.extend(index.sizes)
334
+ begin = self.dim_offsets[-1]
335
+ for dim_offset in index.dim_offsets[1:]:
336
+ self.dim_offsets.append(begin + dim_offset)
337
+
338
+ with open(data_file_path(another_file), "rb") as f:
339
+ while True:
340
+ data = f.read(1024)
341
+ if data:
342
+ self.out_file.write(data)
343
+ else:
344
+ break
345
+
346
+ def finalize(self, index_file):
347
+ self.out_file.close()
348
+ index = open(index_file, "wb")
349
+ index.write(b"TNTIDX\x00\x00")
350
+ index.write(struct.pack("<Q", 1))
351
+ index.write(struct.pack("<QQ", code(self.dtype), self.element_size))
352
+ index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
353
+ write_longs(index, self.dim_offsets)
354
+ write_longs(index, self.data_offsets)
355
+ write_longs(index, self.sizes)
356
+ index.close()
357
+
358
+
359
+ def _warmup_mmap_file(path):
360
+ with open(path, "rb") as stream:
361
+ while stream.read(100 * 1024 * 1024):
362
+ pass
363
+
364
+
365
+ class MMapIndexedDataset(torch.utils.data.Dataset):
366
+ class Index(object):
367
+ _HDR_MAGIC = b"MMIDIDX\x00\x00"
368
+
369
+ @classmethod
370
+ def writer(cls, path, dtype):
371
+ class _Writer(object):
372
+ def __enter__(self):
373
+ self._file = open(path, "wb")
374
+
375
+ self._file.write(cls._HDR_MAGIC)
376
+ self._file.write(struct.pack("<Q", 1))
377
+ self._file.write(struct.pack("<B", code(dtype)))
378
+
379
+ return self
380
+
381
+ @staticmethod
382
+ def _get_pointers(sizes):
383
+ dtype_size = dtype().itemsize
384
+ address = 0
385
+ pointers = []
386
+
387
+ for size in sizes:
388
+ pointers.append(address)
389
+ address += size * dtype_size
390
+
391
+ return pointers
392
+
393
+ def write(self, sizes):
394
+ pointers = self._get_pointers(sizes)
395
+
396
+ self._file.write(struct.pack("<Q", len(sizes)))
397
+
398
+ sizes = np.array(sizes, dtype=np.int32)
399
+ self._file.write(sizes.tobytes(order="C"))
400
+ del sizes
401
+
402
+ pointers = np.array(pointers, dtype=np.int64)
403
+ self._file.write(pointers.tobytes(order="C"))
404
+ del pointers
405
+
406
+ def __exit__(self, exc_type, exc_val, exc_tb):
407
+ self._file.close()
408
+
409
+ return _Writer()
410
+
411
+ def __init__(self, path):
412
+ with open(path, "rb") as stream:
413
+ magic_test = stream.read(9)
414
+ assert self._HDR_MAGIC == magic_test, (
415
+ "Index file doesn't match expected format. "
416
+ "Make sure that --dataset-impl is configured properly."
417
+ )
418
+ version = struct.unpack("<Q", stream.read(8))
419
+ assert (1,) == version
420
+
421
+ (dtype_code,) = struct.unpack("<B", stream.read(1))
422
+ self._dtype = dtypes[dtype_code]
423
+ self._dtype_size = self._dtype().itemsize
424
+
425
+ self._len = struct.unpack("<Q", stream.read(8))[0]
426
+ offset = stream.tell()
427
+
428
+ _warmup_mmap_file(path)
429
+
430
+ self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
431
+ self._bin_buffer = memoryview(self._bin_buffer_mmap)
432
+ self._sizes = np.frombuffer(
433
+ self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
434
+ )
435
+ self._pointers = np.frombuffer(
436
+ self._bin_buffer,
437
+ dtype=np.int64,
438
+ count=self._len,
439
+ offset=offset + self._sizes.nbytes,
440
+ )
441
+
442
+ def __del__(self):
443
+ self._bin_buffer_mmap._mmap.close()
444
+ del self._bin_buffer_mmap
445
+
446
+ @property
447
+ def dtype(self):
448
+ return self._dtype
449
+
450
+ @property
451
+ def sizes(self):
452
+ return self._sizes
453
+
454
+ @lru_cache(maxsize=8)
455
+ def __getitem__(self, i):
456
+ return self._pointers[i], self._sizes[i]
457
+
458
+ def __len__(self):
459
+ return self._len
460
+
461
+ def __init__(self, path):
462
+ super().__init__()
463
+
464
+ self._path = None
465
+ self._index = None
466
+ self._bin_buffer = None
467
+
468
+ self._do_init(path)
469
+
470
+ def __getstate__(self):
471
+ return self._path
472
+
473
+ def __setstate__(self, state):
474
+ self._do_init(state)
475
+
476
+ def _do_init(self, path):
477
+ self._path = path
478
+ self._index = self.Index(index_file_path(self._path))
479
+
480
+ _warmup_mmap_file(data_file_path(self._path))
481
+ self._bin_buffer_mmap = np.memmap(
482
+ data_file_path(self._path), mode="r", order="C"
483
+ )
484
+ self._bin_buffer = memoryview(self._bin_buffer_mmap)
485
+
486
+ def __del__(self):
487
+ self._bin_buffer_mmap._mmap.close()
488
+ del self._bin_buffer_mmap
489
+ del self._index
490
+
491
+ def __len__(self):
492
+ return len(self._index)
493
+
494
+ @lru_cache(maxsize=8)
495
+ def __getitem__(self, i):
496
+ ptr, size = self._index[i]
497
+ np_array = np.frombuffer(
498
+ self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
499
+ )
500
+ if self._index.dtype != np.int64:
501
+ np_array = np_array.astype(np.int64)
502
+
503
+ return torch.from_numpy(np_array)
504
+
505
+ @property
506
+ def sizes(self):
507
+ return self._index.sizes
508
+
509
+ @property
510
+ def supports_prefetch(self):
511
+ return False
512
+
513
+ @staticmethod
514
+ def exists(path):
515
+ return PathManager.exists(index_file_path(path)) and PathManager.exists(
516
+ data_file_path(path)
517
+ )
518
+
519
+
520
+ def get_indexed_dataset_to_local(path):
521
+ local_index_path = PathManager.get_local_path(index_file_path(path))
522
+ local_data_path = PathManager.get_local_path(data_file_path(path))
523
+
524
+ assert local_index_path.endswith(".idx") and local_data_path.endswith(".bin"), (
525
+ "PathManager.get_local_path does not return files with expected patterns: "
526
+ f"{local_index_path} and {local_data_path}"
527
+ )
528
+
529
+ local_path = local_data_path[:-4] # stripping surfix ".bin"
530
+ assert local_path == local_index_path[:-4] # stripping surfix ".idx"
531
+ return local_path
532
+
533
+
534
+ class MMapIndexedDatasetBuilder(object):
535
+ def __init__(self, out_file, dtype=np.int64):
536
+ self._data_file = open(out_file, "wb")
537
+ self._dtype = dtype
538
+ self._sizes = []
539
+
540
+ def add_item(self, tensor):
541
+ np_array = np.array(tensor.numpy(), dtype=self._dtype)
542
+ self._data_file.write(np_array.tobytes(order="C"))
543
+ self._sizes.append(np_array.size)
544
+
545
+ def merge_file_(self, another_file):
546
+ # Concatenate index
547
+ index = MMapIndexedDataset.Index(index_file_path(another_file))
548
+ assert index.dtype == self._dtype
549
+
550
+ for size in index.sizes:
551
+ self._sizes.append(size)
552
+
553
+ # Concatenate data
554
+ with open(data_file_path(another_file), "rb") as f:
555
+ shutil.copyfileobj(f, self._data_file)
556
+
557
+ def finalize(self, index_file):
558
+ self._data_file.close()
559
+
560
+ with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
561
+ index.write(self._sizes)
fairseq-0.10.2/fairseq/data/iterators.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import logging
8
+ import math
9
+ import operator
10
+ import os
11
+ import queue
12
+ import time
13
+ from threading import Thread
14
+
15
+ import numpy as np
16
+ import torch
17
+ from fairseq.data import data_utils
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Object used by _background_consumer to signal the source is exhausted
23
+ # to the main thread.
24
+ _sentinel = object()
25
+
26
+
27
+ class CountingIterator(object):
28
+ """Wrapper around an iterable that maintains the iteration count.
29
+
30
+ Args:
31
+ iterable (iterable): iterable to wrap
32
+ start (int): starting iteration count. Note that this doesn't
33
+ actually advance the iterator.
34
+ total (int): override the iterator length returned by
35
+ ``__len__``. This can be used to truncate *iterator*.
36
+
37
+ Attributes:
38
+ n (int): number of elements consumed from this iterator
39
+ """
40
+
41
+ def __init__(self, iterable, start=None, total=None):
42
+ self.iterable = iterable
43
+ self.itr = iter(self)
44
+
45
+ if start is None:
46
+ self.n = getattr(iterable, "n", 0)
47
+ else:
48
+ self.n = start
49
+
50
+ if total is None:
51
+ self.total = self.n + len(iterable)
52
+ else:
53
+ self.total = total
54
+
55
+ def __len__(self):
56
+ return self.total
57
+
58
+ def __iter__(self):
59
+ for x in self.iterable:
60
+ if self.n >= self.total:
61
+ raise RuntimeError(
62
+ "Mismatch between actual and expected iterable length. "
63
+ "This may be caused by resuming training from a checkpoint using "
64
+ "a different number of GPUs, in which case you can try the "
65
+ "--reset-dataloader option. Alternatively you may have a train or "
66
+ "validation set that is smaller than the number of GPUs. If none "
67
+ "of these apply, please report this to the fairseq developers."
68
+ )
69
+ self.n += 1
70
+ yield x
71
+
72
+ def __next__(self):
73
+ return next(self.itr)
74
+
75
+ def has_next(self):
76
+ """Whether the iterator has been exhausted."""
77
+ return self.n < len(self)
78
+
79
+ def skip(self, num_to_skip):
80
+ """Fast-forward the iterator by skipping *num_to_skip* elements."""
81
+ next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
82
+ return self
83
+
84
+ def take(self, n):
85
+ """
86
+ Truncates the iterator to n elements at most.
87
+ """
88
+ self.total = min(self.total, n)
89
+
90
+ # Propagate this change to the underlying iterator
91
+ # Only take after what we have already consumed (i.e. after restarting
92
+ # from checkpoint mid epoch, we have to subtract self.n which is the
93
+ # starting point)
94
+ #
95
+ # This to maintain the invariant self.total = self.n + len(iterable),
96
+ # before calling __next__ or __iter__
97
+ propagated_take = max(n - self.n, 0)
98
+ if hasattr(self.iterable, "take"):
99
+ self.iterable.take(propagated_take)
100
+ else:
101
+ self.iterable = itertools.islice(self.iterable, propagated_take)
102
+
103
+
104
+ class EpochBatchIterating(object):
105
+ def __len__(self) -> int:
106
+ raise NotImplementedError
107
+
108
+ @property
109
+ def next_epoch_idx(self):
110
+ raise NotImplementedError
111
+
112
+ def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
113
+ """Return a new iterator over the dataset.
114
+
115
+ Args:
116
+ shuffle (bool, optional): shuffle batches before returning the
117
+ iterator (default: True).
118
+ fix_batches_to_gpus: ensure that batches are always
119
+ allocated to the same shards across epochs. Requires
120
+ that :attr:`dataset` supports prefetching (default: False).
121
+ """
122
+ raise NotImplementedError
123
+
124
+ def end_of_epoch(self) -> bool:
125
+ """Returns whether the most recent epoch iterator has been exhausted"""
126
+ raise NotImplementedError
127
+
128
+ @property
129
+ def iterations_in_epoch(self) -> int:
130
+ """The number of consumed batches in the current epoch."""
131
+ raise NotImplementedError
132
+
133
+ def state_dict(self):
134
+ """Returns a dictionary containing a whole state of the iterator."""
135
+ raise NotImplementedError
136
+
137
+ def load_state_dict(self, state_dict):
138
+ """Copies the state of the iterator from the given *state_dict*."""
139
+ raise NotImplementedError
140
+
141
+
142
+ class StreamingEpochBatchIterator(EpochBatchIterating):
143
+ def __init__(
144
+ self,
145
+ dataset,
146
+ epoch=1,
147
+ num_shards=1,
148
+ shard_id=0,
149
+ ):
150
+ assert isinstance(dataset, torch.utils.data.IterableDataset)
151
+ self.dataset = dataset
152
+ self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
153
+ self._current_epoch_iterator = None
154
+ self.num_shards = num_shards
155
+ self.shard_id = shard_id
156
+
157
+ @property
158
+ def next_epoch_idx(self):
159
+ """Return the epoch index after *next_epoch_itr* is called."""
160
+ if self._current_epoch_iterator is not None and self.end_of_epoch():
161
+ return self.epoch + 1
162
+ else:
163
+ return self.epoch
164
+
165
+ def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
166
+ self.epoch = self.next_epoch_idx
167
+ if hasattr(self.dataset, "set_epoch"):
168
+ self.dataset.set_epoch(self.epoch)
169
+ self._current_epoch_iterator = CountingIterator(
170
+ iterable=ShardedIterator(
171
+ iterable=self.dataset,
172
+ num_shards=self.num_shards,
173
+ shard_id=self.shard_id,
174
+ ),
175
+ )
176
+ return self._current_epoch_iterator
177
+
178
+ def end_of_epoch(self) -> bool:
179
+ return not self._current_epoch_iterator.has_next()
180
+
181
+ @property
182
+ def iterations_in_epoch(self) -> int:
183
+ if self._current_epoch_iterator is not None:
184
+ return self._current_epoch_iterator.n
185
+ return 0
186
+
187
+ def state_dict(self):
188
+ return {
189
+ "epoch": self.epoch,
190
+ }
191
+
192
+ def load_state_dict(self, state_dict):
193
+ self.epoch = state_dict["epoch"]
194
+
195
+
196
+ class EpochBatchIterator(EpochBatchIterating):
197
+ """A multi-epoch iterator over a :class:`torch.utils.data.Dataset`.
198
+
199
+ Compared to :class:`torch.utils.data.DataLoader`, this iterator:
200
+
201
+ - can be reused across multiple epochs with the :func:`next_epoch_itr`
202
+ method (optionally shuffled between epochs)
203
+ - can be serialized/deserialized with the :func:`state_dict` and
204
+ :func:`load_state_dict` methods
205
+ - supports sharding with the *num_shards* and *shard_id* arguments
206
+
207
+ Args:
208
+ dataset (~torch.utils.data.Dataset): dataset from which to load the data
209
+ collate_fn (callable): merges a list of samples to form a mini-batch
210
+ batch_sampler (~torch.utils.data.Sampler or a callable): an iterator over batches of
211
+ indices, or a callable to create such an iterator (~torch.utils.data.Sampler).
212
+ A callable batch_sampler will be called for each epoch to enable per epoch dynamic
213
+ batch iterators defined by this callable batch_sampler.
214
+ seed (int, optional): seed for random number generator for
215
+ reproducibility (default: 1).
216
+ num_shards (int, optional): shard the data iterator into N
217
+ shards (default: 1).
218
+ shard_id (int, optional): which shard of the data iterator to
219
+ return (default: 0).
220
+ num_workers (int, optional): how many subprocesses to use for data
221
+ loading. 0 means the data will be loaded in the main process
222
+ (default: 0).
223
+ epoch (int, optional): the epoch to start the iterator from
224
+ (default: 1).
225
+ buffer_size (int, optional): the number of batches to keep ready in the
226
+ queue. Helps speeding up dataloading. When buffer_size is zero, the
227
+ default torch.utils.data.DataLoader preloading is used.
228
+ timeout (int, optional): if positive, the timeout value for collecting a batch
229
+ from workers. Should always be non-negative (default: ``0``).
230
+ disable_shuffling (bool, optional): force disable shuffling
231
+ (default: ``False``).
232
+ """
233
+
234
+ def __init__(
235
+ self,
236
+ dataset,
237
+ collate_fn,
238
+ batch_sampler,
239
+ seed=1,
240
+ num_shards=1,
241
+ shard_id=0,
242
+ num_workers=0,
243
+ epoch=1,
244
+ buffer_size=0,
245
+ timeout=0,
246
+ disable_shuffling=False,
247
+ ):
248
+ assert isinstance(dataset, torch.utils.data.Dataset)
249
+ self.dataset = dataset
250
+ self.collate_fn = collate_fn
251
+ self.batch_sampler = batch_sampler
252
+ self._frozen_batches = (
253
+ tuple(batch_sampler) if not callable(batch_sampler) else None
254
+ )
255
+ self.seed = seed
256
+ self.num_shards = num_shards
257
+ self.shard_id = shard_id
258
+ self.num_workers = num_workers
259
+ # This upper limit here is to prevent people from abusing this feature
260
+ # in a shared computing environment.
261
+ self.buffer_size = min(buffer_size, 20)
262
+ self.timeout = timeout
263
+ self.disable_shuffling = disable_shuffling
264
+
265
+ self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
266
+ self.shuffle = not disable_shuffling
267
+ self._cur_epoch_itr = None
268
+ self._next_epoch_itr = None
269
+ self._supports_prefetch = getattr(dataset, "supports_prefetch", False)
270
+
271
+ @property
272
+ def frozen_batches(self):
273
+ if self._frozen_batches is None:
274
+ self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch))
275
+ return self._frozen_batches
276
+
277
+ @property
278
+ def first_batch(self):
279
+ if len(self.frozen_batches) == 0:
280
+ raise Exception(
281
+ "The dataset is empty. This could indicate "
282
+ "that all elements in the dataset have been skipped. "
283
+ "Try increasing the max number of allowed tokens or using "
284
+ "a larger dataset."
285
+ )
286
+
287
+ if getattr(self.dataset, "supports_fetch_outside_dataloader", True):
288
+ return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]])
289
+ else:
290
+ return "DUMMY"
291
+
292
+ def __len__(self):
293
+ return int(math.ceil(len(self.frozen_batches) / float(self.num_shards)))
294
+
295
+ @property
296
+ def n(self):
297
+ return self.iterations_in_epoch
298
+
299
+ @property
300
+ def next_epoch_idx(self):
301
+ """Return the epoch index after *next_epoch_itr* is called."""
302
+ if self._next_epoch_itr is not None:
303
+ return self.epoch
304
+ elif self._cur_epoch_itr is not None and self.end_of_epoch():
305
+ return self.epoch + 1
306
+ else:
307
+ return self.epoch
308
+
309
+ def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
310
+ """Return a new iterator over the dataset.
311
+
312
+ Args:
313
+ shuffle (bool, optional): shuffle batches before returning the
314
+ iterator (default: True).
315
+ fix_batches_to_gpus: ensure that batches are always
316
+ allocated to the same shards across epochs. Requires
317
+ that :attr:`dataset` supports prefetching (default: False).
318
+ """
319
+ if self.disable_shuffling:
320
+ shuffle = False
321
+ self.epoch = self.next_epoch_idx
322
+ if hasattr(self.dataset, "set_epoch"):
323
+ self.dataset.set_epoch(self.epoch)
324
+ if self._next_epoch_itr is not None:
325
+ self._cur_epoch_itr = self._next_epoch_itr
326
+ self._next_epoch_itr = None
327
+ else:
328
+ if callable(self.batch_sampler):
329
+ # reset _frozen_batches to refresh the next epoch
330
+ self._frozen_batches = None
331
+ self._cur_epoch_itr = self._get_iterator_for_epoch(
332
+ self.epoch,
333
+ shuffle,
334
+ fix_batches_to_gpus=fix_batches_to_gpus,
335
+ )
336
+ self.shuffle = shuffle
337
+ return self._cur_epoch_itr
338
+
339
+ def end_of_epoch(self) -> bool:
340
+ """Returns whether the most recent epoch iterator has been exhausted"""
341
+ return not self._cur_epoch_itr.has_next()
342
+
343
+ @property
344
+ def iterations_in_epoch(self):
345
+ """The number of consumed batches in the current epoch."""
346
+ if self._cur_epoch_itr is not None:
347
+ return self._cur_epoch_itr.n
348
+ elif self._next_epoch_itr is not None:
349
+ return self._next_epoch_itr.n
350
+ return 0
351
+
352
+ def state_dict(self):
353
+ """Returns a dictionary containing a whole state of the iterator."""
354
+ if self.end_of_epoch():
355
+ epoch = self.epoch + 1
356
+ iter_in_epoch = 0
357
+ else:
358
+ epoch = self.epoch
359
+ iter_in_epoch = self.iterations_in_epoch
360
+ return {
361
+ "version": 2,
362
+ "epoch": epoch,
363
+ "iterations_in_epoch": iter_in_epoch,
364
+ "shuffle": self.shuffle,
365
+ }
366
+
367
+ def load_state_dict(self, state_dict):
368
+ """Copies the state of the iterator from the given *state_dict*."""
369
+ self.epoch = state_dict["epoch"]
370
+ itr_pos = state_dict.get("iterations_in_epoch", 0)
371
+ version = state_dict.get("version", 1)
372
+ if itr_pos > 0:
373
+ # fast-forward epoch iterator
374
+ self._next_epoch_itr = self._get_iterator_for_epoch(
375
+ self.epoch,
376
+ shuffle=state_dict.get("shuffle", True),
377
+ offset=itr_pos,
378
+ )
379
+ if self._next_epoch_itr is None:
380
+ if version == 1:
381
+ # legacy behavior: we finished the epoch, increment epoch counter
382
+ self.epoch += 1
383
+ else:
384
+ raise RuntimeError(
385
+ "Cannot resume training due to dataloader mismatch, please "
386
+ "report this to the fairseq developers. You can relaunch "
387
+ "training with `--reset-dataloader` and it should work."
388
+ )
389
+ else:
390
+ self._next_epoch_itr = None
391
+
392
+ def _get_iterator_for_epoch(
393
+ self, epoch, shuffle, fix_batches_to_gpus=False, offset=0
394
+ ):
395
+ def shuffle_batches(batches, seed):
396
+ with data_utils.numpy_seed(seed):
397
+ np.random.shuffle(batches)
398
+ return batches
399
+
400
+ if self._supports_prefetch:
401
+ batches = self.frozen_batches
402
+
403
+ if shuffle and not fix_batches_to_gpus:
404
+ batches = shuffle_batches(list(batches), self.seed + epoch)
405
+
406
+ batches = list(
407
+ ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
408
+ )
409
+ self.dataset.prefetch([i for s in batches for i in s])
410
+
411
+ if shuffle and fix_batches_to_gpus:
412
+ batches = shuffle_batches(batches, self.seed + epoch + self.shard_id)
413
+ else:
414
+ if shuffle:
415
+ batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch)
416
+ else:
417
+ batches = self.frozen_batches
418
+ batches = list(
419
+ ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
420
+ )
421
+
422
+ if offset > 0 and offset >= len(batches):
423
+ return None
424
+
425
+ if self.num_workers > 0:
426
+ os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"
427
+
428
+ # Create data loader
429
+ itr = torch.utils.data.DataLoader(
430
+ self.dataset,
431
+ collate_fn=self.collate_fn,
432
+ batch_sampler=batches[offset:],
433
+ num_workers=self.num_workers,
434
+ timeout=self.timeout,
435
+ )
436
+
437
+ # Wrap with a BufferedIterator if needed
438
+ if self.buffer_size > 0:
439
+ itr = BufferedIterator(self.buffer_size, itr)
440
+
441
+ # Wrap with CoutingIterator
442
+ itr = CountingIterator(itr, start=offset)
443
+ return itr
444
+
445
+
446
+ class GroupedIterator(CountingIterator):
447
+ """Wrapper around an iterable that returns groups (chunks) of items.
448
+
449
+ Args:
450
+ iterable (iterable): iterable to wrap
451
+ chunk_size (int): size of each chunk
452
+
453
+ Attributes:
454
+ n (int): number of elements consumed from this iterator
455
+ """
456
+
457
+ def __init__(self, iterable, chunk_size):
458
+ itr = _chunk_iterator(iterable, chunk_size)
459
+ super().__init__(
460
+ itr,
461
+ start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))),
462
+ total=int(math.ceil(len(iterable) / float(chunk_size))),
463
+ )
464
+ self.chunk_size = chunk_size
465
+
466
+
467
+ def _chunk_iterator(itr, chunk_size):
468
+ chunk = []
469
+ for x in itr:
470
+ chunk.append(x)
471
+ if len(chunk) == chunk_size:
472
+ yield chunk
473
+ chunk = []
474
+ if len(chunk) > 0:
475
+ yield chunk
476
+
477
+
478
+ class ShardedIterator(CountingIterator):
479
+ """A sharded wrapper around an iterable, padded to length.
480
+
481
+ Args:
482
+ iterable (iterable): iterable to wrap
483
+ num_shards (int): number of shards to split the iterable into
484
+ shard_id (int): which shard to iterator over
485
+ fill_value (Any, optional): padding value when the iterable doesn't
486
+ evenly divide *num_shards* (default: None).
487
+
488
+ Attributes:
489
+ n (int): number of elements consumed from this iterator
490
+ """
491
+
492
+ def __init__(self, iterable, num_shards, shard_id, fill_value=None):
493
+ if shard_id < 0 or shard_id >= num_shards:
494
+ raise ValueError("shard_id must be between 0 and num_shards")
495
+ sharded_len = int(math.ceil(len(iterable) / float(num_shards)))
496
+ itr = map(
497
+ operator.itemgetter(1),
498
+ itertools.zip_longest(
499
+ range(sharded_len),
500
+ itertools.islice(iterable, shard_id, len(iterable), num_shards),
501
+ fillvalue=fill_value,
502
+ ),
503
+ )
504
+ super().__init__(
505
+ itr,
506
+ start=int(math.ceil(getattr(iterable, "n", 0) / float(num_shards))),
507
+ total=sharded_len,
508
+ )
509
+
510
+
511
+ class BackgroundConsumer(Thread):
512
+ def __init__(self, queue, source, max_len):
513
+ Thread.__init__(self)
514
+
515
+ self._queue = queue
516
+ self._source = source
517
+ self._max_len = max_len
518
+ self.count = 0
519
+
520
+ def run(self):
521
+ try:
522
+ for item in self._source:
523
+ self._queue.put(item)
524
+
525
+ # Stop if we reached the maximum length
526
+ self.count += 1
527
+ if self._max_len is not None and self.count >= self._max_len:
528
+ break
529
+
530
+ # Signal the consumer we are done.
531
+ self._queue.put(_sentinel)
532
+ except Exception as e:
533
+ self._queue.put(e)
534
+
535
+
536
+ class BufferedIterator(object):
537
+ def __init__(self, size, iterable):
538
+ self._queue = queue.Queue(size)
539
+ self._iterable = iterable
540
+ self._consumer = None
541
+
542
+ self.start_time = time.time()
543
+ self.warning_time = None
544
+
545
+ self.total = len(iterable)
546
+
547
+ def _create_consumer(self):
548
+ self._consumer = BackgroundConsumer(
549
+ self._queue,
550
+ self._iterable,
551
+ self.total,
552
+ )
553
+ self._consumer.daemon = True
554
+ self._consumer.start()
555
+
556
+ def __iter__(self):
557
+ return self
558
+
559
+ def __len__(self):
560
+ return self.total
561
+
562
+ def take(self, n):
563
+ self.total = min(self.total, n)
564
+
565
+ # Propagate this change to the underlying iterator
566
+ if hasattr(self._iterable, "take"):
567
+ self._iterable.take(n)
568
+
569
+ def __next__(self):
570
+ # Create consumer if not created yet
571
+ if self._consumer is None:
572
+ self._create_consumer()
573
+
574
+ # Notify the user if there is a data loading bottleneck
575
+ if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)):
576
+ if time.time() - self.start_time > 5 * 60:
577
+ if (
578
+ self.warning_time is None
579
+ or time.time() - self.warning_time > 15 * 60
580
+ ):
581
+ logger.debug(
582
+ "Data loading buffer is empty or nearly empty. This may "
583
+ "indicate a data loading bottleneck, and increasing the "
584
+ "number of workers (--num-workers) may help."
585
+ )
586
+ self.warning_time = time.time()
587
+
588
+ # Get next example
589
+ item = self._queue.get(True)
590
+ if isinstance(item, Exception):
591
+ raise item
592
+ if item is _sentinel:
593
+ raise StopIteration()
594
+ return item
fairseq-0.10.2/fairseq/data/legacy/masked_lm_dataset.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from typing import Dict, List, Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+ from fairseq.data import Dictionary, FairseqDataset, data_utils
12
+ from fairseq.data.concat_dataset import ConcatDataset
13
+ from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
14
+ from fairseq.data.token_block_dataset import TokenBlockDataset
15
+
16
+
17
+ class MaskedLMDataset(FairseqDataset):
18
+ """
19
+ A wrapper Dataset for masked language modelling. The dataset
20
+ wraps around TokenBlockDataset or BlockedPairDataset and creates a batch
21
+ where the input blocks are masked according to the specified masking
22
+ probability. Additionally the batch can also contain sentence level targets
23
+ if this is specified.
24
+
25
+ Args:
26
+ dataset: Dataset which generates blocks of data. Only BlockPairDataset
27
+ and TokenBlockDataset are supported.
28
+ sizes: Sentence lengths
29
+ vocab: Dictionary with the vocabulary and special tokens.
30
+ pad_idx: Id of padding token in dictionary
31
+ mask_idx: Id of mask token in dictionary
32
+ classif_token_idx: Id of classification token in dictionary. This is the
33
+ token associated with the sentence embedding (Eg: CLS for BERT)
34
+ sep_token_idx: Id of separator token in dictionary
35
+ (Eg: SEP in BERT)
36
+ seed: Seed for random number generator for reproducibility.
37
+ shuffle: Shuffle the elements before batching.
38
+ has_pairs: Specifies whether the underlying dataset
39
+ generates a pair of blocks along with a sentence_target or not.
40
+ Setting it to True assumes that the underlying dataset generates a
41
+ label for the pair of sentences which is surfaced as
42
+ sentence_target. The default value assumes a single block with no
43
+ sentence target.
44
+ segment_id: An optional segment id for filling in the segment labels
45
+ when we are in the single block setting (Eg: XLM). Default is 0.
46
+ masking_ratio: specifies what percentage of the blocks should be masked.
47
+ masking_prob: specifies the probability of a given token being
48
+ replaced with the "MASK" token.
49
+ random_token_prob: specifies the probability of a given token being
50
+ replaced by a random token from the vocabulary.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ dataset: FairseqDataset,
56
+ sizes: np.ndarray,
57
+ vocab: Dictionary,
58
+ pad_idx: int,
59
+ mask_idx: int,
60
+ classif_token_idx: int,
61
+ sep_token_idx: int,
62
+ seed: int = 1,
63
+ shuffle: bool = True,
64
+ has_pairs: bool = True,
65
+ segment_id: int = 0,
66
+ masking_ratio: float = 0.15,
67
+ masking_prob: float = 0.8,
68
+ random_token_prob: float = 0.1,
69
+ ):
70
+ # Make sure the input datasets are the ones supported
71
+ assert (
72
+ isinstance(dataset, TokenBlockDataset)
73
+ or isinstance(dataset, BlockPairDataset)
74
+ or isinstance(dataset, ConcatDataset)
75
+ ), (
76
+ "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or "
77
+ "ConcatDataset"
78
+ )
79
+
80
+ self.dataset = dataset
81
+ self.sizes = np.array(sizes)
82
+ self.vocab = vocab
83
+ self.pad_idx = pad_idx
84
+ self.mask_idx = mask_idx
85
+ self.classif_token_idx = classif_token_idx
86
+ self.sep_token_idx = sep_token_idx
87
+ self.shuffle = shuffle
88
+ self.seed = seed
89
+ self.has_pairs = has_pairs
90
+ self.segment_id = segment_id
91
+ self.masking_ratio = masking_ratio
92
+ self.masking_prob = masking_prob
93
+ self.random_token_prob = random_token_prob
94
+
95
+ # If we have only one block then sizes needs to be updated to include
96
+ # the classification token
97
+ if not has_pairs:
98
+ self.sizes = self.sizes + 1
99
+
100
+ def __getitem__(self, index: int):
101
+ # if has_pairs, then expect 2 blocks and a sentence target
102
+ if self.has_pairs:
103
+ (block_one, block_two, sentence_target) = self.dataset[index]
104
+ else:
105
+ block_one = self.dataset[index]
106
+
107
+ return {
108
+ "id": index,
109
+ "block_one": block_one,
110
+ "block_two": block_two if self.has_pairs else None,
111
+ "sentence_target": sentence_target if self.has_pairs else None,
112
+ }
113
+
114
+ def __len__(self):
115
+ return len(self.dataset)
116
+
117
+ def _mask_block(
118
+ self,
119
+ sentence: np.ndarray,
120
+ mask_idx: int,
121
+ pad_idx: int,
122
+ dictionary_token_range: Tuple,
123
+ ):
124
+ """
125
+ Mask tokens for Masked Language Model training
126
+ Samples mask_ratio tokens that will be predicted by LM.
127
+
128
+ Note:This function may not be efficient enough since we had multiple
129
+ conversions between np and torch, we can replace them with torch
130
+ operators later.
131
+
132
+ Args:
133
+ sentence: 1d tensor to be masked
134
+ mask_idx: index to use for masking the sentence
135
+ pad_idx: index to use for masking the target for tokens we aren't
136
+ predicting
137
+ dictionary_token_range: range of indices in dictionary which can
138
+ be used for random word replacement
139
+ (e.g. without special characters)
140
+ Return:
141
+ masked_sent: masked sentence
142
+ target: target with words which we are not predicting replaced
143
+ by pad_idx
144
+ """
145
+ masked_sent = np.copy(sentence)
146
+ sent_length = len(sentence)
147
+ mask_num = math.ceil(sent_length * self.masking_ratio)
148
+ mask = np.random.choice(sent_length, mask_num, replace=False)
149
+ target = np.copy(sentence)
150
+
151
+ for i in range(sent_length):
152
+ if i in mask:
153
+ rand = np.random.random()
154
+
155
+ # replace with mask if probability is less than masking_prob
156
+ # (Eg: 0.8)
157
+ if rand < self.masking_prob:
158
+ masked_sent[i] = mask_idx
159
+
160
+ # replace with random token if probability is less than
161
+ # masking_prob + random_token_prob (Eg: 0.9)
162
+ elif rand < (self.masking_prob + self.random_token_prob):
163
+ # sample random token from dictionary
164
+ masked_sent[i] = np.random.randint(
165
+ dictionary_token_range[0], dictionary_token_range[1]
166
+ )
167
+ else:
168
+ target[i] = pad_idx
169
+
170
+ return masked_sent, target
171
+
172
+ def _collate(self, samples: List[Dict], pad_idx: int, eos_idx: int):
173
+ """
174
+ Does the heavy lifting for creating a batch from the input list of
175
+ examples. The logic is as follows:
176
+ 1. Mask the input blocks. In case has_pair is True then we have 2
177
+ blocks to mask.
178
+ 2. Prepend the first masked block tensor with the special token
179
+ used as sentence embedding. Eg: CLS in BERT. This happens
180
+ irrespective of the value of has_pair.
181
+ 3. If has_pair is True, then append the first masked block with the
182
+ special separator token (eg: SEP for BERT) and compute segment
183
+ label accordingly. In this case, also append the second masked
184
+ block with this special separator token and compute its segment
185
+ label.
186
+ 4. For the targets tensor, prepend and append with padding index
187
+ accordingly.
188
+ 5. Concatenate all tensors.
189
+ """
190
+ if len(samples) == 0:
191
+ return {}
192
+ # To ensure determinism, we reset the state of the PRNG after every
193
+ # batch based on the seed and the first id of the batch. This ensures
194
+ # that across epochs we get the same mask for the same example. This
195
+ # is needed for reproducibility and is how BERT does masking
196
+ # TODO: Can we add deteminism without this constraint?
197
+ with data_utils.numpy_seed(self.seed + samples[0]["id"]):
198
+ for s in samples:
199
+
200
+ # token range is needed for replacing with random token during
201
+ # masking
202
+ token_range = (self.vocab.nspecial, len(self.vocab))
203
+
204
+ # mask according to specified probabilities.
205
+ masked_blk_one, masked_tgt_one = self._mask_block(
206
+ s["block_one"],
207
+ self.mask_idx,
208
+ self.pad_idx,
209
+ token_range,
210
+ )
211
+
212
+ tokens = np.concatenate([[self.classif_token_idx], masked_blk_one])
213
+ targets = np.concatenate([[self.pad_idx], masked_tgt_one])
214
+ segments = np.ones(len(tokens)) * self.segment_id
215
+
216
+ # if has_pairs is True then we need to add the SEP token to both
217
+ # the blocks after masking and re-compute segments based on the new
218
+ # lengths.
219
+ if self.has_pairs:
220
+ tokens_one = np.concatenate([tokens, [self.sep_token_idx]])
221
+ targets_one = np.concatenate([targets, [self.pad_idx]])
222
+
223
+ masked_blk_two, masked_tgt_two = self._mask_block(
224
+ s["block_two"], self.mask_idx, self.pad_idx, token_range
225
+ )
226
+ tokens_two = np.concatenate([masked_blk_two, [self.sep_token_idx]])
227
+ targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]])
228
+
229
+ # block + 1 sep + 1 special (CLS)
230
+ segments_one = np.zeros(len(tokens_one))
231
+ # block + 1 sep
232
+ segments_two = np.ones(len(tokens_two))
233
+
234
+ tokens = np.concatenate([tokens_one, tokens_two])
235
+ targets = np.concatenate([targets_one, targets_two])
236
+ segments = np.concatenate([segments_one, segments_two])
237
+
238
+ s["source"] = torch.LongTensor(tokens)
239
+ s["segment_labels"] = torch.LongTensor(segments)
240
+ s["lm_target"] = torch.LongTensor(targets)
241
+
242
+ def merge(key):
243
+ return data_utils.collate_tokens(
244
+ [s[key] for s in samples], pad_idx, eos_idx, left_pad=False
245
+ )
246
+
247
+ return {
248
+ "id": torch.LongTensor([s["id"] for s in samples]),
249
+ "ntokens": sum(len(s["source"]) for s in samples),
250
+ "net_input": {
251
+ "src_tokens": merge("source"),
252
+ "segment_labels": merge("segment_labels"),
253
+ },
254
+ "lm_target": merge("lm_target"),
255
+ "sentence_target": torch.LongTensor([s["sentence_target"] for s in samples])
256
+ if self.has_pairs
257
+ else None,
258
+ "nsentences": len(samples),
259
+ }
260
+
261
+ def collater(self, samples: List[Dict]):
262
+ """Merge a list of samples to form a mini-batch.
263
+
264
+ Args:
265
+ samples (List[dict]): samples to collate
266
+
267
+ Returns:
268
+ dict: a mini-batch of data
269
+ """
270
+ return self._collate(samples, self.vocab.pad(), self.vocab.eos())
271
+
272
+ def num_tokens(self, index: int):
273
+ """
274
+ Return the number of tokens in a sample. This value is used to
275
+ enforce max-tokens during batching.
276
+ """
277
+ return self.sizes[index]
278
+
279
+ def size(self, index: int):
280
+ """
281
+ Return an example's size as a float or tuple. This value is used when
282
+ filtering a dataset with max-positions.
283
+ """
284
+ return self.sizes[index]
285
+
286
+ def ordered_indices(self):
287
+ """
288
+ Return an ordered list of indices. Batches will be constructed based
289
+ on this order.
290
+ """
291
+ if self.shuffle:
292
+ return np.random.permutation(len(self))
293
+ else:
294
+ order = [np.arange(len(self))]
295
+ order.append(self.sizes)
296
+ return np.lexsort(order)
297
+
298
+ @property
299
+ def supports_prefetch(self):
300
+ return getattr(self.dataset, "supports_prefetch", False)
301
+
302
+ def prefetch(self, indices):
303
+ self.dataset.prefetch(indices)
fairseq-0.10.2/fairseq/data/lm_context_window_dataset.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ from fairseq.data.monolingual_dataset import MonolingualDataset
9
+
10
+ from . import FairseqDataset
11
+
12
+
13
+ class LMContextWindowDataset(FairseqDataset):
14
+ """Wraps a MonolingualDataset and provides more context for evaluation."""
15
+
16
+ def __init__(self, dataset, tokens_per_sample, context_window, pad_idx):
17
+ assert isinstance(dataset, MonolingualDataset)
18
+ assert context_window > 0
19
+ self.dataset = dataset
20
+ self.tokens_per_sample = tokens_per_sample
21
+ self.context_window = context_window
22
+ self.pad_idx = pad_idx
23
+ self.prev_tokens = np.empty([0])
24
+
25
+ def __getitem__(self, index):
26
+ return self.dataset[index]
27
+
28
+ def __len__(self):
29
+ return len(self.dataset)
30
+
31
+ def collater(self, samples):
32
+ sample = self.dataset.collater(samples)
33
+
34
+ pad = self.pad_idx
35
+ max_sample_len = self.tokens_per_sample + self.context_window
36
+
37
+ bsz, tsz = sample["net_input"]["src_tokens"].shape
38
+ start_idxs = [0] * bsz
39
+ toks = sample["net_input"]["src_tokens"]
40
+ lengths = sample["net_input"]["src_lengths"]
41
+ tgt = sample["target"]
42
+ new_toks = np.empty([bsz, tsz + self.context_window], dtype=np.int64)
43
+ new_tgt = np.full([bsz, tsz + self.context_window], pad, dtype=np.int64)
44
+ sample_lens = toks.ne(pad).long().sum(dim=1).cpu()
45
+ for i in range(bsz):
46
+ sample_len = sample_lens[i]
47
+ extra = len(self.prev_tokens) + sample_len - max_sample_len
48
+ if extra > 0:
49
+ self.prev_tokens = self.prev_tokens[extra:]
50
+ pads = np.full(self.context_window - len(self.prev_tokens), pad)
51
+ new_toks[i] = np.concatenate([self.prev_tokens, toks[i].numpy(), pads])
52
+ new_tgt[
53
+ i, len(self.prev_tokens) : len(self.prev_tokens) + len(tgt[i])
54
+ ] = tgt[i]
55
+ start_idxs[i] = len(self.prev_tokens)
56
+ lengths[i] += len(self.prev_tokens)
57
+ self.prev_tokens = new_toks[i][new_toks[i] != pad][-self.context_window :]
58
+ sample["net_input"]["src_tokens"] = torch.from_numpy(new_toks)
59
+ sample["target"] = torch.from_numpy(new_tgt)
60
+ sample["start_indices"] = start_idxs
61
+
62
+ return sample
63
+
64
+ def num_tokens(self, index):
65
+ return self.dataset.num_tokens(index)
66
+
67
+ def size(self, index):
68
+ return self.dataset.size(index)
69
+
70
+ def ordered_indices(self):
71
+ # NOTE we don't shuffle the data to retain access to the previous dataset elements
72
+ return np.arange(len(self.dataset))
73
+
74
+ @property
75
+ def supports_prefetch(self):
76
+ return getattr(self.dataset, "supports_prefetch", False)
77
+
78
+ def prefetch(self, indices):
79
+ return self.dataset.prefetch(indices)
fairseq-0.10.2/fairseq/data/monolingual_dataset.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from . import FairseqDataset, data_utils
10
+
11
+
12
+ def collate(samples, pad_idx, eos_idx):
13
+ if len(samples) == 0:
14
+ return {}
15
+
16
+ def merge(key, is_list=False):
17
+ if is_list:
18
+ res = []
19
+ for i in range(len(samples[0][key])):
20
+ res.append(
21
+ data_utils.collate_tokens(
22
+ [s[key][i] for s in samples],
23
+ pad_idx,
24
+ eos_idx,
25
+ left_pad=False,
26
+ )
27
+ )
28
+ return res
29
+ else:
30
+ return data_utils.collate_tokens(
31
+ [s[key] for s in samples],
32
+ pad_idx,
33
+ eos_idx,
34
+ left_pad=False,
35
+ )
36
+
37
+ src_tokens = merge("source")
38
+ if samples[0]["target"] is not None:
39
+ is_target_list = isinstance(samples[0]["target"], list)
40
+ target = merge("target", is_target_list)
41
+ else:
42
+ target = src_tokens
43
+
44
+ return {
45
+ "id": torch.LongTensor([s["id"] for s in samples]),
46
+ "nsentences": len(samples),
47
+ "ntokens": sum(len(s["source"]) for s in samples),
48
+ "net_input": {
49
+ "src_tokens": src_tokens,
50
+ "src_lengths": torch.LongTensor([s["source"].numel() for s in samples]),
51
+ },
52
+ "target": target,
53
+ }
54
+
55
+
56
+ class MonolingualDataset(FairseqDataset):
57
+ """
58
+ A wrapper around torch.utils.data.Dataset for monolingual data.
59
+
60
+ Args:
61
+ dataset (torch.utils.data.Dataset): dataset to wrap
62
+ sizes (List[int]): sentence lengths
63
+ vocab (~fairseq.data.Dictionary): vocabulary
64
+ shuffle (bool, optional): shuffle the elements before batching
65
+ (default: True).
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ dataset,
71
+ sizes,
72
+ src_vocab,
73
+ tgt_vocab,
74
+ add_eos_for_other_targets,
75
+ shuffle,
76
+ targets=None,
77
+ add_bos_token=False,
78
+ ):
79
+ self.dataset = dataset
80
+ self.sizes = np.array(sizes)
81
+ self.vocab = src_vocab
82
+ self.tgt_vocab = tgt_vocab
83
+ self.add_eos_for_other_targets = add_eos_for_other_targets
84
+ self.shuffle = shuffle
85
+ self.add_bos_token = add_bos_token
86
+
87
+ assert targets is None or all(
88
+ t in {"self", "future", "past"} for t in targets
89
+ ), "targets must be none or one of 'self', 'future', 'past'"
90
+ if targets is not None and len(targets) == 0:
91
+ targets = None
92
+ self.targets = targets
93
+
94
+ def __getitem__(self, index):
95
+ if self.targets is not None:
96
+ # *future_target* is the original sentence
97
+ # *source* is shifted right by 1 (maybe left-padded with eos)
98
+ # *past_target* is shifted right by 2 (left-padded as needed)
99
+ #
100
+ # Left-to-right language models should condition on *source* and
101
+ # predict *future_target*.
102
+ # Right-to-left language models should condition on *source* and
103
+ # predict *past_target*.
104
+ source, future_target, past_target = self.dataset[index]
105
+ source, target = self._make_source_target(
106
+ source, future_target, past_target
107
+ )
108
+ else:
109
+ source = self.dataset[index]
110
+ target = None
111
+ source, target = self._maybe_add_bos(source, target)
112
+ return {"id": index, "source": source, "target": target}
113
+
114
+ def __len__(self):
115
+ return len(self.dataset)
116
+
117
+ def _make_source_target(self, source, future_target, past_target):
118
+ if self.targets is not None:
119
+ target = []
120
+
121
+ if (
122
+ self.add_eos_for_other_targets
123
+ and (("self" in self.targets) or ("past" in self.targets))
124
+ and source[-1] != self.vocab.eos()
125
+ ):
126
+ # append eos at the end of source
127
+ source = torch.cat([source, source.new([self.vocab.eos()])])
128
+
129
+ if "future" in self.targets:
130
+ future_target = torch.cat(
131
+ [future_target, future_target.new([self.vocab.pad()])]
132
+ )
133
+ if "past" in self.targets:
134
+ # first token is before the start of sentence which is only used in "none" break mode when
135
+ # add_eos_for_other_targets is False
136
+ past_target = torch.cat(
137
+ [
138
+ past_target.new([self.vocab.pad()]),
139
+ past_target[1:],
140
+ source[-2, None],
141
+ ]
142
+ )
143
+
144
+ for t in self.targets:
145
+ if t == "self":
146
+ target.append(source)
147
+ elif t == "future":
148
+ target.append(future_target)
149
+ elif t == "past":
150
+ target.append(past_target)
151
+ else:
152
+ raise Exception("invalid target " + t)
153
+
154
+ if len(target) == 1:
155
+ target = target[0]
156
+ else:
157
+ target = future_target
158
+
159
+ return source, self._filter_vocab(target)
160
+
161
+ def _maybe_add_bos(self, source, target):
162
+ if self.add_bos_token:
163
+ source = torch.cat([source.new([self.vocab.bos()]), source])
164
+ if target is not None:
165
+ target = torch.cat([target.new([self.tgt_vocab.bos()]), target])
166
+ return source, target
167
+
168
+ def _filter_vocab(self, target):
169
+ if len(self.tgt_vocab) != len(self.vocab):
170
+
171
+ def _filter(target):
172
+ mask = target.ge(len(self.tgt_vocab))
173
+ if mask.any():
174
+ target[mask] = self.tgt_vocab.unk()
175
+ return target
176
+
177
+ if isinstance(target, list):
178
+ return [_filter(t) for t in target]
179
+ return _filter(target)
180
+ return target
181
+
182
+ def collater(self, samples):
183
+ """Merge a list of samples to form a mini-batch.
184
+
185
+ Args:
186
+ samples (List[dict]): samples to collate
187
+
188
+ Returns:
189
+ dict: a mini-batch with the following keys:
190
+
191
+ - `id` (LongTensor): example IDs in the original input order
192
+ - `ntokens` (int): total number of tokens in the batch
193
+ - `net_input` (dict): the input to the Model, containing keys:
194
+
195
+ - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
196
+ the source sentence of shape `(bsz, src_len)`. Padding will
197
+ appear on the right.
198
+
199
+ - `target` (LongTensor): a padded 2D Tensor of tokens in the
200
+ target sentence of shape `(bsz, tgt_len)`. Padding will appear
201
+ on the right.
202
+ """
203
+ return collate(samples, self.vocab.pad(), self.vocab.eos())
204
+
205
+ def num_tokens(self, index):
206
+ """Return the number of tokens in a sample. This value is used to
207
+ enforce ``--max-tokens`` during batching."""
208
+ return self.sizes[index]
209
+
210
+ def size(self, index):
211
+ """Return an example's size as a float or tuple. This value is used when
212
+ filtering a dataset with ``--max-positions``."""
213
+ return self.sizes[index]
214
+
215
+ def ordered_indices(self):
216
+ """Return an ordered list of indices. Batches will be constructed based
217
+ on this order."""
218
+ if self.shuffle:
219
+ order = [np.random.permutation(len(self))]
220
+ else:
221
+ order = [np.arange(len(self))]
222
+ order.append(self.sizes)
223
+ return np.lexsort(order)
224
+
225
+ @property
226
+ def supports_prefetch(self):
227
+ return getattr(self.dataset, "supports_prefetch", False)
228
+
229
+ def prefetch(self, indices):
230
+ self.dataset.prefetch(indices)
fairseq-0.10.2/fairseq/data/nested_dictionary_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ from torch.utils.data.dataloader import default_collate
10
+
11
+ from . import FairseqDataset
12
+
13
+
14
+ def _flatten(dico, prefix=None):
15
+ """Flatten a nested dictionary."""
16
+ new_dico = OrderedDict()
17
+ if isinstance(dico, dict):
18
+ prefix = prefix + "." if prefix is not None else ""
19
+ for k, v in dico.items():
20
+ if v is None:
21
+ continue
22
+ new_dico.update(_flatten(v, prefix + k))
23
+ elif isinstance(dico, list):
24
+ for i, v in enumerate(dico):
25
+ new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]"))
26
+ else:
27
+ new_dico = OrderedDict({prefix: dico})
28
+ return new_dico
29
+
30
+
31
+ def _unflatten(dico):
32
+ """Unflatten a flattened dictionary into a nested dictionary."""
33
+ new_dico = OrderedDict()
34
+ for full_k, v in dico.items():
35
+ full_k = full_k.split(".")
36
+ node = new_dico
37
+ for k in full_k[:-1]:
38
+ if k.startswith("[") and k.endswith("]"):
39
+ k = int(k[1:-1])
40
+ if k not in node:
41
+ node[k] = OrderedDict()
42
+ node = node[k]
43
+ node[full_k[-1]] = v
44
+ return new_dico
45
+
46
+
47
+ class NestedDictionaryDataset(FairseqDataset):
48
+ def __init__(self, defn, sizes=None):
49
+ super().__init__()
50
+ self.defn = _flatten(defn)
51
+ self.sizes = [sizes] if not isinstance(sizes, (list, tuple)) else sizes
52
+
53
+ first = None
54
+ for v in self.defn.values():
55
+ if not isinstance(
56
+ v,
57
+ (
58
+ FairseqDataset,
59
+ torch.utils.data.Dataset,
60
+ ),
61
+ ):
62
+ raise ValueError("Expected Dataset but found: {}".format(v.__class__))
63
+ first = first or v
64
+ if len(v) > 0:
65
+ assert len(v) == len(first), "dataset lengths must match"
66
+
67
+ self._len = len(first)
68
+
69
+ def __getitem__(self, index):
70
+ return OrderedDict((k, ds[index]) for k, ds in self.defn.items())
71
+
72
+ def __len__(self):
73
+ return self._len
74
+
75
+ def collater(self, samples):
76
+ """Merge a list of samples to form a mini-batch.
77
+
78
+ Args:
79
+ samples (List[dict]): samples to collate
80
+
81
+ Returns:
82
+ dict: a mini-batch suitable for forwarding with a Model
83
+ """
84
+ if len(samples) == 0:
85
+ return {}
86
+ sample = OrderedDict()
87
+ for k, ds in self.defn.items():
88
+ try:
89
+ sample[k] = ds.collater([s[k] for s in samples])
90
+ except NotImplementedError:
91
+ sample[k] = default_collate([s[k] for s in samples])
92
+ return _unflatten(sample)
93
+
94
+ def num_tokens(self, index):
95
+ """Return the number of tokens in a sample. This value is used to
96
+ enforce ``--max-tokens`` during batching."""
97
+ return max(s[index] for s in self.sizes)
98
+
99
+ def size(self, index):
100
+ """Return an example's size as a float or tuple. This value is used when
101
+ filtering a dataset with ``--max-positions``."""
102
+ if len(self.sizes) == 1:
103
+ return self.sizes[0][index]
104
+ else:
105
+ return (s[index] for s in self.sizes)
106
+
107
+ @property
108
+ def supports_prefetch(self):
109
+ """Whether this dataset supports prefetching."""
110
+ return any(ds.supports_prefetch for ds in self.defn.values())
111
+
112
+ def prefetch(self, indices):
113
+ """Prefetch the data required for this epoch."""
114
+ for ds in self.defn.values():
115
+ if getattr(ds, "supports_prefetch", False):
116
+ ds.prefetch(indices)
117
+
118
+ @property
119
+ def can_reuse_epoch_itr_across_epochs(self):
120
+ return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values())
121
+
122
+ def set_epoch(self, epoch):
123
+ super().set_epoch(epoch)
124
+ for ds in self.defn.values():
125
+ ds.set_epoch(epoch)
fairseq-0.10.2/fairseq/data/noising.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ from fairseq.data import data_utils
9
+
10
+
11
+ class WordNoising(object):
12
+ """Generate a noisy version of a sentence, without changing words themselves."""
13
+
14
+ def __init__(self, dictionary, bpe_cont_marker="@@", bpe_end_marker=None):
15
+ self.dictionary = dictionary
16
+ self.bpe_end = None
17
+ if bpe_cont_marker:
18
+ self.bpe_end = np.array(
19
+ [
20
+ not self.dictionary[i].endswith(bpe_cont_marker)
21
+ for i in range(len(self.dictionary))
22
+ ]
23
+ )
24
+ elif bpe_end_marker:
25
+ self.bpe_end = np.array(
26
+ [
27
+ self.dictionary[i].endswith(bpe_end_marker)
28
+ for i in range(len(self.dictionary))
29
+ ]
30
+ )
31
+
32
+ self.get_word_idx = (
33
+ self._get_bpe_word_idx if self.bpe_end is not None else self._get_token_idx
34
+ )
35
+
36
+ def noising(self, x, lengths, noising_prob=0.0):
37
+ raise NotImplementedError()
38
+
39
+ def _get_bpe_word_idx(self, x):
40
+ """
41
+ Given a list of BPE tokens, for every index in the tokens list,
42
+ return the index of the word grouping that it belongs to.
43
+ For example, for input x corresponding to ["how", "are", "y@@", "ou"],
44
+ return [[0], [1], [2], [2]].
45
+ """
46
+ # x: (T x B)
47
+ bpe_end = self.bpe_end[x]
48
+
49
+ if x.size(0) == 1 and x.size(1) == 1:
50
+ # Special case when we only have one word in x. If x = [[N]],
51
+ # bpe_end is a scalar (bool) instead of a 2-dim array of bools,
52
+ # which makes the sum operation below fail.
53
+ return np.array([[0]])
54
+
55
+ # do a reduce front sum to generate word ids
56
+ word_idx = bpe_end[::-1].cumsum(0)[::-1]
57
+ word_idx = word_idx.max(0)[None, :] - word_idx
58
+ return word_idx
59
+
60
+ def _get_token_idx(self, x):
61
+ """
62
+ This is to extend noising functions to be able to apply to non-bpe
63
+ tokens, e.g. word or characters.
64
+ """
65
+ x = torch.t(x)
66
+ word_idx = np.array([range(len(x_i)) for x_i in x])
67
+ return np.transpose(word_idx)
68
+
69
+
70
+ class WordDropout(WordNoising):
71
+ """Randomly drop input words. If not passing blank_idx (default is None),
72
+ then dropped words will be removed. Otherwise, it will be replaced by the
73
+ blank_idx."""
74
+
75
+ def __init__(
76
+ self,
77
+ dictionary,
78
+ default_dropout_prob=0.1,
79
+ bpe_cont_marker="@@",
80
+ bpe_end_marker=None,
81
+ ):
82
+ super().__init__(dictionary, bpe_cont_marker, bpe_end_marker)
83
+ self.default_dropout_prob = default_dropout_prob
84
+
85
+ def noising(self, x, lengths, dropout_prob=None, blank_idx=None):
86
+ if dropout_prob is None:
87
+ dropout_prob = self.default_dropout_prob
88
+ # x: (T x B), lengths: B
89
+ if dropout_prob == 0:
90
+ return x, lengths
91
+
92
+ assert 0 < dropout_prob < 1
93
+
94
+ # be sure to drop entire words
95
+ word_idx = self.get_word_idx(x)
96
+ sentences = []
97
+ modified_lengths = []
98
+ for i in range(lengths.size(0)):
99
+ # Since dropout probabilities need to apply over non-pad tokens,
100
+ # it is not trivial to generate the keep mask without consider
101
+ # input lengths; otherwise, this could be done outside the loop
102
+
103
+ # We want to drop whole words based on word_idx grouping
104
+ num_words = max(word_idx[:, i]) + 1
105
+
106
+ # ith example: [x0, x1, ..., eos, pad, ..., pad]
107
+ # We should only generate keep probs for non-EOS tokens. Thus if the
108
+ # input sentence ends in EOS, the last word idx is not included in
109
+ # the dropout mask generation and we append True to always keep EOS.
110
+ # Otherwise, just generate the dropout mask for all word idx
111
+ # positions.
112
+ has_eos = x[lengths[i] - 1, i] == self.dictionary.eos()
113
+ if has_eos: # has eos?
114
+ keep = np.random.rand(num_words - 1) >= dropout_prob
115
+ keep = np.append(keep, [True]) # keep EOS symbol
116
+ else:
117
+ keep = np.random.rand(num_words) >= dropout_prob
118
+
119
+ words = x[: lengths[i], i].tolist()
120
+
121
+ # TODO: speed up the following loop
122
+ # drop words from the input according to keep
123
+ new_s = [
124
+ w if keep[word_idx[j, i]] else blank_idx for j, w in enumerate(words)
125
+ ]
126
+ new_s = [w for w in new_s if w is not None]
127
+ # we need to have at least one word in the sentence (more than the
128
+ # start / end sentence symbols)
129
+ if len(new_s) <= 1:
130
+ # insert at beginning in case the only token left is EOS
131
+ # EOS should be at end of list.
132
+ new_s.insert(0, words[np.random.randint(0, len(words))])
133
+ assert len(new_s) >= 1 and (
134
+ not has_eos # Either don't have EOS at end or last token is EOS
135
+ or (len(new_s) >= 2 and new_s[-1] == self.dictionary.eos())
136
+ ), "New sentence is invalid."
137
+ sentences.append(new_s)
138
+ modified_lengths.append(len(new_s))
139
+ # re-construct input
140
+ modified_lengths = torch.LongTensor(modified_lengths)
141
+ modified_x = torch.LongTensor(
142
+ modified_lengths.max(), modified_lengths.size(0)
143
+ ).fill_(self.dictionary.pad())
144
+ for i in range(modified_lengths.size(0)):
145
+ modified_x[: modified_lengths[i], i].copy_(torch.LongTensor(sentences[i]))
146
+
147
+ return modified_x, modified_lengths
148
+
149
+
150
+ class WordShuffle(WordNoising):
151
+ """Shuffle words by no more than k positions."""
152
+
153
+ def __init__(
154
+ self,
155
+ dictionary,
156
+ default_max_shuffle_distance=3,
157
+ bpe_cont_marker="@@",
158
+ bpe_end_marker=None,
159
+ ):
160
+ super().__init__(dictionary, bpe_cont_marker, bpe_end_marker)
161
+ self.default_max_shuffle_distance = 3
162
+
163
+ def noising(self, x, lengths, max_shuffle_distance=None):
164
+ if max_shuffle_distance is None:
165
+ max_shuffle_distance = self.default_max_shuffle_distance
166
+ # x: (T x B), lengths: B
167
+ if max_shuffle_distance == 0:
168
+ return x, lengths
169
+
170
+ # max_shuffle_distance < 1 will return the same sequence
171
+ assert max_shuffle_distance > 1
172
+
173
+ # define noise word scores
174
+ noise = np.random.uniform(
175
+ 0,
176
+ max_shuffle_distance,
177
+ size=(x.size(0), x.size(1)),
178
+ )
179
+ noise[0] = -1 # do not move start sentence symbol
180
+ # be sure to shuffle entire words
181
+ word_idx = self.get_word_idx(x)
182
+ x2 = x.clone()
183
+ for i in range(lengths.size(0)):
184
+ length_no_eos = lengths[i]
185
+ if x[lengths[i] - 1, i] == self.dictionary.eos():
186
+ length_no_eos = lengths[i] - 1
187
+ # generate a random permutation
188
+ scores = word_idx[:length_no_eos, i] + noise[word_idx[:length_no_eos, i], i]
189
+ # ensure no reordering inside a word
190
+ scores += 1e-6 * np.arange(length_no_eos.item())
191
+ permutation = scores.argsort()
192
+ # shuffle words
193
+ x2[:length_no_eos, i].copy_(
194
+ x2[:length_no_eos, i][torch.from_numpy(permutation)]
195
+ )
196
+ return x2, lengths
197
+
198
+
199
+ class UnsupervisedMTNoising(WordNoising):
200
+ """
201
+ Implements the default configuration for noising in UnsupervisedMT
202
+ (github.com/facebookresearch/UnsupervisedMT)
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ dictionary,
208
+ max_word_shuffle_distance,
209
+ word_dropout_prob,
210
+ word_blanking_prob,
211
+ bpe_cont_marker="@@",
212
+ bpe_end_marker=None,
213
+ ):
214
+ super().__init__(dictionary)
215
+ self.max_word_shuffle_distance = max_word_shuffle_distance
216
+ self.word_dropout_prob = word_dropout_prob
217
+ self.word_blanking_prob = word_blanking_prob
218
+
219
+ self.word_dropout = WordDropout(
220
+ dictionary=dictionary,
221
+ bpe_cont_marker=bpe_cont_marker,
222
+ bpe_end_marker=bpe_end_marker,
223
+ )
224
+ self.word_shuffle = WordShuffle(
225
+ dictionary=dictionary,
226
+ bpe_cont_marker=bpe_cont_marker,
227
+ bpe_end_marker=bpe_end_marker,
228
+ )
229
+
230
+ def noising(self, x, lengths):
231
+ # 1. Word Shuffle
232
+ noisy_src_tokens, noisy_src_lengths = self.word_shuffle.noising(
233
+ x=x,
234
+ lengths=lengths,
235
+ max_shuffle_distance=self.max_word_shuffle_distance,
236
+ )
237
+ # 2. Word Dropout
238
+ noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising(
239
+ x=noisy_src_tokens,
240
+ lengths=noisy_src_lengths,
241
+ dropout_prob=self.word_dropout_prob,
242
+ )
243
+ # 3. Word Blanking
244
+ noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising(
245
+ x=noisy_src_tokens,
246
+ lengths=noisy_src_lengths,
247
+ dropout_prob=self.word_blanking_prob,
248
+ blank_idx=self.dictionary.unk(),
249
+ )
250
+
251
+ return noisy_src_tokens
252
+
253
+
254
+ class NoisingDataset(torch.utils.data.Dataset):
255
+ def __init__(
256
+ self,
257
+ src_dataset,
258
+ src_dict,
259
+ seed,
260
+ noiser=None,
261
+ noising_class=UnsupervisedMTNoising,
262
+ **kwargs
263
+ ):
264
+ """
265
+ Wrap a :class:`~torch.utils.data.Dataset` and apply noise to the
266
+ samples based on the supplied noising configuration.
267
+
268
+ Args:
269
+ src_dataset (~torch.utils.data.Dataset): dataset to wrap.
270
+ to build self.src_dataset --
271
+ a LanguagePairDataset with src dataset as the source dataset and
272
+ None as the target dataset. Should NOT have padding so that
273
+ src_lengths are accurately calculated by language_pair_dataset
274
+ collate function.
275
+ We use language_pair_dataset here to encapsulate the tgt_dataset
276
+ so we can re-use the LanguagePairDataset collater to format the
277
+ batches in the structure that SequenceGenerator expects.
278
+ src_dict (~fairseq.data.Dictionary): source dictionary
279
+ seed (int): seed to use when generating random noise
280
+ noiser (WordNoising): a pre-initialized :class:`WordNoising`
281
+ instance. If this is None, a new instance will be created using
282
+ *noising_class* and *kwargs*.
283
+ noising_class (class, optional): class to use to initialize a
284
+ default :class:`WordNoising` instance.
285
+ kwargs (dict, optional): arguments to initialize the default
286
+ :class:`WordNoising` instance given by *noiser*.
287
+ """
288
+ self.src_dataset = src_dataset
289
+ self.src_dict = src_dict
290
+ self.seed = seed
291
+ self.noiser = (
292
+ noiser
293
+ if noiser is not None
294
+ else noising_class(
295
+ dictionary=src_dict,
296
+ **kwargs,
297
+ )
298
+ )
299
+
300
+ def __getitem__(self, index):
301
+ """
302
+ Returns a single noisy sample. Multiple samples are fed to the collater
303
+ create a noising dataset batch.
304
+ """
305
+ src_tokens = self.src_dataset[index]
306
+ src_lengths = torch.LongTensor([len(src_tokens)])
307
+ src_tokens = src_tokens.unsqueeze(0)
308
+
309
+ # Transpose src tokens to fit expected shape of x in noising function
310
+ # (batch size, sequence length) -> (sequence length, batch size)
311
+ src_tokens_t = torch.t(src_tokens)
312
+
313
+ with data_utils.numpy_seed(self.seed + index):
314
+ noisy_src_tokens = self.noiser.noising(src_tokens_t, src_lengths)
315
+
316
+ # Transpose back to expected src_tokens format
317
+ # (sequence length, 1) -> (1, sequence length)
318
+ noisy_src_tokens = torch.t(noisy_src_tokens)
319
+ return noisy_src_tokens[0]
320
+
321
+ def __len__(self):
322
+ """
323
+ The length of the noising dataset is the length of src.
324
+ """
325
+ return len(self.src_dataset)
326
+
327
+ @property
328
+ def supports_prefetch(self):
329
+ return self.src_dataset.supports_prefetch
330
+
331
+ def prefetch(self, indices):
332
+ if self.src_dataset.supports_prefetch:
333
+ self.src_dataset.prefetch(indices)
fairseq-0.10.2/fairseq/data/numel_dataset.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from . import BaseWrapperDataset
10
+
11
+
12
+ class NumelDataset(BaseWrapperDataset):
13
+ def __init__(self, dataset, reduce=False):
14
+ super().__init__(dataset)
15
+ self.reduce = reduce
16
+
17
+ def __getitem__(self, index):
18
+ item = self.dataset[index]
19
+ if torch.is_tensor(item):
20
+ return torch.numel(item)
21
+ else:
22
+ return np.size(item)
23
+
24
+ def __len__(self):
25
+ return len(self.dataset)
26
+
27
+ def collater(self, samples):
28
+ if self.reduce:
29
+ return sum(samples)
30
+ else:
31
+ return torch.tensor(samples)
fairseq-0.10.2/fairseq/data/plasma_utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import subprocess
7
+ import tempfile
8
+
9
+
10
+ class PlasmaArray(object):
11
+ """
12
+ Wrapper around numpy arrays that automatically moves the data to shared
13
+ memory upon serialization. This is particularly helpful when passing numpy
14
+ arrays through multiprocessing, so that data is not unnecessarily
15
+ duplicated or pickled.
16
+ """
17
+
18
+ def __init__(self, array):
19
+ super().__init__()
20
+ self.array = array
21
+ self.disable = array.nbytes < 134217728 # disable for arrays <128MB
22
+ self.object_id = None
23
+ self.path = None
24
+
25
+ # variables with underscores shouldn't be pickled
26
+ self._client = None
27
+ self._server = None
28
+ self._server_tmp = None
29
+ self._plasma = None
30
+
31
+ @property
32
+ def plasma(self):
33
+ if self._plasma is None and not self.disable:
34
+ try:
35
+ import pyarrow.plasma as plasma
36
+
37
+ self._plasma = plasma
38
+ except ImportError:
39
+ self._plasma = None
40
+ return self._plasma
41
+
42
+ def start_server(self):
43
+ if self.plasma is None or self._server is not None:
44
+ return
45
+ assert self.object_id is None
46
+ assert self.path is None
47
+ self._server_tmp = tempfile.NamedTemporaryFile()
48
+ self.path = self._server_tmp.name
49
+ self._server = subprocess.Popen(
50
+ [
51
+ "plasma_store",
52
+ "-m",
53
+ str(int(1.05 * self.array.nbytes)),
54
+ "-s",
55
+ self.path,
56
+ ]
57
+ )
58
+
59
+ @property
60
+ def client(self):
61
+ if self._client is None:
62
+ assert self.path is not None
63
+ self._client = self.plasma.connect(self.path)
64
+ return self._client
65
+
66
+ def __getstate__(self):
67
+ if self.plasma is None:
68
+ return self.__dict__
69
+ if self.object_id is None:
70
+ self.start_server()
71
+ self.object_id = self.client.put(self.array)
72
+ state = self.__dict__.copy()
73
+ del state["array"]
74
+ state["_client"] = None
75
+ state["_server"] = None
76
+ state["_server_tmp"] = None
77
+ state["_plasma"] = None
78
+ return state
79
+
80
+ def __setstate__(self, state):
81
+ self.__dict__.update(state)
82
+ if self.plasma is None:
83
+ return
84
+ self.array = self.client.get(self.object_id)
85
+
86
+ def __del__(self):
87
+ if self._server is not None:
88
+ self._server.kill()
89
+ self._server = None
90
+ self._server_tmp.close()
91
+ self._server_tmp = None
fairseq-0.10.2/fairseq/data/prepend_token_dataset.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from . import BaseWrapperDataset
10
+
11
+
12
+ class PrependTokenDataset(BaseWrapperDataset):
13
+ def __init__(self, dataset, token=None):
14
+ super().__init__(dataset)
15
+ self.token = token
16
+ if token is not None:
17
+ self._sizes = np.array(dataset.sizes) + 1
18
+ else:
19
+ self._sizes = dataset.sizes
20
+
21
+ def __getitem__(self, idx):
22
+ item = self.dataset[idx]
23
+ if self.token is not None:
24
+ item = torch.cat([item.new([self.token]), item])
25
+ return item
26
+
27
+ @property
28
+ def sizes(self):
29
+ return self._sizes
30
+
31
+ def num_tokens(self, index):
32
+ n = self.dataset.num_tokens(index)
33
+ if self.token is not None:
34
+ n += 1
35
+ return n
36
+
37
+ def size(self, index):
38
+ n = self.dataset.size(index)
39
+ if self.token is not None:
40
+ n += 1
41
+ return n
fairseq-0.10.2/fairseq/data/raw_label_dataset.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import FairseqDataset
9
+
10
+
11
+ class RawLabelDataset(FairseqDataset):
12
+ def __init__(self, labels):
13
+ super().__init__()
14
+ self.labels = labels
15
+
16
+ def __getitem__(self, index):
17
+ return self.labels[index]
18
+
19
+ def __len__(self):
20
+ return len(self.labels)
21
+
22
+ def collater(self, samples):
23
+ return torch.tensor(samples)
fairseq-0.10.2/fairseq/data/replace_dataset.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from . import BaseWrapperDataset
7
+
8
+
9
+ class ReplaceDataset(BaseWrapperDataset):
10
+ """Replaces tokens found in the dataset by a specified replacement token
11
+
12
+ Args:
13
+ dataset (~torch.utils.data.Dataset): dataset to replace tokens in
14
+ replace_map(Dictionary[int,int]): map of token to replace -> replacement token
15
+ offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be
16
+ as many as the number of objects returned by the underlying dataset __getitem__ method.
17
+ """
18
+
19
+ def __init__(self, dataset, replace_map, offsets):
20
+ super().__init__(dataset)
21
+ assert len(replace_map) > 0
22
+ self.replace_map = replace_map
23
+ self.offsets = offsets
24
+
25
+ def __getitem__(self, index):
26
+ item = self.dataset[index]
27
+ is_tuple = isinstance(item, tuple)
28
+ srcs = item if is_tuple else [item]
29
+
30
+ for offset, src in zip(self.offsets, srcs):
31
+ for k, v in self.replace_map.items():
32
+ src_off = src[offset:] if offset >= 0 else src[:offset]
33
+ src_off.masked_fill_(src_off == k, v)
34
+
35
+ item = srcs if is_tuple else srcs[0]
36
+ return item
fairseq-0.10.2/fairseq/data/roll_dataset.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import BaseWrapperDataset
9
+
10
+
11
+ class RollDataset(BaseWrapperDataset):
12
+ def __init__(self, dataset, shifts):
13
+ super().__init__(dataset)
14
+ self.shifts = shifts
15
+
16
+ def __getitem__(self, index):
17
+ item = self.dataset[index]
18
+ return torch.roll(item, self.shifts)
fairseq-0.10.2/fairseq/data/round_robin_zip_datasets.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+
8
+ import numpy as np
9
+
10
+ from . import FairseqDataset
11
+
12
+
13
+ class RoundRobinZipDatasets(FairseqDataset):
14
+ """Zip multiple :class:`~fairseq.data.FairseqDataset` instances together.
15
+
16
+ Shorter datasets are repeated in a round-robin fashion to match the length
17
+ of the longest one.
18
+
19
+ Args:
20
+ datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of
21
+ :class:`~fairseq.data.FairseqDataset` instances.
22
+ eval_key (str, optional): a key used at evaluation time that causes
23
+ this instance to pass-through batches from *datasets[eval_key]*.
24
+ """
25
+
26
+ def __init__(self, datasets, eval_key=None):
27
+ super().__init__()
28
+ assert isinstance(datasets, OrderedDict)
29
+ self.datasets = datasets
30
+ self.eval_key = eval_key
31
+
32
+ self.longest_dataset = None
33
+ self.longest_dataset_key = None
34
+ for key, dataset in datasets.items():
35
+ assert isinstance(dataset, FairseqDataset)
36
+ if self.longest_dataset is None or len(dataset) > len(self.longest_dataset):
37
+ self.longest_dataset = dataset
38
+ self.longest_dataset_key = key
39
+
40
+ self._ordered_indices = None
41
+
42
+ def _map_index(self, key, index):
43
+ assert (
44
+ self._ordered_indices is not None
45
+ ), "Must call RoundRobinZipDatasets.ordered_indices() first"
46
+ return self._ordered_indices[key][index % len(self.datasets[key])]
47
+
48
+ def __getitem__(self, index):
49
+ if self.eval_key is None:
50
+ return OrderedDict(
51
+ [
52
+ (key, dataset[self._map_index(key, index)])
53
+ for key, dataset in self.datasets.items()
54
+ ]
55
+ )
56
+ else:
57
+ # at evaluation time it's useful to pass-through batches from a single key
58
+ return self.datasets[self.eval_key][self._map_index(self.eval_key, index)]
59
+
60
+ def __len__(self):
61
+ return len(self.longest_dataset)
62
+
63
+ def collater(self, samples):
64
+ """Merge a list of samples to form a mini-batch."""
65
+ if len(samples) == 0:
66
+ return None
67
+ if self.eval_key is None:
68
+ return OrderedDict(
69
+ [
70
+ (key, dataset.collater([sample[key] for sample in samples]))
71
+ for key, dataset in self.datasets.items()
72
+ ]
73
+ )
74
+ else:
75
+ # at evaluation time it's useful to pass-through batches from a single key
76
+ return self.datasets[self.eval_key].collater(samples)
77
+
78
+ def num_tokens(self, index):
79
+ """Return an example's length (number of tokens), used for batching."""
80
+ # TODO make it configurable whether to use max() or sum() here
81
+ return max(
82
+ dataset.num_tokens(self._map_index(key, index))
83
+ for key, dataset in self.datasets.items()
84
+ )
85
+
86
+ def size(self, index):
87
+ """Return an example's size as a float or tuple. This value is used when
88
+ filtering a dataset with ``--max-positions``."""
89
+ return {
90
+ key: dataset.size(self._map_index(key, index))
91
+ for key, dataset in self.datasets.items()
92
+ }
93
+
94
+ def ordered_indices(self):
95
+ """Ordered indices for batching."""
96
+ if self._ordered_indices is None:
97
+ # Call the underlying dataset's ordered_indices() here, so that we
98
+ # get the same random ordering as we would have from using the
99
+ # underlying dataset directly.
100
+ self._ordered_indices = OrderedDict(
101
+ [
102
+ (key, dataset.ordered_indices())
103
+ for key, dataset in self.datasets.items()
104
+ ]
105
+ )
106
+ return np.arange(len(self))
107
+
108
+ @property
109
+ def supports_prefetch(self):
110
+ return all(
111
+ getattr(dataset, "supports_prefetch", False)
112
+ for dataset in self.datasets.values()
113
+ )
114
+
115
+ def prefetch(self, indices):
116
+ for key, dataset in self.datasets.items():
117
+ dataset.prefetch([self._map_index(key, index) for index in indices])
fairseq-0.10.2/fairseq/data/sort_dataset.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+
8
+ from . import BaseWrapperDataset
9
+
10
+
11
+ class SortDataset(BaseWrapperDataset):
12
+ def __init__(self, dataset, sort_order):
13
+ super().__init__(dataset)
14
+ if not isinstance(sort_order, (list, tuple)):
15
+ sort_order = [sort_order]
16
+ self.sort_order = sort_order
17
+
18
+ assert all(len(so) == len(dataset) for so in sort_order)
19
+
20
+ def ordered_indices(self):
21
+ return np.lexsort(self.sort_order)
fairseq-0.10.2/fairseq/data/subsample_dataset.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ import numpy as np
9
+
10
+ from . import BaseWrapperDataset
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class SubsampleDataset(BaseWrapperDataset):
17
+ """Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples
18
+
19
+ Args:
20
+ dataset (~torch.utils.data.Dataset): dataset to subsample
21
+ size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive)
22
+ """
23
+
24
+ def __init__(self, dataset, size_ratio, shuffle=False):
25
+ super().__init__(dataset)
26
+ assert size_ratio < 1
27
+ self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int)
28
+ self.indices = np.random.choice(
29
+ list(range(len(self.dataset))), self.actual_size, replace=False
30
+ )
31
+ self.shuffle = shuffle
32
+ logger.info(
33
+ "subsampled dataset from {} to {} (ratio={})".format(
34
+ len(self.dataset), self.actual_size, size_ratio
35
+ )
36
+ )
37
+
38
+ def __getitem__(self, index):
39
+ return self.dataset[self.indices[index]]
40
+
41
+ def __len__(self):
42
+ return self.actual_size
43
+
44
+ def collater(self, samples):
45
+ return self.dataset.collater(samples)
46
+
47
+ @property
48
+ def sizes(self):
49
+ return self.dataset.sizes[self.indices]
50
+
51
+ @property
52
+ def name(self):
53
+ return self.dataset.name
54
+
55
+ def num_tokens(self, index):
56
+ return self.dataset.num_tokens(self.indices[index])
57
+
58
+ def size(self, index):
59
+ return self.dataset.size(self.indices[index])
60
+
61
+ def ordered_indices(self):
62
+ """Return an ordered list of indices. Batches will be constructed based
63
+ on this order."""
64
+ if self.shuffle:
65
+ order = [np.random.permutation(len(self))]
66
+ else:
67
+ order = [np.arange(len(self))]
68
+ order.append(self.sizes)
69
+ return np.lexsort(order)
70
+
71
+ def prefetch(self, indices):
72
+ self.dataset.prefetch(self.indices[indices])
fairseq-0.10.2/fairseq/data/token_block_utils_fast.cpp ADDED
The diff for this file is too large to render. See raw diff
 
fairseq-0.10.2/fairseq/data/token_block_utils_fast.pyx ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cython: language_level=3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from itertools import chain
10
+ from libc.math cimport ceil
11
+
12
+ cimport cython
13
+ cimport numpy as np
14
+
15
+ from libc.stdint cimport int32_t, int64_t
16
+
17
+ DTYPE = np.int64
18
+ ctypedef int64_t DTYPE_t
19
+
20
+
21
+ @cython.boundscheck(False)
22
+ @cython.wraparound(False)
23
+ @cython.nonecheck(False)
24
+ cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size):
25
+ cdef DTYPE_t total_size = sizes.sum()
26
+ cdef DTYPE_t length = <DTYPE_t> ceil(total_size / <double> block_size)
27
+ cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE)
28
+ cdef DTYPE_t[:, :] slice_indices_view = slice_indices
29
+ cdef DTYPE_t i
30
+ cdef DTYPE_t start
31
+ cdef DTYPE_t end
32
+ for i in range(length):
33
+ start = i * block_size
34
+ end = min(start + block_size, total_size)
35
+ slice_indices_view[i][0] = start
36
+ slice_indices_view[i][1] = end
37
+ return slice_indices
38
+
39
+
40
+ cdef np.ndarray[DTYPE_t, ndim=2] _fast_convert_to_np_array(list list_of_list):
41
+ """
42
+ Faster function to convert DTYPE_t list of list.
43
+ Only fast when there are huge number of rows and low number of columns.
44
+ """
45
+ cdef np.ndarray[DTYPE_t, ndim=1] flat = np.fromiter(chain.from_iterable(list_of_list), DTYPE, -1)
46
+ return flat.reshape((len(list_of_list), -1))
47
+
48
+
49
+ @cython.boundscheck(False)
50
+ @cython.wraparound(False)
51
+ @cython.nonecheck(False)
52
+ cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len):
53
+ cdef DTYPE_t tok_idx = 0
54
+ cdef DTYPE_t sz_idx = 0
55
+ cdef DTYPE_t curr_size = 0
56
+ cdef DTYPE_t i = 0
57
+ cdef DTYPE_t length
58
+ cdef DTYPE_t total_size
59
+ cdef DTYPE_t[:] sizes_view = sizes
60
+ cdef np.ndarray[DTYPE_t, ndim=2] slice_indices
61
+ cdef list slice_indices_list = []
62
+
63
+ if break_mode is None or break_mode == 'none':
64
+ slice_indices = _get_slice_indices_none_mode(sizes, block_size)
65
+ elif break_mode == 'complete':
66
+ while sz_idx < len(sizes_view):
67
+ if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0:
68
+ curr_size += sizes_view[sz_idx]
69
+ sz_idx += 1
70
+ else:
71
+ slice_indices_list.append((tok_idx, tok_idx + curr_size))
72
+ tok_idx += curr_size
73
+ curr_size = 0
74
+ if curr_size > 0:
75
+ slice_indices_list.append((tok_idx, tok_idx + curr_size))
76
+ slice_indices = _fast_convert_to_np_array(slice_indices_list)
77
+ elif break_mode == 'complete_doc':
78
+ while sz_idx < len(sizes_view):
79
+ if (
80
+ (curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0)
81
+ # an empty sentence indicates end-of-document:
82
+ and sizes_view[sz_idx] != document_sep_len
83
+ ):
84
+ curr_size += sizes_view[sz_idx]
85
+ sz_idx += 1
86
+ else:
87
+ # Only keep non-empty documents.
88
+ if curr_size > 1:
89
+ slice_indices_list.append((tok_idx, tok_idx + curr_size))
90
+ tok_idx += curr_size
91
+ curr_size = 0
92
+ if sizes_view[sz_idx] == document_sep_len:
93
+ tok_idx += sizes_view[sz_idx]
94
+ sz_idx += 1
95
+ if curr_size > 1:
96
+ slice_indices_list.append((tok_idx, tok_idx + curr_size))
97
+ slice_indices = _fast_convert_to_np_array(slice_indices_list)
98
+ elif break_mode == 'eos':
99
+ slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE)
100
+ cumsum = sizes.cumsum(axis=0)
101
+ slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1]
102
+ slice_indices[:, 1] = cumsum
103
+ else:
104
+ raise ValueError('Invalid break_mode: ' + break_mode)
105
+ return slice_indices
106
+
107
+
108
+ @cython.boundscheck(False)
109
+ @cython.wraparound(False)
110
+ @cython.nonecheck(False)
111
+ cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices):
112
+ cdef DTYPE_t start_ds_idx
113
+ cdef DTYPE_t start_offset
114
+ cdef DTYPE_t end_ds_idx
115
+ cdef DTYPE_t i
116
+ cdef DTYPE_t s
117
+ cdef DTYPE_t e
118
+ cdef DatasetSearcher ds = DatasetSearcher(sizes)
119
+ cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE)
120
+ cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index
121
+ cdef DTYPE_t[:, :] slice_indices_view = slice_indices
122
+ cdef Py_ssize_t x_max = slice_indices.shape[0]
123
+
124
+ for i in range(x_max):
125
+ s = slice_indices_view[i][0]
126
+ e = slice_indices_view[i][1]
127
+ ds.seek(s)
128
+ start_ds_idx = ds.current_index
129
+ start_offset = ds.current_offset
130
+ if e <= s:
131
+ end_ds_idx = start_ds_idx
132
+ else:
133
+ ds.seek(e - 1)
134
+ end_ds_idx = ds.current_index
135
+ block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset
136
+ block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index
137
+ block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset
138
+ return block_to_dataset_index
139
+
140
+
141
+ cdef class DatasetSearcher(object):
142
+ """Helper for mapping "flat" indices to indices and offsets in an
143
+ underlying dataset."""
144
+ cdef DTYPE_t current_i
145
+ cdef DTYPE_t current_offset
146
+ cdef DTYPE_t current_index
147
+ cdef DTYPE_t[:] sizes
148
+
149
+ def __init__(self, DTYPE_t[:] sizes):
150
+ self.sizes = sizes
151
+ self.reset()
152
+
153
+ cdef reset(self):
154
+ self.current_offset = 0 # offset within current index in underlying dataset
155
+ self.current_i = 0 # "flat" index
156
+ self.current_index = 0 # index in underlying dataset
157
+
158
+ @cython.boundscheck(False)
159
+ @cython.wraparound(False)
160
+ @cython.nonecheck(False)
161
+ cdef int step(self, DTYPE_t i):
162
+ cdef DTYPE_t to_consume
163
+ cdef DTYPE_t remaining
164
+ if i < self.current_i:
165
+ self.reset()
166
+ if i > self.current_i:
167
+ to_consume = i - self.current_i
168
+ remaining = self.sizes[self.current_index] - self.current_offset
169
+ if remaining > to_consume:
170
+ self.current_offset += to_consume
171
+ self.current_i += to_consume
172
+ else:
173
+ assert remaining >= 0
174
+ self.current_i += remaining
175
+ self.current_index += 1
176
+ self.current_offset = 0
177
+ return 1
178
+ return 0
179
+
180
+ @cython.boundscheck(False)
181
+ @cython.wraparound(False)
182
+ @cython.nonecheck(False)
183
+ cdef seek(self, DTYPE_t i):
184
+ cdef int not_done = 1
185
+ while not_done == 1:
186
+ not_done = self.step(i)
187
+ assert self.current_i == i
fairseq-0.10.2/fairseq/data/transform_eos_dataset.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import FairseqDataset
9
+
10
+
11
+ class TransformEosDataset(FairseqDataset):
12
+ """A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
13
+
14
+ Note that the transformation is applied in :func:`collater`.
15
+
16
+ Args:
17
+ dataset (~fairseq.data.FairseqDataset): dataset to wrap
18
+ eos (int): index of the end-of-sentence symbol
19
+ append_eos_to_src (bool, optional): append EOS to the end of src
20
+ remove_eos_from_src (bool, optional): remove EOS from the end of src
21
+ append_eos_to_tgt (bool, optional): append EOS to the end of tgt
22
+ remove_eos_from_tgt (bool, optional): remove EOS from the end of tgt
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ dataset,
28
+ eos,
29
+ append_eos_to_src=False,
30
+ remove_eos_from_src=False,
31
+ append_eos_to_tgt=False,
32
+ remove_eos_from_tgt=False,
33
+ has_target=True,
34
+ ):
35
+ if not isinstance(dataset, FairseqDataset):
36
+ raise ValueError("dataset must be an instance of FairseqDataset")
37
+ if append_eos_to_src and remove_eos_from_src:
38
+ raise ValueError("cannot combine append_eos_to_src and remove_eos_from_src")
39
+ if append_eos_to_tgt and remove_eos_from_tgt:
40
+ raise ValueError("cannot combine append_eos_to_tgt and remove_eos_from_tgt")
41
+
42
+ self.dataset = dataset
43
+ self.eos = torch.LongTensor([eos])
44
+ self.append_eos_to_src = append_eos_to_src
45
+ self.remove_eos_from_src = remove_eos_from_src
46
+ self.append_eos_to_tgt = append_eos_to_tgt
47
+ self.remove_eos_from_tgt = remove_eos_from_tgt
48
+ self.has_target = has_target
49
+
50
+ # precompute how we should adjust the reported sizes
51
+ self._src_delta = 0
52
+ self._src_delta += 1 if append_eos_to_src else 0
53
+ self._src_delta -= 1 if remove_eos_from_src else 0
54
+ self._tgt_delta = 0
55
+ self._tgt_delta += 1 if append_eos_to_tgt else 0
56
+ self._tgt_delta -= 1 if remove_eos_from_tgt else 0
57
+
58
+ self._checked_src = False
59
+ self._checked_tgt = False
60
+
61
+ def _check_src(self, src, expect_eos):
62
+ if not self._checked_src:
63
+ assert (src[-1] == self.eos[0]) == expect_eos
64
+ self._checked_src = True
65
+
66
+ def _check_tgt(self, tgt, expect_eos):
67
+ if self.has_target and not self._checked_tgt:
68
+ assert (tgt[-1] == self.eos[0]) == expect_eos
69
+ self._checked_tgt = True
70
+
71
+ def __getitem__(self, index):
72
+ return self.dataset[index]
73
+
74
+ def __len__(self):
75
+ return len(self.dataset)
76
+
77
+ def collater(self, samples):
78
+ def transform(item):
79
+ if self.append_eos_to_src:
80
+ self.eos = self.eos.to(device=item["source"].device)
81
+ self._check_src(item["source"], expect_eos=False)
82
+ item["source"] = torch.cat([item["source"], self.eos])
83
+ if self.remove_eos_from_src:
84
+ self.eos = self.eos.to(device=item["source"].device)
85
+ self._check_src(item["source"], expect_eos=True)
86
+ item["source"] = item["source"][:-1]
87
+ if self.append_eos_to_tgt:
88
+ self.eos = self.eos.to(device=item["target"].device)
89
+ self._check_tgt(item["target"], expect_eos=False)
90
+ item["target"] = torch.cat([item["target"], self.eos])
91
+ if self.remove_eos_from_tgt:
92
+ self.eos = self.eos.to(device=item["target"].device)
93
+ self._check_tgt(item["target"], expect_eos=True)
94
+ item["target"] = item["target"][:-1]
95
+ return item
96
+
97
+ samples = list(map(transform, samples))
98
+ return self.dataset.collater(samples)
99
+
100
+ def num_tokens(self, index):
101
+ return self.dataset.num_tokens(index)
102
+
103
+ def size(self, index):
104
+ if self.has_target:
105
+ src_len, tgt_len = self.dataset.size(index)
106
+ return (src_len + self._src_delta, tgt_len + self._tgt_delta)
107
+ else:
108
+ return self.dataset.size(index)
109
+
110
+ def ordered_indices(self):
111
+ # NOTE: we assume that the ordering does not change based on the
112
+ # addition or removal of eos
113
+ return self.dataset.ordered_indices()
114
+
115
+ @property
116
+ def supports_prefetch(self):
117
+ return getattr(self.dataset, "supports_prefetch", False)
118
+
119
+ def prefetch(self, indices):
120
+ return self.dataset.prefetch(indices)