sleepyhead111 commited on
Commit
43a1e46
·
verified ·
1 Parent(s): 491a2bd

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. fairseq-0.10.2/fairseq/data/__pycache__/__init__.cpython-310.pyc +0 -0
  3. fairseq-0.10.2/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc +0 -0
  4. fairseq-0.10.2/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc +0 -0
  5. fairseq-0.10.2/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc +0 -0
  6. fairseq-0.10.2/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc +0 -0
  7. fairseq-0.10.2/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc +0 -0
  8. fairseq-0.10.2/fairseq/data/__pycache__/dictionary.cpython-310.pyc +0 -0
  9. fairseq-0.10.2/fairseq/data/__pycache__/id_dataset.cpython-310.pyc +0 -0
  10. fairseq-0.10.2/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc +0 -0
  11. fairseq-0.10.2/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc +0 -0
  12. fairseq-0.10.2/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc +0 -0
  13. fairseq-0.10.2/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc +0 -0
  14. fairseq-0.10.2/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc +0 -0
  15. fairseq-0.10.2/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc +0 -0
  16. fairseq-0.10.2/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc +0 -0
  17. fairseq-0.10.2/fairseq/data/concat_dataset.py +124 -0
  18. fairseq-0.10.2/fairseq/data/concat_sentences_dataset.py +54 -0
  19. fairseq-0.10.2/fairseq/data/dictionary.py +387 -0
  20. fairseq-0.10.2/fairseq/data/encoders/byte_utils.py +51 -0
  21. fairseq-0.10.2/fairseq/data/encoders/fastbpe.py +35 -0
  22. fairseq-0.10.2/fairseq/data/encoders/gpt2_bpe_utils.py +140 -0
  23. fairseq-0.10.2/fairseq/data/encoders/moses_tokenizer.py +51 -0
  24. fairseq-0.10.2/fairseq/data/encoders/nltk_tokenizer.py +23 -0
  25. fairseq-0.10.2/fairseq/data/encoders/subword_nmt_bpe.py +54 -0
  26. fairseq-0.10.2/fairseq/data/encoders/utils.py +30 -0
  27. fairseq-0.10.2/fairseq/data/fairseq_dataset.py +191 -0
  28. fairseq-0.10.2/fairseq/data/fasta_dataset.py +107 -0
  29. fairseq-0.10.2/fairseq/data/id_dataset.py +19 -0
  30. fairseq-0.10.2/fairseq/data/language_pair_dataset.py +475 -0
  31. fairseq-0.10.2/fairseq/data/legacy/__init__.py +16 -0
  32. fairseq-0.10.2/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc +0 -0
  33. fairseq-0.10.2/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc +0 -0
  34. fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc +0 -0
  35. fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc +0 -0
  36. fairseq-0.10.2/fairseq/data/legacy/block_pair_dataset.py +311 -0
  37. fairseq-0.10.2/fairseq/data/legacy/masked_lm_dictionary.py +60 -0
  38. fairseq-0.10.2/fairseq/data/list_dataset.py +32 -0
  39. fairseq-0.10.2/fairseq/data/lru_cache_dataset.py +21 -0
  40. fairseq-0.10.2/fairseq/data/mask_tokens_dataset.py +178 -0
  41. fairseq-0.10.2/fairseq/data/multi_corpus_dataset.py +159 -0
  42. fairseq-0.10.2/fairseq/data/multi_corpus_sampled_dataset.py +145 -0
  43. fairseq-0.10.2/fairseq/data/num_samples_dataset.py +17 -0
  44. fairseq-0.10.2/fairseq/data/offset_tokens_dataset.py +15 -0
  45. fairseq-0.10.2/fairseq/data/pad_dataset.py +28 -0
  46. fairseq-0.10.2/fairseq/data/prepend_dataset.py +28 -0
  47. fairseq-0.10.2/fairseq/data/resampling_dataset.py +139 -0
  48. fairseq-0.10.2/fairseq/data/shorten_dataset.py +78 -0
  49. fairseq-0.10.2/fairseq/data/strip_token_dataset.py +20 -0
  50. fairseq-0.10.2/fairseq/data/token_block_dataset.py +168 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ fairseq-0.10.2/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
fairseq-0.10.2/fairseq/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.55 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc ADDED
Binary file (2.38 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc ADDED
Binary file (1.4 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc ADDED
Binary file (2.49 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc ADDED
Binary file (1.21 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc ADDED
Binary file (5.04 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/dictionary.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/id_dataset.cpython-310.pyc ADDED
Binary file (784 Bytes). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc ADDED
Binary file (6.33 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc ADDED
Binary file (5.36 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc ADDED
Binary file (759 Bytes). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc ADDED
Binary file (4.41 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc ADDED
Binary file (777 Bytes). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc ADDED
Binary file (5.01 kB). View file
 
fairseq-0.10.2/fairseq/data/concat_dataset.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import bisect
7
+
8
+ import numpy as np
9
+ from torch.utils.data.dataloader import default_collate
10
+
11
+ from . import FairseqDataset
12
+
13
+
14
+ class ConcatDataset(FairseqDataset):
15
+ @staticmethod
16
+ def cumsum(sequence, sample_ratios):
17
+ r, s = [], 0
18
+ for e, ratio in zip(sequence, sample_ratios):
19
+ curr_len = int(ratio * len(e))
20
+ r.append(curr_len + s)
21
+ s += curr_len
22
+ return r
23
+
24
+ def __init__(self, datasets, sample_ratios=1):
25
+ super(ConcatDataset, self).__init__()
26
+ assert len(datasets) > 0, "datasets should not be an empty iterable"
27
+ self.datasets = list(datasets)
28
+ if isinstance(sample_ratios, int):
29
+ sample_ratios = [sample_ratios] * len(self.datasets)
30
+ self.sample_ratios = sample_ratios
31
+ self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
32
+ self.real_sizes = [len(d) for d in self.datasets]
33
+
34
+ def __len__(self):
35
+ return self.cumulative_sizes[-1]
36
+
37
+ def __getitem__(self, idx):
38
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
39
+ return self.datasets[dataset_idx][sample_idx]
40
+
41
+ def _get_dataset_and_sample_index(self, idx: int):
42
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
43
+ if dataset_idx == 0:
44
+ sample_idx = idx
45
+ else:
46
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
47
+ sample_idx = sample_idx % self.real_sizes[dataset_idx]
48
+ return dataset_idx, sample_idx
49
+
50
+ def collater(self, samples, **extra_args):
51
+ # For now only supports datasets with same underlying collater implementations
52
+ if hasattr(self.datasets[0], "collater"):
53
+ return self.datasets[0].collater(samples, **extra_args)
54
+ else:
55
+ return default_collate(samples, **extra_args)
56
+
57
+ def size(self, idx: int):
58
+ """
59
+ Return an example's size as a float or tuple.
60
+ """
61
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
62
+ return self.datasets[dataset_idx].size(sample_idx)
63
+
64
+ def num_tokens(self, index: int):
65
+ return np.max(self.size(index))
66
+
67
+ def attr(self, attr: str, index: int):
68
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
69
+ return getattr(self.datasets[dataset_idx], attr, None)
70
+
71
+ @property
72
+ def sizes(self):
73
+ _dataset_sizes = []
74
+ for ds, sr in zip(self.datasets, self.sample_ratios):
75
+ if isinstance(ds.sizes, np.ndarray):
76
+ _dataset_sizes.append(np.tile(ds.sizes, sr))
77
+ else:
78
+ # Only support underlying dataset with single size array.
79
+ assert isinstance(ds.sizes, list)
80
+ _dataset_sizes.append(np.tile(ds.sizes[0], sr))
81
+ return np.concatenate(_dataset_sizes)
82
+
83
+ @property
84
+ def supports_prefetch(self):
85
+ return all(d.supports_prefetch for d in self.datasets)
86
+
87
+ def ordered_indices(self):
88
+ """
89
+ Returns indices sorted by length. So less padding is needed.
90
+ """
91
+ if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
92
+ # special handling for concatenating lang_pair_datasets
93
+ indices = np.arange(len(self))
94
+ sizes = self.sizes
95
+ tgt_sizes = (
96
+ sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
97
+ )
98
+ src_sizes = (
99
+ sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
100
+ )
101
+ # sort by target length, then source length
102
+ if tgt_sizes is not None:
103
+ indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
104
+ return indices[np.argsort(src_sizes[indices], kind="mergesort")]
105
+ else:
106
+ return np.argsort(self.sizes)
107
+
108
+ def prefetch(self, indices):
109
+ frm = 0
110
+ for to, ds in zip(self.cumulative_sizes, self.datasets):
111
+ real_size = len(ds)
112
+ if getattr(ds, "supports_prefetch", False):
113
+ ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
114
+ frm = to
115
+
116
+ @property
117
+ def can_reuse_epoch_itr_across_epochs(self):
118
+ return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
119
+
120
+ def set_epoch(self, epoch):
121
+ super().set_epoch(epoch)
122
+ for ds in self.datasets:
123
+ if hasattr(ds, "set_epoch"):
124
+ ds.set_epoch(epoch)
fairseq-0.10.2/fairseq/data/concat_sentences_dataset.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import FairseqDataset
9
+
10
+
11
+ class ConcatSentencesDataset(FairseqDataset):
12
+ def __init__(self, *datasets):
13
+ super().__init__()
14
+ self.datasets = datasets
15
+ assert all(
16
+ len(ds) == len(datasets[0]) for ds in datasets
17
+ ), "datasets must have the same length"
18
+
19
+ def __getitem__(self, index):
20
+ return torch.cat([ds[index] for ds in self.datasets])
21
+
22
+ def __len__(self):
23
+ return len(self.datasets[0])
24
+
25
+ def collater(self, samples):
26
+ return self.datasets[0].collater(samples)
27
+
28
+ @property
29
+ def sizes(self):
30
+ return sum(ds.sizes for ds in self.datasets)
31
+
32
+ def num_tokens(self, index):
33
+ return sum(ds.num_tokens(index) for ds in self.datasets)
34
+
35
+ def size(self, index):
36
+ return sum(ds.size(index) for ds in self.datasets)
37
+
38
+ def ordered_indices(self):
39
+ return self.datasets[0].ordered_indices()
40
+
41
+ @property
42
+ def supports_prefetch(self):
43
+ return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)
44
+
45
+ def prefetch(self, indices):
46
+ for ds in self.datasets:
47
+ if getattr(ds, "supports_prefetch", False):
48
+ ds.prefetch(indices)
49
+
50
+ def set_epoch(self, epoch):
51
+ super().set_epoch(epoch)
52
+ for ds in self.datasets:
53
+ if hasattr(ds, "set_epoch"):
54
+ ds.set_epoch(epoch)
fairseq-0.10.2/fairseq/data/dictionary.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from collections import Counter
8
+ from multiprocessing import Pool
9
+
10
+ import torch
11
+ from fairseq import utils
12
+ from fairseq.binarizer import safe_readline
13
+ from fairseq.data import data_utils
14
+ from fairseq.file_io import PathManager
15
+ from fairseq.tokenizer import tokenize_line
16
+
17
+
18
+ class Dictionary(object):
19
+ """A mapping from symbols to consecutive integers"""
20
+
21
+ def __init__(
22
+ self,
23
+ *, # begin keyword-only arguments
24
+ bos="<s>",
25
+ pad="<pad>",
26
+ eos="</s>",
27
+ unk="<unk>",
28
+ extra_special_symbols=None,
29
+ ):
30
+ self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
31
+ self.symbols = []
32
+ self.count = []
33
+ self.indices = {}
34
+ self.bos_index = self.add_symbol(bos)
35
+ self.pad_index = self.add_symbol(pad)
36
+ self.eos_index = self.add_symbol(eos)
37
+ self.unk_index = self.add_symbol(unk)
38
+ if extra_special_symbols:
39
+ for s in extra_special_symbols:
40
+ self.add_symbol(s)
41
+ self.nspecial = len(self.symbols)
42
+
43
+ def __eq__(self, other):
44
+ return self.indices == other.indices
45
+
46
+ def __getitem__(self, idx):
47
+ if idx < len(self.symbols):
48
+ return self.symbols[idx]
49
+ return self.unk_word
50
+
51
+ def __len__(self):
52
+ """Returns the number of symbols in the dictionary"""
53
+ return len(self.symbols)
54
+
55
+ def __contains__(self, sym):
56
+ return sym in self.indices
57
+
58
+ def index(self, sym):
59
+ """Returns the index of the specified symbol"""
60
+ assert isinstance(sym, str)
61
+ if sym in self.indices:
62
+ return self.indices[sym]
63
+ return self.unk_index
64
+
65
+ def string(
66
+ self,
67
+ tensor,
68
+ bpe_symbol=None,
69
+ escape_unk=False,
70
+ extra_symbols_to_ignore=None,
71
+ unk_string=None,
72
+ ):
73
+ """Helper for converting a tensor of token indices to a string.
74
+
75
+ Can optionally remove BPE symbols or escape <unk> words.
76
+ """
77
+ if torch.is_tensor(tensor) and tensor.dim() == 2:
78
+ return "\n".join(
79
+ self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore)
80
+ for t in tensor
81
+ )
82
+
83
+ extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
84
+ extra_symbols_to_ignore.add(self.eos())
85
+
86
+ def token_string(i):
87
+ if i == self.unk():
88
+ if unk_string is not None:
89
+ return unk_string
90
+ else:
91
+ return self.unk_string(escape_unk)
92
+ else:
93
+ return self[i]
94
+
95
+ if hasattr(self, "bos_index"):
96
+ extra_symbols_to_ignore.add(self.bos())
97
+
98
+ sent = " ".join(
99
+ token_string(i)
100
+ for i in tensor
101
+ if utils.item(i) not in extra_symbols_to_ignore
102
+ )
103
+
104
+ return data_utils.post_process(sent, bpe_symbol)
105
+
106
+ def unk_string(self, escape=False):
107
+ """Return unknown string, optionally escaped as: <<unk>>"""
108
+ if escape:
109
+ return "<{}>".format(self.unk_word)
110
+ else:
111
+ return self.unk_word
112
+
113
+ def add_symbol(self, word, n=1, overwrite=False):
114
+ """Adds a word to the dictionary"""
115
+ if word in self.indices and not overwrite:
116
+ idx = self.indices[word]
117
+ self.count[idx] = self.count[idx] + n
118
+ return idx
119
+ else:
120
+ idx = len(self.symbols)
121
+ self.indices[word] = idx
122
+ self.symbols.append(word)
123
+ self.count.append(n)
124
+ return idx
125
+
126
+ def update(self, new_dict):
127
+ """Updates counts from new dictionary."""
128
+ for word in new_dict.symbols:
129
+ idx2 = new_dict.indices[word]
130
+ if word in self.indices:
131
+ idx = self.indices[word]
132
+ self.count[idx] = self.count[idx] + new_dict.count[idx2]
133
+ else:
134
+ idx = len(self.symbols)
135
+ self.indices[word] = idx
136
+ self.symbols.append(word)
137
+ self.count.append(new_dict.count[idx2])
138
+
139
+ def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
140
+ """Sort symbols by frequency in descending order, ignoring special ones.
141
+
142
+ Args:
143
+ - threshold defines the minimum word count
144
+ - nwords defines the total number of words in the final dictionary,
145
+ including special symbols
146
+ - padding_factor can be used to pad the dictionary size to be a
147
+ multiple of 8, which is important on some hardware (e.g., Nvidia
148
+ Tensor Cores).
149
+ """
150
+ if nwords <= 0:
151
+ nwords = len(self)
152
+
153
+ new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial)))
154
+ new_symbols = self.symbols[: self.nspecial]
155
+ new_count = self.count[: self.nspecial]
156
+
157
+ c = Counter(
158
+ dict(
159
+ sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :]))
160
+ )
161
+ )
162
+ for symbol, count in c.most_common(nwords - self.nspecial):
163
+ if count >= threshold:
164
+ new_indices[symbol] = len(new_symbols)
165
+ new_symbols.append(symbol)
166
+ new_count.append(count)
167
+ else:
168
+ break
169
+
170
+ assert len(new_symbols) == len(new_indices)
171
+
172
+ self.count = list(new_count)
173
+ self.symbols = list(new_symbols)
174
+ self.indices = new_indices
175
+
176
+ self.pad_to_multiple_(padding_factor)
177
+
178
+ def pad_to_multiple_(self, padding_factor):
179
+ """Pad Dictionary size to be a multiple of *padding_factor*."""
180
+ if padding_factor > 1:
181
+ i = 0
182
+ while len(self) % padding_factor != 0:
183
+ symbol = "madeupword{:04d}".format(i)
184
+ self.add_symbol(symbol, n=0)
185
+ i += 1
186
+
187
+ def bos(self):
188
+ """Helper to get index of beginning-of-sentence symbol"""
189
+ return self.bos_index
190
+
191
+ def pad(self):
192
+ """Helper to get index of pad symbol"""
193
+ return self.pad_index
194
+
195
+ def eos(self):
196
+ """Helper to get index of end-of-sentence symbol"""
197
+ return self.eos_index
198
+
199
+ def unk(self):
200
+ """Helper to get index of unk symbol"""
201
+ return self.unk_index
202
+
203
+ @classmethod
204
+ def load(cls, f):
205
+ """Loads the dictionary from a text file with the format:
206
+
207
+ ```
208
+ <symbol0> <count0>
209
+ <symbol1> <count1>
210
+ ...
211
+ ```
212
+ """
213
+ d = cls()
214
+ d.add_from_file(f)
215
+ return d
216
+
217
+ def add_from_file(self, f):
218
+ """
219
+ Loads a pre-existing dictionary from a text file and adds its symbols
220
+ to this instance.
221
+ """
222
+ if isinstance(f, str):
223
+ try:
224
+ with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
225
+ self.add_from_file(fd)
226
+ except FileNotFoundError as fnfe:
227
+ raise fnfe
228
+ except UnicodeError:
229
+ raise Exception(
230
+ "Incorrect encoding detected in {}, please "
231
+ "rebuild the dataset".format(f)
232
+ )
233
+ return
234
+
235
+ lines = f.readlines()
236
+ indices_start_line = self._load_meta(lines)
237
+
238
+ for line in lines[indices_start_line:]:
239
+ try:
240
+ line, field = line.rstrip().rsplit(" ", 1)
241
+ if field == "#fairseq:overwrite":
242
+ overwrite = True
243
+ line, field = line.rsplit(" ", 1)
244
+ else:
245
+ overwrite = False
246
+ count = int(field)
247
+ word = line
248
+ if word in self and not overwrite:
249
+ raise RuntimeError(
250
+ "Duplicate word found when loading Dictionary: '{}'. "
251
+ "Duplicate words can overwrite earlier ones by adding the "
252
+ "#fairseq:overwrite flag at the end of the corresponding row "
253
+ "in the dictionary file. If using the Camembert model, please "
254
+ "download an updated copy of the model file.".format(word)
255
+ )
256
+ self.add_symbol(word, n=count, overwrite=overwrite)
257
+ except ValueError:
258
+ raise ValueError(
259
+ "Incorrect dictionary format, expected '<token> <cnt> [flags]'"
260
+ )
261
+
262
+ def _save(self, f, kv_iterator):
263
+ if isinstance(f, str):
264
+ PathManager.mkdirs(os.path.dirname(f))
265
+ with PathManager.open(f, "w", encoding="utf-8") as fd:
266
+ return self.save(fd)
267
+ for k, v in kv_iterator:
268
+ print("{} {}".format(k, v), file=f)
269
+
270
+ def _get_meta(self):
271
+ return [], []
272
+
273
+ def _load_meta(self, lines):
274
+ return 0
275
+
276
+ def save(self, f):
277
+ """Stores dictionary into a text file"""
278
+ ex_keys, ex_vals = self._get_meta()
279
+ self._save(
280
+ f,
281
+ zip(
282
+ ex_keys + self.symbols[self.nspecial :],
283
+ ex_vals + self.count[self.nspecial :],
284
+ ),
285
+ )
286
+
287
+ def dummy_sentence(self, length):
288
+ t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
289
+ t[-1] = self.eos()
290
+ return t
291
+
292
+ def encode_line(
293
+ self,
294
+ line,
295
+ line_tokenizer=tokenize_line,
296
+ add_if_not_exist=True,
297
+ consumer=None,
298
+ append_eos=True,
299
+ reverse_order=False,
300
+ ):
301
+ words = line_tokenizer(line)
302
+ if reverse_order:
303
+ words = list(reversed(words))
304
+ nwords = len(words)
305
+ ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
306
+
307
+ for i, word in enumerate(words):
308
+ if add_if_not_exist:
309
+ idx = self.add_symbol(word)
310
+ else:
311
+ idx = self.index(word)
312
+ if consumer is not None:
313
+ consumer(word, idx)
314
+ ids[i] = idx
315
+ if append_eos:
316
+ ids[nwords] = self.eos_index
317
+ return ids
318
+
319
+ @staticmethod
320
+ def _add_file_to_dictionary_single_worker(
321
+ filename, tokenize, eos_word, worker_id=0, num_workers=1
322
+ ):
323
+ counter = Counter()
324
+ with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f:
325
+ size = os.fstat(f.fileno()).st_size
326
+ chunk_size = size // num_workers
327
+ offset = worker_id * chunk_size
328
+ end = offset + chunk_size
329
+ f.seek(offset)
330
+ if offset > 0:
331
+ safe_readline(f) # drop first incomplete line
332
+ line = f.readline()
333
+ while line:
334
+ for word in tokenize(line):
335
+ counter.update([word])
336
+ counter.update([eos_word])
337
+ if f.tell() > end:
338
+ break
339
+ line = f.readline()
340
+ return counter
341
+
342
+ @staticmethod
343
+ def add_file_to_dictionary(filename, dict, tokenize, num_workers):
344
+ def merge_result(counter):
345
+ for w, c in sorted(counter.items()):
346
+ dict.add_symbol(w, c)
347
+
348
+ if num_workers > 1:
349
+ pool = Pool(processes=num_workers)
350
+ results = []
351
+ for worker_id in range(num_workers):
352
+ results.append(
353
+ pool.apply_async(
354
+ Dictionary._add_file_to_dictionary_single_worker,
355
+ (filename, tokenize, dict.eos_word, worker_id, num_workers),
356
+ )
357
+ )
358
+ pool.close()
359
+ pool.join()
360
+ for r in results:
361
+ merge_result(r.get())
362
+ else:
363
+ merge_result(
364
+ Dictionary._add_file_to_dictionary_single_worker(
365
+ filename, tokenize, dict.eos_word
366
+ )
367
+ )
368
+
369
+
370
+ class TruncatedDictionary(object):
371
+ def __init__(self, wrapped_dict, length):
372
+ self.__class__ = type(
373
+ wrapped_dict.__class__.__name__,
374
+ (self.__class__, wrapped_dict.__class__),
375
+ {},
376
+ )
377
+ self.__dict__ = wrapped_dict.__dict__
378
+ self.wrapped_dict = wrapped_dict
379
+ self.length = min(len(self.wrapped_dict), length)
380
+
381
+ def __len__(self):
382
+ return self.length
383
+
384
+ def __getitem__(self, i):
385
+ if i < self.length:
386
+ return self.wrapped_dict[i]
387
+ return self.wrapped_dict.unk()
fairseq-0.10.2/fairseq/data/encoders/byte_utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+
9
+ WHITESPACE_NORMALIZER = re.compile(r"\s+")
10
+ SPACE = chr(32)
11
+ SPACE_ESCAPE = chr(9601)
12
+ # excluding non-breaking space (160) here
13
+ PRINTABLE_LATIN = set(
14
+ list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1))
15
+ )
16
+ BYTE_TO_BCHAR = {
17
+ b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256)
18
+ }
19
+ BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
20
+
21
+
22
+ def byte_encode(x: str) -> str:
23
+ normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
24
+ return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
25
+
26
+
27
+ def byte_decode(x: str) -> str:
28
+ try:
29
+ return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
30
+ except ValueError:
31
+ return ""
32
+
33
+
34
+ def smart_byte_decode(x: str) -> str:
35
+ output = byte_decode(x)
36
+ if output == "":
37
+ # DP the best recovery (max valid chars) if it's broken
38
+ n_bytes = len(x)
39
+ f = [0 for _ in range(n_bytes + 1)]
40
+ pt = [0 for _ in range(n_bytes + 1)]
41
+ for i in range(1, n_bytes + 1):
42
+ f[i], pt[i] = f[i - 1], i - 1
43
+ for j in range(1, min(4, i) + 1):
44
+ if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
45
+ f[i], pt[i] = f[i - j] + 1, i - j
46
+ cur_pt = n_bytes
47
+ while cur_pt > 0:
48
+ if f[cur_pt] == f[pt[cur_pt]] + 1:
49
+ output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
50
+ cur_pt = pt[cur_pt]
51
+ return output
fairseq-0.10.2/fairseq/data/encoders/fastbpe.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq import file_utils
7
+ from fairseq.data.encoders import register_bpe
8
+
9
+
10
+ @register_bpe("fastbpe")
11
+ class fastBPE(object):
12
+ @staticmethod
13
+ def add_args(parser):
14
+ # fmt: off
15
+ parser.add_argument('--bpe-codes', type=str,
16
+ help='path to fastBPE BPE')
17
+ # fmt: on
18
+
19
+ def __init__(self, args):
20
+ if args.bpe_codes is None:
21
+ raise ValueError("--bpe-codes is required for --bpe=fastbpe")
22
+ codes = file_utils.cached_path(args.bpe_codes)
23
+ try:
24
+ import fastBPE
25
+
26
+ self.bpe = fastBPE.fastBPE(codes)
27
+ self.bpe_symbol = "@@ "
28
+ except ImportError:
29
+ raise ImportError("Please install fastBPE with: pip install fastBPE")
30
+
31
+ def encode(self, x: str) -> str:
32
+ return self.bpe.apply([x])[0]
33
+
34
+ def decode(self, x: str) -> str:
35
+ return (x + " ").replace(self.bpe_symbol, "").rstrip()
fairseq-0.10.2/fairseq/data/encoders/gpt2_bpe_utils.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte pair encoding utilities from GPT-2.
3
+
4
+ Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
5
+ Original license: MIT
6
+ """
7
+
8
+ import json
9
+ from functools import lru_cache
10
+
11
+
12
+ @lru_cache()
13
+ def bytes_to_unicode():
14
+ """
15
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
16
+ The reversible bpe codes work on unicode strings.
17
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
18
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
19
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
20
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
21
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
22
+ """
23
+ bs = (
24
+ list(range(ord("!"), ord("~") + 1))
25
+ + list(range(ord("¡"), ord("¬") + 1))
26
+ + list(range(ord("®"), ord("ÿ") + 1))
27
+ )
28
+ cs = bs[:]
29
+ n = 0
30
+ for b in range(2 ** 8):
31
+ if b not in bs:
32
+ bs.append(b)
33
+ cs.append(2 ** 8 + n)
34
+ n += 1
35
+ cs = [chr(n) for n in cs]
36
+ return dict(zip(bs, cs))
37
+
38
+
39
+ def get_pairs(word):
40
+ """Return set of symbol pairs in a word.
41
+ Word is represented as tuple of symbols (symbols being variable-length strings).
42
+ """
43
+ pairs = set()
44
+ prev_char = word[0]
45
+ for char in word[1:]:
46
+ pairs.add((prev_char, char))
47
+ prev_char = char
48
+ return pairs
49
+
50
+
51
+ class Encoder:
52
+ def __init__(self, encoder, bpe_merges, errors="replace"):
53
+ self.encoder = encoder
54
+ self.decoder = {v: k for k, v in self.encoder.items()}
55
+ self.errors = errors # how to handle errors in decoding
56
+ self.byte_encoder = bytes_to_unicode()
57
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
58
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
59
+ self.cache = {}
60
+
61
+ try:
62
+ import regex as re
63
+
64
+ self.re = re
65
+ except ImportError:
66
+ raise ImportError("Please install regex with: pip install regex")
67
+
68
+ # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
69
+ self.pat = self.re.compile(
70
+ r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
71
+ )
72
+
73
+ def bpe(self, token):
74
+ if token in self.cache:
75
+ return self.cache[token]
76
+ word = tuple(token)
77
+ pairs = get_pairs(word)
78
+
79
+ if not pairs:
80
+ return token
81
+
82
+ while True:
83
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
84
+ if bigram not in self.bpe_ranks:
85
+ break
86
+ first, second = bigram
87
+ new_word = []
88
+ i = 0
89
+ while i < len(word):
90
+ try:
91
+ j = word.index(first, i)
92
+ new_word.extend(word[i:j])
93
+ i = j
94
+ except:
95
+ new_word.extend(word[i:])
96
+ break
97
+
98
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
99
+ new_word.append(first + second)
100
+ i += 2
101
+ else:
102
+ new_word.append(word[i])
103
+ i += 1
104
+ new_word = tuple(new_word)
105
+ word = new_word
106
+ if len(word) == 1:
107
+ break
108
+ else:
109
+ pairs = get_pairs(word)
110
+ word = " ".join(word)
111
+ self.cache[token] = word
112
+ return word
113
+
114
+ def encode(self, text):
115
+ bpe_tokens = []
116
+ for token in self.re.findall(self.pat, text):
117
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
118
+ bpe_tokens.extend(
119
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
120
+ )
121
+ return bpe_tokens
122
+
123
+ def decode(self, tokens):
124
+ text = "".join([self.decoder.get(token, token) for token in tokens])
125
+ text = bytearray([self.byte_decoder[c] for c in text]).decode(
126
+ "utf-8", errors=self.errors
127
+ )
128
+ return text
129
+
130
+
131
+ def get_encoder(encoder_json_path, vocab_bpe_path):
132
+ with open(encoder_json_path, "r") as f:
133
+ encoder = json.load(f)
134
+ with open(vocab_bpe_path, "r", encoding="utf-8") as f:
135
+ bpe_data = f.read()
136
+ bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
137
+ return Encoder(
138
+ encoder=encoder,
139
+ bpe_merges=bpe_merges,
140
+ )
fairseq-0.10.2/fairseq/data/encoders/moses_tokenizer.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq.data.encoders import register_tokenizer
7
+
8
+
9
+ @register_tokenizer("moses")
10
+ class MosesTokenizer(object):
11
+ @staticmethod
12
+ def add_args(parser):
13
+ # fmt: off
14
+ parser.add_argument('--moses-source-lang', metavar='SRC',
15
+ help='source language')
16
+ parser.add_argument('--moses-target-lang', metavar='TARGET',
17
+ help='target language')
18
+ parser.add_argument('--moses-no-dash-splits', action='store_true', default=False,
19
+ help='don\'t apply dash split rules')
20
+ parser.add_argument('--moses-no-escape', action='store_true', default=False,
21
+ help='don\'t perform HTML escaping on apostrophy, quotes, etc.')
22
+ # fmt: on
23
+
24
+ def __init__(self, args):
25
+ self.args = args
26
+
27
+ if getattr(args, "moses_source_lang", None) is None:
28
+ args.moses_source_lang = getattr(args, "source_lang", "en")
29
+ if getattr(args, "moses_target_lang", None) is None:
30
+ args.moses_target_lang = getattr(args, "target_lang", "en")
31
+
32
+ try:
33
+ from sacremoses import MosesTokenizer, MosesDetokenizer
34
+
35
+ self.tok = MosesTokenizer(args.moses_source_lang)
36
+ self.detok = MosesDetokenizer(args.moses_target_lang)
37
+ except ImportError:
38
+ raise ImportError(
39
+ "Please install Moses tokenizer with: pip install sacremoses"
40
+ )
41
+
42
+ def encode(self, x: str) -> str:
43
+ return self.tok.tokenize(
44
+ x,
45
+ aggressive_dash_splits=(not self.args.moses_no_dash_splits),
46
+ return_str=True,
47
+ escape=(not self.args.moses_no_escape),
48
+ )
49
+
50
+ def decode(self, x: str) -> str:
51
+ return self.detok.detokenize(x.split())
fairseq-0.10.2/fairseq/data/encoders/nltk_tokenizer.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq.data.encoders import register_tokenizer
7
+
8
+
9
+ @register_tokenizer("nltk")
10
+ class NLTKTokenizer(object):
11
+ def __init__(self, source_lang=None, target_lang=None):
12
+ try:
13
+ from nltk.tokenize import word_tokenize
14
+
15
+ self.word_tokenize = word_tokenize
16
+ except ImportError:
17
+ raise ImportError("Please install nltk with: pip install nltk")
18
+
19
+ def encode(self, x: str) -> str:
20
+ return " ".join(self.word_tokenize(x))
21
+
22
+ def decode(self, x: str) -> str:
23
+ return x
fairseq-0.10.2/fairseq/data/encoders/subword_nmt_bpe.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq import file_utils
7
+ from fairseq.data.encoders import register_bpe
8
+
9
+
10
+ @register_bpe("subword_nmt")
11
+ class SubwordNMTBPE(object):
12
+ @staticmethod
13
+ def add_args(parser):
14
+ # fmt: off
15
+ parser.add_argument('--bpe-codes', type=str,
16
+ help='path to subword NMT BPE')
17
+ parser.add_argument('--bpe-separator', default='@@',
18
+ help='BPE separator')
19
+ # fmt: on
20
+
21
+ def __init__(self, args):
22
+ if args.bpe_codes is None:
23
+ raise ValueError("--bpe-codes is required for --bpe=subword_nmt")
24
+ codes = file_utils.cached_path(args.bpe_codes)
25
+ try:
26
+ from subword_nmt import apply_bpe
27
+
28
+ bpe_parser = apply_bpe.create_parser()
29
+ bpe_args = bpe_parser.parse_args(
30
+ [
31
+ "--codes",
32
+ codes,
33
+ "--separator",
34
+ args.bpe_separator,
35
+ ]
36
+ )
37
+ self.bpe = apply_bpe.BPE(
38
+ bpe_args.codes,
39
+ bpe_args.merges,
40
+ bpe_args.separator,
41
+ None,
42
+ bpe_args.glossaries,
43
+ )
44
+ self.bpe_symbol = bpe_args.separator + " "
45
+ except ImportError:
46
+ raise ImportError(
47
+ "Please install subword_nmt with: pip install subword-nmt"
48
+ )
49
+
50
+ def encode(self, x: str) -> str:
51
+ return self.bpe.process_line(x)
52
+
53
+ def decode(self, x: str) -> str:
54
+ return (x + " ").replace(self.bpe_symbol, "").rstrip()
fairseq-0.10.2/fairseq/data/encoders/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from fairseq.data import encoders
8
+
9
+
10
+ def get_whole_word_mask(args, dictionary):
11
+ bpe = encoders.build_bpe(args)
12
+ if bpe is not None:
13
+
14
+ def is_beginning_of_word(i):
15
+ if i < dictionary.nspecial:
16
+ # special elements are always considered beginnings
17
+ return True
18
+ tok = dictionary[i]
19
+ if tok.startswith("madeupword"):
20
+ return True
21
+ try:
22
+ return bpe.is_beginning_of_word(tok)
23
+ except ValueError:
24
+ return True
25
+
26
+ mask_whole_words = torch.ByteTensor(
27
+ list(map(is_beginning_of_word, range(len(dictionary))))
28
+ )
29
+ return mask_whole_words
30
+ return None
fairseq-0.10.2/fairseq/data/fairseq_dataset.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch.utils.data
8
+ from fairseq.data import data_utils
9
+
10
+
11
+ class EpochListening:
12
+ """Mixin for receiving updates whenever the epoch increments."""
13
+
14
+ @property
15
+ def can_reuse_epoch_itr_across_epochs(self):
16
+ """
17
+ Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for
18
+ this dataset across epochs.
19
+
20
+ This needs to return ``False`` if the sample sizes can change across
21
+ epochs, in which case we may need to regenerate batches at each epoch.
22
+ If your dataset relies in ``set_epoch`` then you should consider setting
23
+ this to ``False``.
24
+ """
25
+ return True
26
+
27
+ def set_epoch(self, epoch):
28
+ """Will receive the updated epoch number at the beginning of the epoch."""
29
+ pass
30
+
31
+
32
+ class FairseqDataset(torch.utils.data.Dataset, EpochListening):
33
+ """A dataset that provides helpers for batching."""
34
+
35
+ def __getitem__(self, index):
36
+ raise NotImplementedError
37
+
38
+ def __len__(self):
39
+ raise NotImplementedError
40
+
41
+ def collater(self, samples):
42
+ """Merge a list of samples to form a mini-batch.
43
+
44
+ Args:
45
+ samples (List[dict]): samples to collate
46
+
47
+ Returns:
48
+ dict: a mini-batch suitable for forwarding with a Model
49
+ """
50
+ raise NotImplementedError
51
+
52
+ def num_tokens(self, index):
53
+ """Return the number of tokens in a sample. This value is used to
54
+ enforce ``--max-tokens`` during batching."""
55
+ raise NotImplementedError
56
+
57
+ def size(self, index):
58
+ """Return an example's size as a float or tuple. This value is used when
59
+ filtering a dataset with ``--max-positions``."""
60
+ raise NotImplementedError
61
+
62
+ def ordered_indices(self):
63
+ """Return an ordered list of indices. Batches will be constructed based
64
+ on this order."""
65
+ return np.arange(len(self), dtype=np.int64)
66
+
67
+ @property
68
+ def supports_prefetch(self):
69
+ """Whether this dataset supports prefetching."""
70
+ return False
71
+
72
+ def attr(self, attr: str, index: int):
73
+ return getattr(self, attr, None)
74
+
75
+ def prefetch(self, indices):
76
+ """Prefetch the data required for this epoch."""
77
+ raise NotImplementedError
78
+
79
+ def get_batch_shapes(self):
80
+ """
81
+ Return a list of valid batch shapes, for example::
82
+
83
+ [(8, 512), (16, 256), (32, 128)]
84
+
85
+ The first dimension of each tuple is the batch size and can be ``None``
86
+ to automatically infer the max batch size based on ``--max-tokens``.
87
+ The second dimension of each tuple is the max supported length as given
88
+ by :func:`fairseq.data.FairseqDataset.num_tokens`.
89
+
90
+ This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size`
91
+ to restrict batch shapes. This is useful on TPUs to avoid too many
92
+ dynamic shapes (and recompilations).
93
+ """
94
+ return None
95
+
96
+ def batch_by_size(
97
+ self,
98
+ indices,
99
+ max_tokens=None,
100
+ max_sentences=None,
101
+ required_batch_size_multiple=1,
102
+ ):
103
+ """
104
+ Given an ordered set of indices, return batches according to
105
+ *max_tokens*, *max_sentences* and *required_batch_size_multiple*.
106
+ """
107
+ from fairseq.data import data_utils
108
+
109
+ fixed_shapes = self.get_batch_shapes()
110
+ if fixed_shapes is not None:
111
+
112
+ def adjust_bsz(bsz, num_tokens):
113
+ if bsz is None:
114
+ assert max_tokens is not None, "Must specify --max-tokens"
115
+ bsz = max_tokens // num_tokens
116
+ if max_sentences is not None:
117
+ bsz = min(bsz, max_sentences)
118
+ elif (
119
+ bsz >= required_batch_size_multiple
120
+ and bsz % required_batch_size_multiple != 0
121
+ ):
122
+ bsz -= bsz % required_batch_size_multiple
123
+ return bsz
124
+
125
+ fixed_shapes = np.array(
126
+ [
127
+ [adjust_bsz(bsz, num_tokens), num_tokens]
128
+ for (bsz, num_tokens) in fixed_shapes
129
+ ]
130
+ )
131
+
132
+ return data_utils.batch_by_size(
133
+ indices,
134
+ num_tokens_fn=self.num_tokens,
135
+ max_tokens=max_tokens,
136
+ max_sentences=max_sentences,
137
+ required_batch_size_multiple=required_batch_size_multiple,
138
+ fixed_shapes=fixed_shapes,
139
+ )
140
+
141
+ def filter_indices_by_size(self, indices, max_sizes):
142
+ """
143
+ Filter a list of sample indices. Remove those that are longer than
144
+ specified in *max_sizes*.
145
+
146
+ WARNING: don't update, override method in child classes
147
+
148
+ Args:
149
+ indices (np.array): original array of sample indices
150
+ max_sizes (int or list[int] or tuple[int]): max sample size,
151
+ can be defined separately for src and tgt (then list or tuple)
152
+
153
+ Returns:
154
+ np.array: filtered sample array
155
+ list: list of removed indices
156
+ """
157
+ if isinstance(max_sizes, float) or isinstance(max_sizes, int):
158
+ if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
159
+ ignored = indices[self.sizes[indices] > max_sizes].tolist()
160
+ indices = indices[self.sizes[indices] <= max_sizes]
161
+ elif (
162
+ hasattr(self, "sizes")
163
+ and isinstance(self.sizes, list)
164
+ and len(self.sizes) == 1
165
+ ):
166
+ ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
167
+ indices = indices[self.sizes[0][indices] <= max_sizes]
168
+ else:
169
+ indices, ignored = data_utils._filter_by_size_dynamic(
170
+ indices, self.size, max_sizes
171
+ )
172
+ else:
173
+ indices, ignored = data_utils._filter_by_size_dynamic(
174
+ indices, self.size, max_sizes
175
+ )
176
+ return indices, ignored
177
+
178
+ @property
179
+ def supports_fetch_outside_dataloader(self):
180
+ """Whether this dataset supports fetching outside the workers of the dataloader."""
181
+ return True
182
+
183
+
184
+ class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening):
185
+ """
186
+ For datasets that need to be read sequentially, usually because the data is
187
+ being streamed or otherwise can't be manipulated on a single machine.
188
+ """
189
+
190
+ def __iter__(self):
191
+ raise NotImplementedError
fairseq-0.10.2/fairseq/data/fasta_dataset.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import subprocess
8
+ import threading
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+
15
+ def fasta_file_path(prefix_path):
16
+ return prefix_path + ".fasta"
17
+
18
+
19
+ class FastaDataset(torch.utils.data.Dataset):
20
+ """
21
+ For loading protein sequence datasets in the common FASTA data format
22
+ """
23
+
24
+ def __init__(self, path: str, cache_indices=False):
25
+ self.fn = fasta_file_path(path)
26
+ self.threadlocal = threading.local()
27
+ self.cache = Path(f"{path}.fasta.idx.npy")
28
+ if cache_indices:
29
+ if self.cache.exists():
30
+ self.offsets, self.sizes = np.load(self.cache)
31
+ else:
32
+ self.offsets, self.sizes = self._build_index(path)
33
+ np.save(self.cache, np.stack([self.offsets, self.sizes]))
34
+ else:
35
+ self.offsets, self.sizes = self._build_index(path)
36
+
37
+ def _get_file(self):
38
+ if not hasattr(self.threadlocal, "f"):
39
+ self.threadlocal.f = open(self.fn, "r")
40
+ return self.threadlocal.f
41
+
42
+ def __getitem__(self, idx):
43
+ f = self._get_file()
44
+ f.seek(self.offsets[idx])
45
+ desc = f.readline().strip()
46
+ line = f.readline()
47
+ seq = ""
48
+ while line != "" and line[0] != ">":
49
+ seq += line.strip()
50
+ line = f.readline()
51
+ return desc, seq
52
+
53
+ def __len__(self):
54
+ return self.offsets.size
55
+
56
+ def _build_index(self, path: str):
57
+ # Use grep and awk to get 100M/s on local SSD.
58
+ # Should process your enormous 100G fasta in ~10 min single core...
59
+ path = fasta_file_path(path)
60
+ bytes_offsets = subprocess.check_output(
61
+ f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
62
+ "| grep --byte-offset '^>' -o | cut -d: -f1",
63
+ shell=True,
64
+ )
65
+ fasta_lengths = subprocess.check_output(
66
+ f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
67
+ "| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'",
68
+ shell=True,
69
+ )
70
+ bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ")
71
+ sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ")
72
+ return bytes_np, sizes_np
73
+
74
+ def __setstate__(self, state):
75
+ self.__dict__ = state
76
+ self.threadlocal = threading.local()
77
+
78
+ def __getstate__(self):
79
+ d = {}
80
+ for i, v in self.__dict__.items():
81
+ if i != "threadlocal":
82
+ d[i] = v
83
+ return d
84
+
85
+ def __del__(self):
86
+ if hasattr(self.threadlocal, "f"):
87
+ self.threadlocal.f.close()
88
+ del self.threadlocal.f
89
+
90
+ @staticmethod
91
+ def exists(path):
92
+ return os.path.exists(fasta_file_path(path))
93
+
94
+
95
+ class EncodedFastaDataset(FastaDataset):
96
+ """
97
+ The FastaDataset returns raw sequences - this allows us to return
98
+ indices with a dictionary instead.
99
+ """
100
+
101
+ def __init__(self, path, dictionary):
102
+ super().__init__(path, cache_indices=True)
103
+ self.dictionary = dictionary
104
+
105
+ def __getitem__(self, idx):
106
+ desc, seq = super().__getitem__(idx)
107
+ return self.dictionary.encode_line(seq, line_tokenizer=list).long()
fairseq-0.10.2/fairseq/data/id_dataset.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import FairseqDataset
9
+
10
+
11
+ class IdDataset(FairseqDataset):
12
+ def __getitem__(self, index):
13
+ return index
14
+
15
+ def __len__(self):
16
+ return 0
17
+
18
+ def collater(self, samples):
19
+ return torch.tensor(samples)
fairseq-0.10.2/fairseq/data/language_pair_dataset.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ import numpy as np
9
+ import torch
10
+ from fairseq.data import FairseqDataset, data_utils
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def collate(
17
+ samples,
18
+ pad_idx,
19
+ eos_idx,
20
+ left_pad_source=True,
21
+ left_pad_target=False,
22
+ input_feeding=True,
23
+ pad_to_length=None,
24
+ pad_to_multiple=1,
25
+ ):
26
+ if len(samples) == 0:
27
+ return {}
28
+
29
+ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
30
+ return data_utils.collate_tokens(
31
+ [s[key] for s in samples],
32
+ pad_idx,
33
+ eos_idx,
34
+ left_pad,
35
+ move_eos_to_beginning,
36
+ pad_to_length=pad_to_length,
37
+ pad_to_multiple=pad_to_multiple,
38
+ )
39
+
40
+ def check_alignment(alignment, src_len, tgt_len):
41
+ if alignment is None or len(alignment) == 0:
42
+ return False
43
+ if (
44
+ alignment[:, 0].max().item() >= src_len - 1
45
+ or alignment[:, 1].max().item() >= tgt_len - 1
46
+ ):
47
+ logger.warning("alignment size mismatch found, skipping alignment!")
48
+ return False
49
+ return True
50
+
51
+ def compute_alignment_weights(alignments):
52
+ """
53
+ Given a tensor of shape [:, 2] containing the source-target indices
54
+ corresponding to the alignments, a weight vector containing the
55
+ inverse frequency of each target index is computed.
56
+ For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
57
+ a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
58
+ index 3 is repeated twice)
59
+ """
60
+ align_tgt = alignments[:, 1]
61
+ _, align_tgt_i, align_tgt_c = torch.unique(
62
+ align_tgt, return_inverse=True, return_counts=True
63
+ )
64
+ align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
65
+ return 1.0 / align_weights.float()
66
+
67
+ id = torch.LongTensor([s["id"] for s in samples])
68
+
69
+ # import pdb;pdb.set_trace()
70
+
71
+ src_tokens = merge(
72
+ "source",
73
+ left_pad=left_pad_source,
74
+ pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
75
+ )
76
+ # sort by descending source length
77
+ src_lengths = torch.LongTensor(
78
+ [s["source"].ne(pad_idx).long().sum() for s in samples]
79
+ )
80
+
81
+ # import pdb;pdb.set_trace()
82
+
83
+ src_lengths, sort_order = src_lengths.sort(descending=True)
84
+ id = id.index_select(0, sort_order)
85
+ src_tokens = src_tokens.index_select(0, sort_order)
86
+
87
+ prev_output_tokens = None
88
+ target = None
89
+ if samples[0].get("target", None) is not None:
90
+ target = merge(
91
+ "target",
92
+ left_pad=left_pad_target,
93
+ pad_to_length=pad_to_length["target"]
94
+ if pad_to_length is not None
95
+ else None,
96
+ )
97
+ target = target.index_select(0, sort_order)
98
+ tgt_lengths = torch.LongTensor(
99
+ [s["target"].ne(pad_idx).long().sum() for s in samples]
100
+ ).index_select(0, sort_order)
101
+ ntokens = tgt_lengths.sum().item()
102
+
103
+ if samples[0].get("prev_output_tokens", None) is not None:
104
+ prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
105
+ elif input_feeding:
106
+ # we create a shifted version of targets for feeding the
107
+ # previous output token(s) into the next decoder step
108
+ prev_output_tokens = merge(
109
+ "target",
110
+ left_pad=left_pad_target,
111
+ move_eos_to_beginning=True,
112
+ pad_to_length=pad_to_length["target"]
113
+ if pad_to_length is not None
114
+ else None,
115
+ )
116
+ else:
117
+ ntokens = src_lengths.sum().item()
118
+
119
+ batch = {
120
+ "id": id,
121
+ "nsentences": len(samples),
122
+ "ntokens": ntokens,
123
+ "net_input": {
124
+ "src_tokens": src_tokens,
125
+ "src_lengths": src_lengths,
126
+ },
127
+ "target": target,
128
+ }
129
+ if prev_output_tokens is not None:
130
+ batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
131
+ 0, sort_order
132
+ )
133
+
134
+ if samples[0].get("alignment", None) is not None:
135
+ bsz, tgt_sz = batch["target"].shape
136
+ src_sz = batch["net_input"]["src_tokens"].shape[1]
137
+
138
+ offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
139
+ offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz
140
+ if left_pad_source:
141
+ offsets[:, 0] += src_sz - src_lengths
142
+ if left_pad_target:
143
+ offsets[:, 1] += tgt_sz - tgt_lengths
144
+
145
+ alignments = [
146
+ alignment + offset
147
+ for align_idx, offset, src_len, tgt_len in zip(
148
+ sort_order, offsets, src_lengths, tgt_lengths
149
+ )
150
+ for alignment in [samples[align_idx]["alignment"].view(-1, 2)]
151
+ if check_alignment(alignment, src_len, tgt_len)
152
+ ]
153
+
154
+ if len(alignments) > 0:
155
+ alignments = torch.cat(alignments, dim=0)
156
+ align_weights = compute_alignment_weights(alignments)
157
+
158
+ batch["alignments"] = alignments
159
+ batch["align_weights"] = align_weights
160
+
161
+ if samples[0].get("constraints", None) is not None:
162
+ # Collate the packed constraints across the samples, padding to
163
+ # the length of the longest sample.
164
+ lens = [sample.get("constraints").size(0) for sample in samples]
165
+ max_len = max(lens)
166
+ constraints = torch.zeros((len(samples), max(lens))).long()
167
+ for i, sample in enumerate(samples):
168
+ constraints[i, 0 : lens[i]] = samples[i].get("constraints")
169
+ batch["constraints"] = constraints
170
+
171
+ return batch
172
+
173
+
174
+ class LanguagePairDataset(FairseqDataset):
175
+ """
176
+ A pair of torch.utils.data.Datasets.
177
+
178
+ Args:
179
+ src (torch.utils.data.Dataset): source dataset to wrap
180
+ src_sizes (List[int]): source sentence lengths
181
+ src_dict (~fairseq.data.Dictionary): source vocabulary
182
+ tgt (torch.utils.data.Dataset, optional): target dataset to wrap
183
+ tgt_sizes (List[int], optional): target sentence lengths
184
+ tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
185
+ left_pad_source (bool, optional): pad source tensors on the left side
186
+ (default: True).
187
+ left_pad_target (bool, optional): pad target tensors on the left side
188
+ (default: False).
189
+ shuffle (bool, optional): shuffle dataset elements before batching
190
+ (default: True).
191
+ input_feeding (bool, optional): create a shifted version of the targets
192
+ to be passed into the model for teacher forcing (default: True).
193
+ remove_eos_from_source (bool, optional): if set, removes eos from end
194
+ of source if it's present (default: False).
195
+ append_eos_to_target (bool, optional): if set, appends eos to end of
196
+ target if it's absent (default: False).
197
+ align_dataset (torch.utils.data.Dataset, optional): dataset
198
+ containing alignments.
199
+ constraints (Tensor, optional): 2d tensor with a concatenated, zero-
200
+ delimited list of constraints for each sentence.
201
+ append_bos (bool, optional): if set, appends bos to the beginning of
202
+ source/target sentence.
203
+ num_buckets (int, optional): if set to a value greater than 0, then
204
+ batches will be bucketed into the given number of batch shapes.
205
+ src_lang_id (int, optional): source language ID, if set, the collated batch
206
+ will contain a field 'src_lang_id' in 'net_input' which indicates the
207
+ source language of the samples.
208
+ tgt_lang_id (int, optional): target language ID, if set, the collated batch
209
+ will contain a field 'tgt_lang_id' which indicates the target language
210
+ of the samples.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ src,
216
+ src_sizes,
217
+ src_dict,
218
+ tgt=None,
219
+ tgt_sizes=None,
220
+ tgt_dict=None,
221
+ left_pad_source=True,
222
+ left_pad_target=False,
223
+ shuffle=True,
224
+ input_feeding=True,
225
+ remove_eos_from_source=False,
226
+ append_eos_to_target=False,
227
+ align_dataset=None,
228
+ constraints=None,
229
+ append_bos=False,
230
+ eos=None,
231
+ num_buckets=0,
232
+ src_lang_id=None,
233
+ tgt_lang_id=None,
234
+ pad_to_multiple=1,
235
+ ):
236
+ if tgt_dict is not None:
237
+ assert src_dict.pad() == tgt_dict.pad()
238
+ assert src_dict.eos() == tgt_dict.eos()
239
+ assert src_dict.unk() == tgt_dict.unk()
240
+ if tgt is not None:
241
+ assert len(src) == len(
242
+ tgt
243
+ ), "Source and target must contain the same number of examples"
244
+ self.src = src
245
+ self.tgt = tgt
246
+ self.src_sizes = np.array(src_sizes)
247
+ self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
248
+ self.sizes = (
249
+ np.vstack((self.src_sizes, self.tgt_sizes)).T
250
+ if self.tgt_sizes is not None
251
+ else self.src_sizes
252
+ )
253
+ self.src_dict = src_dict
254
+ self.tgt_dict = tgt_dict
255
+ self.left_pad_source = left_pad_source
256
+ self.left_pad_target = left_pad_target
257
+ self.shuffle = shuffle
258
+ self.input_feeding = input_feeding
259
+ self.remove_eos_from_source = remove_eos_from_source
260
+ self.append_eos_to_target = append_eos_to_target
261
+ self.align_dataset = align_dataset
262
+ if self.align_dataset is not None:
263
+ assert (
264
+ self.tgt_sizes is not None
265
+ ), "Both source and target needed when alignments are provided"
266
+ self.constraints = constraints
267
+ self.append_bos = append_bos
268
+ self.eos = eos if eos is not None else src_dict.eos()
269
+ self.src_lang_id = src_lang_id
270
+ self.tgt_lang_id = tgt_lang_id
271
+ if num_buckets > 0:
272
+ from fairseq.data import BucketPadLengthDataset
273
+
274
+ self.src = BucketPadLengthDataset(
275
+ self.src,
276
+ sizes=self.src_sizes,
277
+ num_buckets=num_buckets,
278
+ pad_idx=self.src_dict.pad(),
279
+ left_pad=self.left_pad_source,
280
+ )
281
+ self.src_sizes = self.src.sizes
282
+ logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
283
+ if self.tgt is not None:
284
+ self.tgt = BucketPadLengthDataset(
285
+ self.tgt,
286
+ sizes=self.tgt_sizes,
287
+ num_buckets=num_buckets,
288
+ pad_idx=self.tgt_dict.pad(),
289
+ left_pad=self.left_pad_target,
290
+ )
291
+ self.tgt_sizes = self.tgt.sizes
292
+ logger.info(
293
+ "bucketing target lengths: {}".format(list(self.tgt.buckets))
294
+ )
295
+
296
+ # determine bucket sizes using self.num_tokens, which will return
297
+ # the padded lengths (thanks to BucketPadLengthDataset)
298
+ num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
299
+ self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
300
+ self.buckets = [
301
+ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
302
+ ]
303
+ else:
304
+ self.buckets = None
305
+ self.pad_to_multiple = pad_to_multiple
306
+
307
+ def get_batch_shapes(self):
308
+ return self.buckets
309
+
310
+ def __getitem__(self, index):
311
+ tgt_item = self.tgt[index] if self.tgt is not None else None
312
+ src_item = self.src[index]
313
+ # Append EOS to end of tgt sentence if it does not have an EOS and remove
314
+ # EOS from end of src sentence if it exists. This is useful when we use
315
+ # use existing datasets for opposite directions i.e., when we want to
316
+ # use tgt_dataset as src_dataset and vice versa
317
+ if self.append_eos_to_target:
318
+ eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
319
+ if self.tgt and self.tgt[index][-1] != eos:
320
+ tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
321
+
322
+ if self.append_bos:
323
+ bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
324
+ if self.tgt and self.tgt[index][0] != bos:
325
+ tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
326
+
327
+ bos = self.src_dict.bos()
328
+ if self.src[index][0] != bos:
329
+ src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
330
+
331
+ if self.remove_eos_from_source:
332
+ eos = self.src_dict.eos()
333
+ if self.src[index][-1] == eos:
334
+ src_item = self.src[index][:-1]
335
+
336
+ example = {
337
+ "id": index,
338
+ "source": src_item,
339
+ "target": tgt_item,
340
+ }
341
+ if self.align_dataset is not None:
342
+ example["alignment"] = self.align_dataset[index]
343
+ if self.constraints is not None:
344
+ example["constraints"] = self.constraints[index]
345
+ return example
346
+
347
+ def __len__(self):
348
+ return len(self.src)
349
+
350
+ def collater(self, samples, pad_to_length=None):
351
+ """Merge a list of samples to form a mini-batch.
352
+
353
+ Args:
354
+ samples (List[dict]): samples to collate
355
+ pad_to_length (dict, optional): a dictionary of
356
+ {'source': source_pad_to_length, 'target': target_pad_to_length}
357
+ to indicate the max length to pad to in source and target respectively.
358
+
359
+ Returns:
360
+ dict: a mini-batch with the following keys:
361
+
362
+ - `id` (LongTensor): example IDs in the original input order
363
+ - `ntokens` (int): total number of tokens in the batch
364
+ - `net_input` (dict): the input to the Model, containing keys:
365
+
366
+ - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
367
+ the source sentence of shape `(bsz, src_len)`. Padding will
368
+ appear on the left if *left_pad_source* is ``True``.
369
+ - `src_lengths` (LongTensor): 1D Tensor of the unpadded
370
+ lengths of each source sentence of shape `(bsz)`
371
+ - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
372
+ tokens in the target sentence, shifted right by one
373
+ position for teacher forcing, of shape `(bsz, tgt_len)`.
374
+ This key will not be present if *input_feeding* is
375
+ ``False``. Padding will appear on the left if
376
+ *left_pad_target* is ``True``.
377
+ - `src_lang_id` (LongTensor): a long Tensor which contains source
378
+ language IDs of each sample in the batch
379
+
380
+ - `target` (LongTensor): a padded 2D Tensor of tokens in the
381
+ target sentence of shape `(bsz, tgt_len)`. Padding will appear
382
+ on the left if *left_pad_target* is ``True``.
383
+ - `tgt_lang_id` (LongTensor): a long Tensor which contains target language
384
+ IDs of each sample in the batch
385
+ """
386
+ res = collate(
387
+ samples,
388
+ pad_idx=self.src_dict.pad(),
389
+ eos_idx=self.eos,
390
+ left_pad_source=self.left_pad_source,
391
+ left_pad_target=self.left_pad_target,
392
+ input_feeding=self.input_feeding,
393
+ pad_to_length=pad_to_length,
394
+ pad_to_multiple=self.pad_to_multiple,
395
+ )
396
+ if self.src_lang_id is not None or self.tgt_lang_id is not None:
397
+ src_tokens = res["net_input"]["src_tokens"]
398
+ bsz = src_tokens.size(0)
399
+ if self.src_lang_id is not None:
400
+ res["net_input"]["src_lang_id"] = (
401
+ torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
402
+ )
403
+ if self.tgt_lang_id is not None:
404
+ res["tgt_lang_id"] = (
405
+ torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
406
+ )
407
+ return res
408
+
409
+ def num_tokens(self, index):
410
+ """Return the number of tokens in a sample. This value is used to
411
+ enforce ``--max-tokens`` during batching."""
412
+ return max(
413
+ self.src_sizes[index],
414
+ self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
415
+ )
416
+
417
+ def size(self, index):
418
+ """Return an example's size as a float or tuple. This value is used when
419
+ filtering a dataset with ``--max-positions``."""
420
+ return (
421
+ self.src_sizes[index],
422
+ self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
423
+ )
424
+
425
+ def ordered_indices(self):
426
+ """Return an ordered list of indices. Batches will be constructed based
427
+ on this order."""
428
+ if self.shuffle:
429
+ indices = np.random.permutation(len(self)).astype(np.int64)
430
+ else:
431
+ indices = np.arange(len(self), dtype=np.int64)
432
+ if self.buckets is None:
433
+ # sort by target length, then source length
434
+ if self.tgt_sizes is not None:
435
+ indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
436
+ return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
437
+ else:
438
+ # sort by bucketed_num_tokens, which is:
439
+ # max(padded_src_len, padded_tgt_len)
440
+ return indices[
441
+ np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
442
+ ]
443
+
444
+ @property
445
+ def supports_prefetch(self):
446
+ return getattr(self.src, "supports_prefetch", False) and (
447
+ getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
448
+ )
449
+
450
+ def prefetch(self, indices):
451
+ self.src.prefetch(indices)
452
+ if self.tgt is not None:
453
+ self.tgt.prefetch(indices)
454
+ if self.align_dataset is not None:
455
+ self.align_dataset.prefetch(indices)
456
+
457
+ def filter_indices_by_size(self, indices, max_sizes):
458
+ """Filter a list of sample indices. Remove those that are longer
459
+ than specified in max_sizes.
460
+
461
+ Args:
462
+ indices (np.array): original array of sample indices
463
+ max_sizes (int or list[int] or tuple[int]): max sample size,
464
+ can be defined separately for src and tgt (then list or tuple)
465
+
466
+ Returns:
467
+ np.array: filtered sample array
468
+ list: list of removed indices
469
+ """
470
+ return data_utils.filter_paired_dataset_indices_by_size(
471
+ self.src_sizes,
472
+ self.tgt_sizes,
473
+ indices,
474
+ max_sizes,
475
+ )
fairseq-0.10.2/fairseq/data/legacy/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .block_pair_dataset import BlockPairDataset
7
+ from .masked_lm_dataset import MaskedLMDataset
8
+ from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
9
+
10
+
11
+ __all__ = [
12
+ "BertDictionary",
13
+ "BlockPairDataset",
14
+ "MaskedLMDataset",
15
+ "MaskedLMDictionary",
16
+ ]
fairseq-0.10.2/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (413 Bytes). View file
 
fairseq-0.10.2/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc ADDED
Binary file (8.76 kB). View file
 
fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc ADDED
Binary file (2.05 kB). View file
 
fairseq-0.10.2/fairseq/data/legacy/block_pair_dataset.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import numpy as np
9
+ import torch
10
+ from fairseq.data import FairseqDataset
11
+
12
+
13
+ class BlockPairDataset(FairseqDataset):
14
+ """Break a Dataset of tokens into sentence pair blocks for next sentence
15
+ prediction as well as masked language model.
16
+
17
+ High-level logics are:
18
+ 1. break input tensor to tensor blocks
19
+ 2. pair the blocks with 50% next sentence and 50% random sentence
20
+ 3. return paired blocks as well as related segment labels
21
+
22
+ Args:
23
+ dataset (~torch.utils.data.Dataset): dataset to break into blocks
24
+ sizes: array of sentence lengths
25
+ dictionary: dictionary for the task
26
+ block_size: maximum block size
27
+ break_mode: mode for breaking copurs into block pairs. currently we support
28
+ 2 modes
29
+ doc: respect document boundaries and each part of the pair should belong to on document
30
+ none: don't respect any boundary and cut tokens evenly
31
+ short_seq_prob: probability for generating shorter block pairs
32
+ doc_break_size: Size for empty line separating documents. Typically 1 if
33
+ the sentences have eos, 0 otherwise.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ dataset,
39
+ dictionary,
40
+ sizes,
41
+ block_size,
42
+ break_mode="doc",
43
+ short_seq_prob=0.1,
44
+ doc_break_size=1,
45
+ ):
46
+ super().__init__()
47
+ self.dataset = dataset
48
+ self.pad = dictionary.pad()
49
+ self.eos = dictionary.eos()
50
+ self.cls = dictionary.cls()
51
+ self.mask = dictionary.mask()
52
+ self.sep = dictionary.sep()
53
+ self.break_mode = break_mode
54
+ self.dictionary = dictionary
55
+ self.short_seq_prob = short_seq_prob
56
+ self.block_indices = []
57
+
58
+ assert len(dataset) == len(sizes)
59
+
60
+ if break_mode == "doc":
61
+ cur_doc = []
62
+ for sent_id, sz in enumerate(sizes):
63
+ assert doc_break_size == 0 or sz != 0, (
64
+ "when doc_break_size is non-zero, we expect documents to be"
65
+ "separated by a blank line with a single eos."
66
+ )
67
+ # empty line as document separator
68
+ if sz == doc_break_size:
69
+ if len(cur_doc) == 0:
70
+ continue
71
+ self.block_indices.append(cur_doc)
72
+ cur_doc = []
73
+ else:
74
+ cur_doc.append(sent_id)
75
+ max_num_tokens = block_size - 3 # Account for [CLS], [SEP], [SEP]
76
+ self.sent_pairs = []
77
+ self.sizes = []
78
+ for doc_id, doc in enumerate(self.block_indices):
79
+ self._generate_sentence_pair(doc, doc_id, max_num_tokens, sizes)
80
+ elif break_mode is None or break_mode == "none":
81
+ # each block should have half of the block size since we are constructing block pair
82
+ sent_length = (block_size - 3) // 2
83
+ total_len = sum(dataset.sizes)
84
+ length = math.ceil(total_len / sent_length)
85
+
86
+ def block_at(i):
87
+ start = i * sent_length
88
+ end = min(start + sent_length, total_len)
89
+ return (start, end)
90
+
91
+ sent_indices = np.array([block_at(i) for i in range(length)])
92
+ sent_sizes = np.array([e - s for s, e in sent_indices])
93
+ dataset_index = self._sent_to_dataset_index(sent_sizes)
94
+
95
+ # pair sentences
96
+ self._pair_sentences(dataset_index)
97
+ else:
98
+ raise ValueError("Invalid break_mode: " + break_mode)
99
+
100
+ def _pair_sentences(self, dataset_index):
101
+ """
102
+ Give a list of evenly cut blocks/sentences, pair these sentences with 50%
103
+ consecutive sentences and 50% random sentences.
104
+ This is used for none break mode
105
+ """
106
+ # pair sentences
107
+ for sent_id, sent in enumerate(dataset_index):
108
+ next_sent_label = (
109
+ 1 if np.random.rand() > 0.5 and sent_id != len(dataset_index) - 1 else 0
110
+ )
111
+ if next_sent_label:
112
+ next_sent = dataset_index[sent_id + 1]
113
+ else:
114
+ next_sent = dataset_index[
115
+ self._skip_sampling(len(dataset_index), [sent_id, sent_id + 1])
116
+ ]
117
+ self.sent_pairs.append((sent, next_sent, next_sent_label))
118
+
119
+ # The current blocks don't include the special tokens but the
120
+ # sizes already account for this
121
+ self.sizes.append(3 + sent[3] + next_sent[3])
122
+
123
+ def _sent_to_dataset_index(self, sent_sizes):
124
+ """
125
+ Build index mapping block indices to the underlying dataset indices
126
+ """
127
+ dataset_index = []
128
+ ds_idx, ds_remaining = -1, 0
129
+ for to_consume in sent_sizes:
130
+ sent_size = to_consume
131
+ if ds_remaining == 0:
132
+ ds_idx += 1
133
+ ds_remaining = sent_sizes[ds_idx]
134
+ start_ds_idx = ds_idx
135
+ start_offset = sent_sizes[ds_idx] - ds_remaining
136
+ while to_consume > ds_remaining:
137
+ to_consume -= ds_remaining
138
+ ds_idx += 1
139
+ ds_remaining = sent_sizes[ds_idx]
140
+ ds_remaining -= to_consume
141
+ dataset_index.append(
142
+ (
143
+ start_ds_idx, # starting index in dataset
144
+ start_offset, # starting offset within starting index
145
+ ds_idx, # ending index in dataset
146
+ sent_size, # sentence length
147
+ )
148
+ )
149
+ assert ds_remaining == 0
150
+ assert ds_idx == len(self.dataset) - 1
151
+ return dataset_index
152
+
153
+ def _generate_sentence_pair(self, doc, doc_id, max_num_tokens, sizes):
154
+ """
155
+ Go through a single document and genrate sentence paris from it
156
+ """
157
+ current_chunk = []
158
+ current_length = 0
159
+ curr = 0
160
+ # To provide more randomness, we decrease target seq length for parts of
161
+ # samples (10% by default). Note that max_num_tokens is the hard threshold
162
+ # for batching and will never be changed.
163
+ target_seq_length = max_num_tokens
164
+ if np.random.random() < self.short_seq_prob:
165
+ target_seq_length = np.random.randint(2, max_num_tokens)
166
+ # loop through all sentences in document
167
+ while curr < len(doc):
168
+ sent_id = doc[curr]
169
+ current_chunk.append(sent_id)
170
+ current_length = sum(sizes[current_chunk])
171
+ # split chunk and generate pair when exceed target_seq_length or
172
+ # finish the loop
173
+ if curr == len(doc) - 1 or current_length >= target_seq_length:
174
+ # split the chunk into 2 parts
175
+ a_end = 1
176
+ if len(current_chunk) > 2:
177
+ a_end = np.random.randint(1, len(current_chunk) - 1)
178
+ sent_a = current_chunk[:a_end]
179
+ len_a = sum(sizes[sent_a])
180
+ # generate next sentence label, note that if there is only 1 sentence
181
+ # in current chunk, label is always 0
182
+ next_sent_label = (
183
+ 1 if np.random.rand() > 0.5 and len(current_chunk) != 1 else 0
184
+ )
185
+ if not next_sent_label:
186
+ # if next sentence label is 0, sample sent_b from a random doc
187
+ target_b_length = target_seq_length - len_a
188
+ rand_doc_id = self._skip_sampling(len(self.block_indices), [doc_id])
189
+ random_doc = self.block_indices[rand_doc_id]
190
+ random_start = np.random.randint(0, len(random_doc))
191
+ sent_b = []
192
+ len_b = 0
193
+ for j in range(random_start, len(random_doc)):
194
+ sent_b.append(random_doc[j])
195
+ len_b = sum(sizes[sent_b])
196
+ if len_b >= target_b_length:
197
+ break
198
+ # return the second part of the chunk since it's not used
199
+ num_unused_segments = len(current_chunk) - a_end
200
+ curr -= num_unused_segments
201
+ else:
202
+ # if next sentence label is 1, use the second part of chunk as sent_B
203
+ sent_b = current_chunk[a_end:]
204
+ len_b = sum(sizes[sent_b])
205
+ # currently sent_a and sent_B may be longer than max_num_tokens,
206
+ # truncate them and return block idx and offsets for them
207
+ sent_a, sent_b = self._truncate_sentences(
208
+ sent_a, sent_b, max_num_tokens
209
+ )
210
+ self.sent_pairs.append((sent_a, sent_b, next_sent_label))
211
+ self.sizes.append(3 + sent_a[3] + sent_b[3])
212
+ current_chunk = []
213
+ curr += 1
214
+
215
+ def _skip_sampling(self, total, skip_ids):
216
+ """
217
+ Generate a random integer which is not in skip_ids. Sample range is [0, total)
218
+ TODO: ids in skip_ids should be consecutive, we can extend it to more generic version later
219
+ """
220
+ rand_id = np.random.randint(total - len(skip_ids))
221
+ return rand_id if rand_id < min(skip_ids) else rand_id + len(skip_ids)
222
+
223
+ def _truncate_sentences(self, sent_a, sent_b, max_num_tokens):
224
+ """
225
+ Trancate a pair of sentence to limit total length under max_num_tokens
226
+ Logics:
227
+ 1. Truncate longer sentence
228
+ 2. Tokens to be truncated could be at the beginning or the end of the sentnce
229
+ Returns:
230
+ Truncated sentences represented by dataset idx
231
+ """
232
+ len_a, len_b = sum(self.dataset.sizes[sent_a]), sum(self.dataset.sizes[sent_b])
233
+ front_cut_a = front_cut_b = end_cut_a = end_cut_b = 0
234
+
235
+ while True:
236
+ total_length = (
237
+ len_a + len_b - front_cut_a - front_cut_b - end_cut_a - end_cut_b
238
+ )
239
+ if total_length <= max_num_tokens:
240
+ break
241
+
242
+ if len_a - front_cut_a - end_cut_a > len_b - front_cut_b - end_cut_b:
243
+ if np.random.rand() < 0.5:
244
+ front_cut_a += 1
245
+ else:
246
+ end_cut_a += 1
247
+ else:
248
+ if np.random.rand() < 0.5:
249
+ front_cut_b += 1
250
+ else:
251
+ end_cut_b += 1
252
+
253
+ # calculate ds indices as well as offsets and return
254
+ truncated_sent_a = self._cut_sentence(sent_a, front_cut_a, end_cut_a)
255
+ truncated_sent_b = self._cut_sentence(sent_b, front_cut_b, end_cut_b)
256
+ return truncated_sent_a, truncated_sent_b
257
+
258
+ def _cut_sentence(self, sent, front_cut, end_cut):
259
+ """
260
+ Cut a sentence based on the numbers of tokens to be cut from beginning and end
261
+ Represent the sentence as dataset idx and return
262
+ """
263
+ start_ds_idx, end_ds_idx, offset = sent[0], sent[-1], 0
264
+ target_len = sum(self.dataset.sizes[sent]) - front_cut - end_cut
265
+ while front_cut > 0:
266
+ if self.dataset.sizes[start_ds_idx] > front_cut:
267
+ offset += front_cut
268
+ break
269
+ else:
270
+ front_cut -= self.dataset.sizes[start_ds_idx]
271
+ start_ds_idx += 1
272
+ while end_cut > 0:
273
+ if self.dataset.sizes[end_ds_idx] > end_cut:
274
+ break
275
+ else:
276
+ end_cut -= self.dataset.sizes[end_ds_idx]
277
+ end_ds_idx -= 1
278
+ return start_ds_idx, offset, end_ds_idx, target_len
279
+
280
+ def _fetch_block(self, start_ds_idx, offset, end_ds_idx, length):
281
+ """
282
+ Fetch a block of tokens based on its dataset idx
283
+ """
284
+ buffer = torch.cat(
285
+ [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
286
+ )
287
+ s, e = offset, offset + length
288
+ return buffer[s:e]
289
+
290
+ def __getitem__(self, index):
291
+ block1, block2, next_sent_label = self.sent_pairs[index]
292
+ block1 = self._fetch_block(*block1)
293
+ block2 = self._fetch_block(*block2)
294
+ return block1, block2, next_sent_label
295
+
296
+ def __len__(self):
297
+ return len(self.sizes)
298
+
299
+ @property
300
+ def supports_prefetch(self):
301
+ return getattr(self.dataset, "supports_prefetch", False)
302
+
303
+ def prefetch(self, indices):
304
+ prefetch_idx = set()
305
+ for index in indices:
306
+ for block1, block2, _ in [self.sent_pairs[index]]:
307
+ for ds_idx in range(block1[0], block1[2] + 1):
308
+ prefetch_idx.add(ds_idx)
309
+ for ds_idx in range(block2[0], block2[2] + 1):
310
+ prefetch_idx.add(ds_idx)
311
+ self.dataset.prefetch(prefetch_idx)
fairseq-0.10.2/fairseq/data/legacy/masked_lm_dictionary.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq.data import Dictionary
7
+
8
+
9
+ class MaskedLMDictionary(Dictionary):
10
+ """
11
+ Dictionary for Masked Language Modelling tasks. This extends Dictionary by
12
+ adding the mask symbol.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ pad="<pad>",
18
+ eos="</s>",
19
+ unk="<unk>",
20
+ mask="<mask>",
21
+ ):
22
+ super().__init__(pad=pad, eos=eos, unk=unk)
23
+ self.mask_word = mask
24
+ self.mask_index = self.add_symbol(mask)
25
+ self.nspecial = len(self.symbols)
26
+
27
+ def mask(self):
28
+ """Helper to get index of mask symbol"""
29
+ return self.mask_index
30
+
31
+
32
+ class BertDictionary(MaskedLMDictionary):
33
+ """
34
+ Dictionary for BERT task. This extends MaskedLMDictionary by adding support
35
+ for cls and sep symbols.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ pad="<pad>",
41
+ eos="</s>",
42
+ unk="<unk>",
43
+ mask="<mask>",
44
+ cls="<cls>",
45
+ sep="<sep>",
46
+ ):
47
+ super().__init__(pad=pad, eos=eos, unk=unk, mask=mask)
48
+ self.cls_word = cls
49
+ self.sep_word = sep
50
+ self.cls_index = self.add_symbol(cls)
51
+ self.sep_index = self.add_symbol(sep)
52
+ self.nspecial = len(self.symbols)
53
+
54
+ def cls(self):
55
+ """Helper to get index of cls symbol"""
56
+ return self.cls_index
57
+
58
+ def sep(self):
59
+ """Helper to get index of sep symbol"""
60
+ return self.sep_index
fairseq-0.10.2/fairseq/data/list_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from . import BaseWrapperDataset
7
+
8
+
9
+ class ListDataset(BaseWrapperDataset):
10
+ def __init__(self, dataset, sizes=None):
11
+ super().__init__(dataset)
12
+ self._sizes = sizes
13
+
14
+ def __iter__(self):
15
+ for x in self.dataset:
16
+ yield x
17
+
18
+ def collater(self, samples):
19
+ return samples
20
+
21
+ @property
22
+ def sizes(self):
23
+ return self._sizes
24
+
25
+ def num_tokens(self, index):
26
+ return self.sizes[index]
27
+
28
+ def size(self, index):
29
+ return self.sizes[index]
30
+
31
+ def set_epoch(self, epoch):
32
+ pass
fairseq-0.10.2/fairseq/data/lru_cache_dataset.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from functools import lru_cache
7
+
8
+ from . import BaseWrapperDataset
9
+
10
+
11
+ class LRUCacheDataset(BaseWrapperDataset):
12
+ def __init__(self, dataset, token=None):
13
+ super().__init__(dataset)
14
+
15
+ @lru_cache(maxsize=8)
16
+ def __getitem__(self, index):
17
+ return self.dataset[index]
18
+
19
+ @lru_cache(maxsize=8)
20
+ def collater(self, samples):
21
+ return self.dataset.collater(samples)
fairseq-0.10.2/fairseq/data/mask_tokens_dataset.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from functools import lru_cache
7
+
8
+ import numpy as np
9
+ import torch
10
+ from fairseq.data import Dictionary, data_utils
11
+
12
+ from . import BaseWrapperDataset, LRUCacheDataset
13
+
14
+
15
+ class MaskTokensDataset(BaseWrapperDataset):
16
+ """
17
+ A wrapper Dataset for masked language modeling.
18
+
19
+ Input items are masked according to the specified masking probability.
20
+
21
+ Args:
22
+ dataset: Dataset to wrap.
23
+ sizes: Sentence lengths
24
+ vocab: Dictionary with the vocabulary and special tokens.
25
+ pad_idx: Id of pad token in vocab
26
+ mask_idx: Id of mask token in vocab
27
+ return_masked_tokens: controls whether to return the non-masked tokens
28
+ (the default) or to return a tensor with the original masked token
29
+ IDs (and *pad_idx* elsewhere). The latter is useful as targets for
30
+ masked LM training.
31
+ seed: Seed for random number generator for reproducibility.
32
+ mask_prob: probability of replacing a token with *mask_idx*.
33
+ leave_unmasked_prob: probability that a masked token is unmasked.
34
+ random_token_prob: probability of replacing a masked token with a
35
+ random token from the vocabulary.
36
+ freq_weighted_replacement: sample random replacement words based on
37
+ word frequencies in the vocab.
38
+ mask_whole_words: only mask whole words. This should be a byte mask
39
+ over vocab indices, indicating whether it is the beginning of a
40
+ word. We will extend any mask to encompass the whole word.
41
+ bpe: BPE to use for whole-word masking.
42
+ """
43
+
44
+ @classmethod
45
+ def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs):
46
+ """Return the source and target datasets for masked LM training."""
47
+ dataset = LRUCacheDataset(dataset)
48
+ return (
49
+ LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)),
50
+ LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)),
51
+ )
52
+
53
+ def __init__(
54
+ self,
55
+ dataset: torch.utils.data.Dataset,
56
+ vocab: Dictionary,
57
+ pad_idx: int,
58
+ mask_idx: int,
59
+ return_masked_tokens: bool = False,
60
+ seed: int = 1,
61
+ mask_prob: float = 0.15,
62
+ leave_unmasked_prob: float = 0.1,
63
+ random_token_prob: float = 0.1,
64
+ freq_weighted_replacement: bool = False,
65
+ mask_whole_words: torch.Tensor = None,
66
+ ):
67
+ assert 0.0 < mask_prob < 1.0
68
+ assert 0.0 <= random_token_prob <= 1.0
69
+ assert 0.0 <= leave_unmasked_prob <= 1.0
70
+ assert random_token_prob + leave_unmasked_prob <= 1.0
71
+
72
+ self.dataset = dataset
73
+ self.vocab = vocab
74
+ self.pad_idx = pad_idx
75
+ self.mask_idx = mask_idx
76
+ self.return_masked_tokens = return_masked_tokens
77
+ self.seed = seed
78
+ self.mask_prob = mask_prob
79
+ self.leave_unmasked_prob = leave_unmasked_prob
80
+ self.random_token_prob = random_token_prob
81
+ self.mask_whole_words = mask_whole_words
82
+
83
+ if random_token_prob > 0.0:
84
+ if freq_weighted_replacement:
85
+ weights = np.array(self.vocab.count)
86
+ else:
87
+ weights = np.ones(len(self.vocab))
88
+ weights[: self.vocab.nspecial] = 0
89
+ self.weights = weights / weights.sum()
90
+
91
+ self.epoch = 0
92
+
93
+ @property
94
+ def can_reuse_epoch_itr_across_epochs(self):
95
+ return True # only the noise changes, not item sizes
96
+
97
+ def set_epoch(self, epoch, **unused):
98
+ super().set_epoch(epoch)
99
+ self.epoch = epoch
100
+
101
+ @lru_cache(maxsize=8)
102
+ def __getitem__(self, index: int):
103
+ with data_utils.numpy_seed(self.seed, self.epoch, index):
104
+ item = self.dataset[index]
105
+ sz = len(item)
106
+
107
+ assert (
108
+ self.mask_idx not in item
109
+ ), "Dataset contains mask_idx (={}), this is not expected!".format(
110
+ self.mask_idx,
111
+ )
112
+
113
+ if self.mask_whole_words is not None:
114
+ word_begins_mask = self.mask_whole_words.gather(0, item)
115
+ word_begins_idx = word_begins_mask.nonzero().view(-1)
116
+ sz = len(word_begins_idx)
117
+ words = np.split(word_begins_mask, word_begins_idx)[1:]
118
+ assert len(words) == sz
119
+ word_lens = list(map(len, words))
120
+
121
+ # decide elements to mask
122
+ mask = np.full(sz, False)
123
+ num_mask = int(
124
+ # add a random number for probabilistic rounding
125
+ self.mask_prob * sz
126
+ + np.random.rand()
127
+ )
128
+ mask[np.random.choice(sz, num_mask, replace=False)] = True
129
+
130
+ if self.return_masked_tokens:
131
+ # exit early if we're just returning the masked tokens
132
+ # (i.e., the targets for masked LM training)
133
+ if self.mask_whole_words is not None:
134
+ mask = np.repeat(mask, word_lens)
135
+ new_item = np.full(len(mask), self.pad_idx)
136
+ new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
137
+ return torch.from_numpy(new_item)
138
+
139
+ # decide unmasking and random replacement
140
+ rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
141
+ if rand_or_unmask_prob > 0.0:
142
+ rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
143
+ if self.random_token_prob == 0.0:
144
+ unmask = rand_or_unmask
145
+ rand_mask = None
146
+ elif self.leave_unmasked_prob == 0.0:
147
+ unmask = None
148
+ rand_mask = rand_or_unmask
149
+ else:
150
+ unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
151
+ decision = np.random.rand(sz) < unmask_prob
152
+ unmask = rand_or_unmask & decision
153
+ rand_mask = rand_or_unmask & (~decision)
154
+ else:
155
+ unmask = rand_mask = None
156
+
157
+ if unmask is not None:
158
+ mask = mask ^ unmask
159
+
160
+ if self.mask_whole_words is not None:
161
+ mask = np.repeat(mask, word_lens)
162
+
163
+ new_item = np.copy(item)
164
+ new_item[mask] = self.mask_idx
165
+ if rand_mask is not None:
166
+ num_rand = rand_mask.sum()
167
+ if num_rand > 0:
168
+ if self.mask_whole_words is not None:
169
+ rand_mask = np.repeat(rand_mask, word_lens)
170
+ num_rand = rand_mask.sum()
171
+
172
+ new_item[rand_mask] = np.random.choice(
173
+ len(self.vocab),
174
+ num_rand,
175
+ p=self.weights,
176
+ )
177
+
178
+ return torch.from_numpy(new_item)
fairseq-0.10.2/fairseq/data/multi_corpus_dataset.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from collections import OrderedDict
8
+ from typing import Dict, List
9
+
10
+ import numpy as np
11
+ from fairseq.data import data_utils
12
+
13
+ from . import FairseqDataset
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class MultiCorpusDataset(FairseqDataset):
20
+ """
21
+ Stores multiple instances of FairseqDataset together. Requires each instance
22
+ to be the same dataset, as the collate method needs to work on batches with
23
+ samples from each dataset.
24
+
25
+ Allows specifying a distribution over the datasets to use. Note that unlike
26
+ MultiCorpusSampledDataset, this distribution allows sampling for each item,
27
+ rather than on a batch level.
28
+
29
+ Each time ordered_indices() is called, a new sample is generated with
30
+ the specified distribution.
31
+
32
+ Args:
33
+ datasets: a OrderedDict of FairseqDataset instances.
34
+ distribution: a List containing the probability of getting an utterance from
35
+ corresponding dataset
36
+ seed: random seed for sampling the datsets
37
+ sort_indices: if true, will sort the ordered indices by size
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ datasets: Dict[str, FairseqDataset],
43
+ distribution: List[float],
44
+ seed: int,
45
+ sort_indices: bool = False,
46
+ ):
47
+ super().__init__()
48
+ assert isinstance(datasets, OrderedDict)
49
+ assert len(datasets) == len(distribution)
50
+ self.datasets = datasets
51
+ self.distribution = distribution
52
+ self.seed = seed
53
+ self.sort_indices = sort_indices
54
+
55
+ # Avoid repeated conversions to list later
56
+ self.dataset_list = list(datasets.values())
57
+ self.total_num_instances = 0
58
+
59
+ first_dataset = list(self.datasets.values())[0]
60
+
61
+ self.dataset_offsets = []
62
+ for dataset in datasets.values():
63
+ assert isinstance(dataset, FairseqDataset)
64
+ assert type(dataset) is type(first_dataset)
65
+ self.dataset_offsets.append(self.total_num_instances)
66
+ self.total_num_instances += len(dataset)
67
+
68
+ def ordered_indices(self):
69
+ with data_utils.numpy_seed(self.seed, self.epoch):
70
+ # Used to store the order of indices of each dataset to use
71
+ indices = [
72
+ np.random.permutation(len(dataset))
73
+ for dataset in self.datasets.values()
74
+ ]
75
+ # Keep track of which samples we've used for each dataset
76
+ counters = [0 for _ in self.datasets]
77
+
78
+ sampled_indices = [
79
+ self._sample(indices, counters) for _ in range(self.total_num_instances)
80
+ ]
81
+ if self.sort_indices:
82
+ sampled_indices.sort(key=lambda i: self.num_tokens(i))
83
+ return np.array(sampled_indices, dtype=np.int64)
84
+
85
+ def _sample(self, indices, counters):
86
+ # First pick dataset
87
+ dataset_idx = np.random.choice(len(self.distribution), p=self.distribution)
88
+
89
+ # Then get dataset internal index
90
+ idx = indices[dataset_idx][counters[dataset_idx]]
91
+
92
+ # Convert to multi-datasets index
93
+ idx += self.dataset_offsets[dataset_idx]
94
+
95
+ counters[dataset_idx] += 1
96
+
97
+ # Reset if we reach end
98
+ if counters[dataset_idx] == len(self.dataset_list[dataset_idx]):
99
+ counters[dataset_idx] = 0
100
+ indices[dataset_idx] = np.random.permutation(
101
+ len(self.dataset_list[dataset_idx])
102
+ )
103
+
104
+ return idx
105
+
106
+ def _map_index(self, index: int):
107
+ """
108
+ If dataset A has length N and dataset B has length M
109
+ then index 1 maps to index 1 of dataset A, and index N + 1
110
+ maps to index 1 of B.
111
+ """
112
+ counter = 0
113
+ for key, dataset in self.datasets.items():
114
+ if index < counter + len(dataset):
115
+ return index - counter, key
116
+ counter += len(dataset)
117
+ raise ValueError(
118
+ "Invalid index: {}, max: {}".format(index, self.total_num_instances)
119
+ )
120
+
121
+ def __len__(self):
122
+ """
123
+ Length of this dataset is the sum of individual datasets
124
+ """
125
+ return self.total_num_instances
126
+
127
+ def __getitem__(self, index):
128
+ index, key = self._map_index(index)
129
+ return self.datasets[key][index]
130
+
131
+ def collater(self, samples):
132
+ """
133
+ Since we enforce all datsets to be the same, collating is just
134
+ picking the first one and doing collate.
135
+ """
136
+ if len(samples) == 0:
137
+ return None
138
+
139
+ return list(self.datasets.values())[0].collater(samples)
140
+
141
+ def num_tokens(self, index: int):
142
+ index, key = self._map_index(index)
143
+ return self.datasets[key].num_tokens(index)
144
+
145
+ def size(self, index: int):
146
+ index, key = self._map_index(index)
147
+ return self.datasets[key].size(index)
148
+
149
+ @property
150
+ def can_reuse_epoch_itr_across_epochs(self):
151
+ return False
152
+
153
+ def set_epoch(self, epoch, **unused):
154
+ super().set_epoch(epoch)
155
+ self.epoch = epoch
156
+
157
+ @property
158
+ def supports_prefetch(self):
159
+ return False
fairseq-0.10.2/fairseq/data/multi_corpus_sampled_dataset.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+ from typing import Callable, Dict, List
8
+
9
+ import numpy as np
10
+
11
+ from . import FairseqDataset
12
+
13
+
14
+ def uniform_sampler(x):
15
+ # Sample from uniform distribution
16
+ return np.random.choice(x, 1).item()
17
+
18
+
19
+ class MultiCorpusSampledDataset(FairseqDataset):
20
+ """
21
+ Stores multiple instances of FairseqDataset together and in every iteration
22
+ creates a batch by first sampling a dataset according to a specified
23
+ probability distribution and then getting instances from that dataset.
24
+
25
+ Args:
26
+ datasets: an OrderedDict of FairseqDataset instances.
27
+ sampling_func: A function for sampling over list of dataset keys.
28
+ The default strategy is to sample uniformly.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ datasets: Dict[str, FairseqDataset],
34
+ sampling_func: Callable[[List], int] = None,
35
+ ):
36
+ super().__init__()
37
+ assert isinstance(datasets, OrderedDict)
38
+ self.datasets = datasets
39
+ if sampling_func is None:
40
+ sampling_func = uniform_sampler
41
+ self.sampling_func = sampling_func
42
+
43
+ self.total_num_instances = 0
44
+ for _, dataset in datasets.items():
45
+ assert isinstance(dataset, FairseqDataset)
46
+ self.total_num_instances += len(dataset)
47
+
48
+ self._ordered_indices = None
49
+
50
+ def __len__(self):
51
+ """
52
+ Length of this dataset is the sum of individual datasets
53
+ """
54
+ return self.total_num_instances
55
+
56
+ def ordered_indices(self):
57
+ """
58
+ Ordered indices for batching. Here we call the underlying
59
+ dataset's ordered_indices() so that we get the same random ordering
60
+ as we would have from using the underlying dataset directly.
61
+ """
62
+ if self._ordered_indices is None:
63
+ self._ordered_indices = OrderedDict(
64
+ [
65
+ (key, dataset.ordered_indices())
66
+ for key, dataset in self.datasets.items()
67
+ ]
68
+ )
69
+ return np.arange(len(self))
70
+
71
+ def _map_index_to_dataset(self, key: int, index: int):
72
+ """
73
+ Different underlying datasets have different lengths. In order to ensure
74
+ we are not accessing an index outside the range of the current dataset
75
+ size, we wrap around. This function should be called after we have
76
+ created an ordering for this and all underlying datasets.
77
+ """
78
+ assert (
79
+ self._ordered_indices is not None
80
+ ), "Must call MultiCorpusSampledDataset.ordered_indices() first"
81
+ mapped_index = index % len(self.datasets[key])
82
+ return self._ordered_indices[key][mapped_index]
83
+
84
+ def __getitem__(self, index: int):
85
+ """
86
+ Get the item associated with index from each underlying dataset.
87
+ Since index is in the range of [0, TotalNumInstances], we need to
88
+ map the index to the dataset before retrieving the item.
89
+ """
90
+ return OrderedDict(
91
+ [
92
+ (key, dataset[self._map_index_to_dataset(key, index)])
93
+ for key, dataset in self.datasets.items()
94
+ ]
95
+ )
96
+
97
+ def collater(self, samples: List[Dict]):
98
+ """
99
+ Generate a mini-batch for this dataset.
100
+ To convert this into a regular mini-batch we use the following
101
+ logic:
102
+ 1. Select a dataset using the specified probability distribution.
103
+ 2. Call the collater function of the selected dataset.
104
+ """
105
+ if len(samples) == 0:
106
+ return None
107
+
108
+ selected_key = self.sampling_func(list(self.datasets.keys()))
109
+ selected_samples = [sample[selected_key] for sample in samples]
110
+ return self.datasets[selected_key].collater(selected_samples)
111
+
112
+ def num_tokens(self, index: int):
113
+ """
114
+ Return an example's length (number of tokens), used for batching. Here
115
+ we return the max across all examples at index across all underlying
116
+ datasets.
117
+ """
118
+ return max(
119
+ dataset.num_tokens(self._map_index_to_dataset(key, index))
120
+ for key, dataset in self.datasets.items()
121
+ )
122
+
123
+ def size(self, index: int):
124
+ """
125
+ Return an example's size as a float or tuple. Here we return the max
126
+ across all underlying datasets. This value is used when filtering a
127
+ dataset with max-positions.
128
+ """
129
+ return max(
130
+ dataset.size(self._map_index_to_dataset(key, index))
131
+ for key, dataset in self.datasets.items()
132
+ )
133
+
134
+ @property
135
+ def supports_prefetch(self):
136
+ return all(
137
+ getattr(dataset, "supports_prefetch", False)
138
+ for dataset in self.datasets.values()
139
+ )
140
+
141
+ def prefetch(self, indices):
142
+ for key, dataset in self.datasets.items():
143
+ dataset.prefetch(
144
+ [self._map_index_to_dataset(key, index) for index in indices]
145
+ )
fairseq-0.10.2/fairseq/data/num_samples_dataset.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from . import FairseqDataset
7
+
8
+
9
+ class NumSamplesDataset(FairseqDataset):
10
+ def __getitem__(self, index):
11
+ return 1
12
+
13
+ def __len__(self):
14
+ return 0
15
+
16
+ def collater(self, samples):
17
+ return sum(samples)
fairseq-0.10.2/fairseq/data/offset_tokens_dataset.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from . import BaseWrapperDataset
7
+
8
+
9
+ class OffsetTokensDataset(BaseWrapperDataset):
10
+ def __init__(self, dataset, offset):
11
+ super().__init__(dataset)
12
+ self.offset = offset
13
+
14
+ def __getitem__(self, idx):
15
+ return self.dataset[idx] + self.offset
fairseq-0.10.2/fairseq/data/pad_dataset.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq.data import data_utils
7
+
8
+ from . import BaseWrapperDataset
9
+
10
+
11
+ class PadDataset(BaseWrapperDataset):
12
+ def __init__(self, dataset, pad_idx, left_pad):
13
+ super().__init__(dataset)
14
+ self.pad_idx = pad_idx
15
+ self.left_pad = left_pad
16
+
17
+ def collater(self, samples):
18
+ return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad)
19
+
20
+
21
+ class LeftPadDataset(PadDataset):
22
+ def __init__(self, dataset, pad_idx):
23
+ super().__init__(dataset, pad_idx, left_pad=True)
24
+
25
+
26
+ class RightPadDataset(PadDataset):
27
+ def __init__(self, dataset, pad_idx):
28
+ super().__init__(dataset, pad_idx, left_pad=False)
fairseq-0.10.2/fairseq/data/prepend_dataset.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from . import BaseWrapperDataset
10
+
11
+
12
+ class PrependDataset(BaseWrapperDataset):
13
+ def __init__(self, dataset, prepend_getter, ensure_first_token_is=None):
14
+ super().__init__(dataset)
15
+ self.prepend_getter = prepend_getter
16
+ self.ensure_first_token = ensure_first_token_is
17
+
18
+ def __getitem__(self, idx):
19
+ item = self.dataset[idx]
20
+ is_tuple = isinstance(item, tuple)
21
+ src = item[0] if is_tuple else item
22
+
23
+ assert self.ensure_first_token is None or src[0] == self.ensure_first_token
24
+ prepend_idx = self.prepend_getter(self.dataset, idx)
25
+ assert isinstance(prepend_idx, int)
26
+ src[0] = prepend_idx
27
+ item = tuple((src,) + item[1:]) if is_tuple else src
28
+ return item
fairseq-0.10.2/fairseq/data/resampling_dataset.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ import numpy as np
9
+ from fairseq.data import BaseWrapperDataset, plasma_utils
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ResamplingDataset(BaseWrapperDataset):
16
+ """Randomly samples from a given dataset at each epoch.
17
+
18
+ Sampling is done with or without replacement, depending on the "replace"
19
+ parameter.
20
+
21
+ Optionally, the epoch size can be rescaled. This is potentially desirable
22
+ to increase per-epoch coverage of the base dataset (since sampling with
23
+ replacement means that many items in the dataset will be left out). In the
24
+ case of sampling without replacement, size_ratio should be strictly less
25
+ than 1.
26
+
27
+ Args:
28
+ dataset (~torch.utils.data.Dataset): dataset on which to sample.
29
+ weights (List[float]): list of probability weights
30
+ (default: None, which corresponds to uniform sampling).
31
+ replace (bool): sampling mode; True for "with replacement", or False
32
+ for "without replacement" (default: True)
33
+ size_ratio (float): the ratio to subsample to; must be positive
34
+ (default: 1.0).
35
+ batch_by_size (bool): whether or not to batch by sequence length
36
+ (default: True).
37
+ seed (int): RNG seed to use (default: 0).
38
+ epoch (int): starting epoch number (default: 1).
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ dataset,
44
+ weights=None,
45
+ replace=True,
46
+ size_ratio=1.0,
47
+ batch_by_size=True,
48
+ seed=0,
49
+ epoch=1,
50
+ ):
51
+ super().__init__(dataset)
52
+
53
+ if weights is None:
54
+ self.weights = None
55
+
56
+ else:
57
+ assert len(weights) == len(dataset)
58
+ weights_arr = np.array(weights, dtype=np.float64)
59
+ weights_arr /= weights_arr.sum()
60
+ self.weights = plasma_utils.PlasmaArray(weights_arr)
61
+
62
+ self.replace = replace
63
+
64
+ assert size_ratio > 0.0
65
+ if not self.replace:
66
+ assert size_ratio < 1.0
67
+ self.size_ratio = float(size_ratio)
68
+ self.actual_size = np.ceil(len(dataset) * self.size_ratio).astype(int)
69
+
70
+ self.batch_by_size = batch_by_size
71
+ self.seed = seed
72
+
73
+ self._cur_epoch = None
74
+ self._cur_indices = None
75
+
76
+ self.set_epoch(epoch)
77
+
78
+ def __getitem__(self, index):
79
+ return self.dataset[self._cur_indices.array[index]]
80
+
81
+ def __len__(self):
82
+ return self.actual_size
83
+
84
+ @property
85
+ def sizes(self):
86
+ if isinstance(self.dataset.sizes, list):
87
+ return [s[self._cur_indices.array] for s in self.dataset.sizes]
88
+ return self.dataset.sizes[self._cur_indices.array]
89
+
90
+ def num_tokens(self, index):
91
+ return self.dataset.num_tokens(self._cur_indices.array[index])
92
+
93
+ def size(self, index):
94
+ return self.dataset.size(self._cur_indices.array[index])
95
+
96
+ def ordered_indices(self):
97
+ if self.batch_by_size:
98
+ order = [
99
+ np.arange(len(self)),
100
+ self.sizes,
101
+ ] # No need to handle `self.shuffle == True`
102
+ return np.lexsort(order)
103
+ else:
104
+ return np.arange(len(self))
105
+
106
+ def prefetch(self, indices):
107
+ self.dataset.prefetch(self._cur_indices.array[indices])
108
+
109
+ @property
110
+ def can_reuse_epoch_itr_across_epochs(self):
111
+ return False
112
+
113
+ def set_epoch(self, epoch):
114
+ logger.debug("ResamplingDataset.set_epoch: {}".format(epoch))
115
+ super().set_epoch(epoch)
116
+
117
+ if epoch == self._cur_epoch:
118
+ return
119
+
120
+ self._cur_epoch = epoch
121
+
122
+ # Generate a weighted sample of indices as a function of the
123
+ # random seed and the current epoch.
124
+
125
+ rng = np.random.RandomState(
126
+ [
127
+ 42, # magic number
128
+ self.seed % (2 ** 32), # global seed
129
+ self._cur_epoch, # epoch index
130
+ ]
131
+ )
132
+ self._cur_indices = plasma_utils.PlasmaArray(
133
+ rng.choice(
134
+ len(self.dataset),
135
+ self.actual_size,
136
+ replace=self.replace,
137
+ p=(None if self.weights is None else self.weights.array),
138
+ )
139
+ )
fairseq-0.10.2/fairseq/data/shorten_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ from fairseq.data import data_utils
8
+
9
+ from . import BaseWrapperDataset
10
+
11
+
12
+ class TruncateDataset(BaseWrapperDataset):
13
+ """Truncate a sequence by returning the first truncation_length tokens"""
14
+
15
+ def __init__(self, dataset, truncation_length):
16
+ super().__init__(dataset)
17
+ assert truncation_length is not None
18
+ self.truncation_length = truncation_length
19
+ self.dataset = dataset
20
+
21
+ def __getitem__(self, index):
22
+ item = self.dataset[index]
23
+ item_len = item.size(0)
24
+ if item_len > self.truncation_length:
25
+ item = item[: self.truncation_length]
26
+ return item
27
+
28
+ @property
29
+ def sizes(self):
30
+ return np.minimum(self.dataset.sizes, self.truncation_length)
31
+
32
+ def __len__(self):
33
+ return len(self.dataset)
34
+
35
+
36
+ class RandomCropDataset(TruncateDataset):
37
+ """Truncate a sequence by returning a random crop of truncation_length tokens"""
38
+
39
+ def __init__(self, dataset, truncation_length, seed=1):
40
+ super().__init__(dataset, truncation_length)
41
+ self.seed = seed
42
+ self.epoch = 0
43
+
44
+ @property
45
+ def can_reuse_epoch_itr_across_epochs(self):
46
+ return True # only the crop changes, not item sizes
47
+
48
+ def set_epoch(self, epoch, **unused):
49
+ super().set_epoch(epoch)
50
+ self.epoch = epoch
51
+
52
+ def __getitem__(self, index):
53
+ with data_utils.numpy_seed(self.seed, self.epoch, index):
54
+ item = self.dataset[index]
55
+ item_len = item.size(0)
56
+ excess = item_len - self.truncation_length
57
+ if excess > 0:
58
+ start_idx = np.random.randint(0, excess)
59
+ item = item[start_idx : start_idx + self.truncation_length]
60
+ return item
61
+
62
+
63
+ def maybe_shorten_dataset(
64
+ dataset,
65
+ split,
66
+ shorten_data_split_list,
67
+ shorten_method,
68
+ tokens_per_sample,
69
+ seed,
70
+ ):
71
+ truncate_split = (
72
+ split in shorten_data_split_list.split(",") or len(shorten_data_split_list) == 0
73
+ )
74
+ if shorten_method == "truncate" and truncate_split:
75
+ dataset = TruncateDataset(dataset, tokens_per_sample)
76
+ elif shorten_method == "random_crop" and truncate_split:
77
+ dataset = RandomCropDataset(dataset, tokens_per_sample, seed)
78
+ return dataset
fairseq-0.10.2/fairseq/data/strip_token_dataset.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from . import BaseWrapperDataset
7
+
8
+
9
+ class StripTokenDataset(BaseWrapperDataset):
10
+ def __init__(self, dataset, id_to_strip):
11
+ super().__init__(dataset)
12
+ self.id_to_strip = id_to_strip
13
+
14
+ def __getitem__(self, index):
15
+ item = self.dataset[index]
16
+ while len(item) > 0 and item[-1] == self.id_to_strip:
17
+ item = item[:-1]
18
+ while len(item) > 0 and item[0] == self.id_to_strip:
19
+ item = item[1:]
20
+ return item
fairseq-0.10.2/fairseq/data/token_block_dataset.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ from fairseq.data import FairseqDataset, plasma_utils
9
+
10
+
11
+ class TokenBlockDataset(FairseqDataset):
12
+ """Break a Dataset of tokens into blocks.
13
+
14
+ Args:
15
+ dataset (~torch.utils.data.Dataset): dataset to break into blocks
16
+ sizes (List[int]): sentence lengths (required for 'complete' and 'eos')
17
+ block_size (int): maximum block size (ignored in 'eos' break mode)
18
+ break_mode (str, optional): Mode used for breaking tokens. Values can
19
+ be one of:
20
+ - 'none': break tokens into equally sized blocks (up to block_size)
21
+ - 'complete': break tokens into blocks (up to block_size) such that
22
+ blocks contains complete sentences, although block_size may be
23
+ exceeded if some sentences exceed block_size
24
+ - 'complete_doc': similar to 'complete' mode, but do not
25
+ cross document boundaries
26
+ - 'eos': each block contains one sentence (block_size is ignored)
27
+ include_targets (bool, optional): return next tokens as targets
28
+ (default: False).
29
+ document_sep_len (int, optional): document separator size (required for
30
+ 'complete_doc' break mode). Typically 1 if the sentences have eos
31
+ and 0 otherwise.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ dataset,
37
+ sizes,
38
+ block_size,
39
+ pad,
40
+ eos,
41
+ break_mode=None,
42
+ include_targets=False,
43
+ document_sep_len=1,
44
+ ):
45
+ try:
46
+ from fairseq.data.token_block_utils_fast import (
47
+ _get_slice_indices_fast,
48
+ _get_block_to_dataset_index_fast,
49
+ )
50
+ except ImportError:
51
+ raise ImportError(
52
+ "Please build Cython components with: `pip install --editable .` "
53
+ "or `python setup.py build_ext --inplace`"
54
+ )
55
+
56
+ super().__init__()
57
+ self.dataset = dataset
58
+ self.pad = pad
59
+ self.eos = eos
60
+ self.include_targets = include_targets
61
+
62
+ assert len(dataset) == len(sizes)
63
+ assert len(dataset) > 0
64
+
65
+ if isinstance(sizes, list):
66
+ sizes = np.array(sizes, dtype=np.int64)
67
+ else:
68
+ if torch.is_tensor(sizes):
69
+ sizes = sizes.numpy()
70
+ sizes = sizes.astype(np.int64)
71
+
72
+ break_mode = break_mode if break_mode is not None else "none"
73
+
74
+ # For "eos" break-mode, block_size is not required parameters.
75
+ if break_mode == "eos" and block_size is None:
76
+ block_size = 0
77
+
78
+ slice_indices = _get_slice_indices_fast(
79
+ sizes, str(break_mode), block_size, document_sep_len
80
+ )
81
+ self._sizes = slice_indices[:, 1] - slice_indices[:, 0]
82
+
83
+ # build index mapping block indices to the underlying dataset indices
84
+ if break_mode == "eos":
85
+ # much faster version for eos break mode
86
+ block_to_dataset_index = np.stack(
87
+ [
88
+ np.arange(len(sizes)), # starting index in dataset
89
+ np.zeros(
90
+ len(sizes), dtype=np.long
91
+ ), # starting offset within starting index
92
+ np.arange(len(sizes)), # ending index in dataset
93
+ ],
94
+ 1,
95
+ )
96
+ else:
97
+ block_to_dataset_index = _get_block_to_dataset_index_fast(
98
+ sizes,
99
+ slice_indices,
100
+ )
101
+ self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
102
+ self._sizes = plasma_utils.PlasmaArray(self._sizes)
103
+ self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index)
104
+
105
+ @property
106
+ def slice_indices(self):
107
+ return self._slice_indices.array
108
+
109
+ @property
110
+ def sizes(self):
111
+ return self._sizes.array
112
+
113
+ @property
114
+ def block_to_dataset_index(self):
115
+ return self._block_to_dataset_index.array
116
+
117
+ def attr(self, attr: str, index: int):
118
+ start_ds_idx, _, _ = self.block_to_dataset_index[index]
119
+ return self.dataset.attr(attr, start_ds_idx)
120
+
121
+ def __getitem__(self, index):
122
+ start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
123
+
124
+ buffer = torch.cat(
125
+ [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
126
+ )
127
+
128
+ slice_s, slice_e = self.slice_indices[index]
129
+ length = slice_e - slice_s
130
+ s, e = start_offset, start_offset + length
131
+ item = buffer[s:e]
132
+
133
+ if self.include_targets:
134
+ # *target* is the original sentence (=item)
135
+ # *source* is shifted right by 1 (maybe left-padded with eos)
136
+ # *past_target* is shifted right by 2 (left-padded as needed)
137
+ if s == 0:
138
+ source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]])
139
+ past_target = torch.cat(
140
+ [item.new([self.pad, self.eos]), buffer[0 : e - 2]]
141
+ )
142
+ else:
143
+ source = buffer[s - 1 : e - 1]
144
+ if s == 1:
145
+ past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]])
146
+ else:
147
+ past_target = buffer[s - 2 : e - 2]
148
+
149
+ return source, item, past_target
150
+
151
+ return item
152
+
153
+ def __len__(self):
154
+ return len(self.slice_indices)
155
+
156
+ @property
157
+ def supports_prefetch(self):
158
+ return getattr(self.dataset, "supports_prefetch", False)
159
+
160
+ def prefetch(self, indices):
161
+ self.dataset.prefetch(
162
+ {
163
+ ds_idx
164
+ for index in indices
165
+ for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]]
166
+ for ds_idx in range(start_ds_idx, end_ds_idx + 1)
167
+ }
168
+ )