Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- fairseq-0.10.2/fairseq/data/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/dictionary.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/id_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/concat_dataset.py +124 -0
- fairseq-0.10.2/fairseq/data/concat_sentences_dataset.py +54 -0
- fairseq-0.10.2/fairseq/data/dictionary.py +387 -0
- fairseq-0.10.2/fairseq/data/encoders/byte_utils.py +51 -0
- fairseq-0.10.2/fairseq/data/encoders/fastbpe.py +35 -0
- fairseq-0.10.2/fairseq/data/encoders/gpt2_bpe_utils.py +140 -0
- fairseq-0.10.2/fairseq/data/encoders/moses_tokenizer.py +51 -0
- fairseq-0.10.2/fairseq/data/encoders/nltk_tokenizer.py +23 -0
- fairseq-0.10.2/fairseq/data/encoders/subword_nmt_bpe.py +54 -0
- fairseq-0.10.2/fairseq/data/encoders/utils.py +30 -0
- fairseq-0.10.2/fairseq/data/fairseq_dataset.py +191 -0
- fairseq-0.10.2/fairseq/data/fasta_dataset.py +107 -0
- fairseq-0.10.2/fairseq/data/id_dataset.py +19 -0
- fairseq-0.10.2/fairseq/data/language_pair_dataset.py +475 -0
- fairseq-0.10.2/fairseq/data/legacy/__init__.py +16 -0
- fairseq-0.10.2/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/legacy/block_pair_dataset.py +311 -0
- fairseq-0.10.2/fairseq/data/legacy/masked_lm_dictionary.py +60 -0
- fairseq-0.10.2/fairseq/data/list_dataset.py +32 -0
- fairseq-0.10.2/fairseq/data/lru_cache_dataset.py +21 -0
- fairseq-0.10.2/fairseq/data/mask_tokens_dataset.py +178 -0
- fairseq-0.10.2/fairseq/data/multi_corpus_dataset.py +159 -0
- fairseq-0.10.2/fairseq/data/multi_corpus_sampled_dataset.py +145 -0
- fairseq-0.10.2/fairseq/data/num_samples_dataset.py +17 -0
- fairseq-0.10.2/fairseq/data/offset_tokens_dataset.py +15 -0
- fairseq-0.10.2/fairseq/data/pad_dataset.py +28 -0
- fairseq-0.10.2/fairseq/data/prepend_dataset.py +28 -0
- fairseq-0.10.2/fairseq/data/resampling_dataset.py +139 -0
- fairseq-0.10.2/fairseq/data/shorten_dataset.py +78 -0
- fairseq-0.10.2/fairseq/data/strip_token_dataset.py +20 -0
- fairseq-0.10.2/fairseq/data/token_block_dataset.py +168 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
fairseq-0.10.2/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
fairseq-0.10.2/fairseq/data/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (3.55 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc
ADDED
|
Binary file (2.38 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc
ADDED
|
Binary file (1.4 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc
ADDED
|
Binary file (2.49 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc
ADDED
|
Binary file (1.21 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc
ADDED
|
Binary file (5.04 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/dictionary.cpython-310.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/id_dataset.cpython-310.pyc
ADDED
|
Binary file (784 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc
ADDED
|
Binary file (6.33 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc
ADDED
|
Binary file (5.36 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc
ADDED
|
Binary file (759 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc
ADDED
|
Binary file (4.41 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc
ADDED
|
Binary file (777 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc
ADDED
|
Binary file (5.01 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/concat_dataset.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import bisect
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from torch.utils.data.dataloader import default_collate
|
| 10 |
+
|
| 11 |
+
from . import FairseqDataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ConcatDataset(FairseqDataset):
|
| 15 |
+
@staticmethod
|
| 16 |
+
def cumsum(sequence, sample_ratios):
|
| 17 |
+
r, s = [], 0
|
| 18 |
+
for e, ratio in zip(sequence, sample_ratios):
|
| 19 |
+
curr_len = int(ratio * len(e))
|
| 20 |
+
r.append(curr_len + s)
|
| 21 |
+
s += curr_len
|
| 22 |
+
return r
|
| 23 |
+
|
| 24 |
+
def __init__(self, datasets, sample_ratios=1):
|
| 25 |
+
super(ConcatDataset, self).__init__()
|
| 26 |
+
assert len(datasets) > 0, "datasets should not be an empty iterable"
|
| 27 |
+
self.datasets = list(datasets)
|
| 28 |
+
if isinstance(sample_ratios, int):
|
| 29 |
+
sample_ratios = [sample_ratios] * len(self.datasets)
|
| 30 |
+
self.sample_ratios = sample_ratios
|
| 31 |
+
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
|
| 32 |
+
self.real_sizes = [len(d) for d in self.datasets]
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
return self.cumulative_sizes[-1]
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, idx):
|
| 38 |
+
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
| 39 |
+
return self.datasets[dataset_idx][sample_idx]
|
| 40 |
+
|
| 41 |
+
def _get_dataset_and_sample_index(self, idx: int):
|
| 42 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
| 43 |
+
if dataset_idx == 0:
|
| 44 |
+
sample_idx = idx
|
| 45 |
+
else:
|
| 46 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
| 47 |
+
sample_idx = sample_idx % self.real_sizes[dataset_idx]
|
| 48 |
+
return dataset_idx, sample_idx
|
| 49 |
+
|
| 50 |
+
def collater(self, samples, **extra_args):
|
| 51 |
+
# For now only supports datasets with same underlying collater implementations
|
| 52 |
+
if hasattr(self.datasets[0], "collater"):
|
| 53 |
+
return self.datasets[0].collater(samples, **extra_args)
|
| 54 |
+
else:
|
| 55 |
+
return default_collate(samples, **extra_args)
|
| 56 |
+
|
| 57 |
+
def size(self, idx: int):
|
| 58 |
+
"""
|
| 59 |
+
Return an example's size as a float or tuple.
|
| 60 |
+
"""
|
| 61 |
+
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
| 62 |
+
return self.datasets[dataset_idx].size(sample_idx)
|
| 63 |
+
|
| 64 |
+
def num_tokens(self, index: int):
|
| 65 |
+
return np.max(self.size(index))
|
| 66 |
+
|
| 67 |
+
def attr(self, attr: str, index: int):
|
| 68 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
|
| 69 |
+
return getattr(self.datasets[dataset_idx], attr, None)
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def sizes(self):
|
| 73 |
+
_dataset_sizes = []
|
| 74 |
+
for ds, sr in zip(self.datasets, self.sample_ratios):
|
| 75 |
+
if isinstance(ds.sizes, np.ndarray):
|
| 76 |
+
_dataset_sizes.append(np.tile(ds.sizes, sr))
|
| 77 |
+
else:
|
| 78 |
+
# Only support underlying dataset with single size array.
|
| 79 |
+
assert isinstance(ds.sizes, list)
|
| 80 |
+
_dataset_sizes.append(np.tile(ds.sizes[0], sr))
|
| 81 |
+
return np.concatenate(_dataset_sizes)
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def supports_prefetch(self):
|
| 85 |
+
return all(d.supports_prefetch for d in self.datasets)
|
| 86 |
+
|
| 87 |
+
def ordered_indices(self):
|
| 88 |
+
"""
|
| 89 |
+
Returns indices sorted by length. So less padding is needed.
|
| 90 |
+
"""
|
| 91 |
+
if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
|
| 92 |
+
# special handling for concatenating lang_pair_datasets
|
| 93 |
+
indices = np.arange(len(self))
|
| 94 |
+
sizes = self.sizes
|
| 95 |
+
tgt_sizes = (
|
| 96 |
+
sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
|
| 97 |
+
)
|
| 98 |
+
src_sizes = (
|
| 99 |
+
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
|
| 100 |
+
)
|
| 101 |
+
# sort by target length, then source length
|
| 102 |
+
if tgt_sizes is not None:
|
| 103 |
+
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
|
| 104 |
+
return indices[np.argsort(src_sizes[indices], kind="mergesort")]
|
| 105 |
+
else:
|
| 106 |
+
return np.argsort(self.sizes)
|
| 107 |
+
|
| 108 |
+
def prefetch(self, indices):
|
| 109 |
+
frm = 0
|
| 110 |
+
for to, ds in zip(self.cumulative_sizes, self.datasets):
|
| 111 |
+
real_size = len(ds)
|
| 112 |
+
if getattr(ds, "supports_prefetch", False):
|
| 113 |
+
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
|
| 114 |
+
frm = to
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
| 118 |
+
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
|
| 119 |
+
|
| 120 |
+
def set_epoch(self, epoch):
|
| 121 |
+
super().set_epoch(epoch)
|
| 122 |
+
for ds in self.datasets:
|
| 123 |
+
if hasattr(ds, "set_epoch"):
|
| 124 |
+
ds.set_epoch(epoch)
|
fairseq-0.10.2/fairseq/data/concat_sentences_dataset.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from . import FairseqDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ConcatSentencesDataset(FairseqDataset):
|
| 12 |
+
def __init__(self, *datasets):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.datasets = datasets
|
| 15 |
+
assert all(
|
| 16 |
+
len(ds) == len(datasets[0]) for ds in datasets
|
| 17 |
+
), "datasets must have the same length"
|
| 18 |
+
|
| 19 |
+
def __getitem__(self, index):
|
| 20 |
+
return torch.cat([ds[index] for ds in self.datasets])
|
| 21 |
+
|
| 22 |
+
def __len__(self):
|
| 23 |
+
return len(self.datasets[0])
|
| 24 |
+
|
| 25 |
+
def collater(self, samples):
|
| 26 |
+
return self.datasets[0].collater(samples)
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def sizes(self):
|
| 30 |
+
return sum(ds.sizes for ds in self.datasets)
|
| 31 |
+
|
| 32 |
+
def num_tokens(self, index):
|
| 33 |
+
return sum(ds.num_tokens(index) for ds in self.datasets)
|
| 34 |
+
|
| 35 |
+
def size(self, index):
|
| 36 |
+
return sum(ds.size(index) for ds in self.datasets)
|
| 37 |
+
|
| 38 |
+
def ordered_indices(self):
|
| 39 |
+
return self.datasets[0].ordered_indices()
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def supports_prefetch(self):
|
| 43 |
+
return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)
|
| 44 |
+
|
| 45 |
+
def prefetch(self, indices):
|
| 46 |
+
for ds in self.datasets:
|
| 47 |
+
if getattr(ds, "supports_prefetch", False):
|
| 48 |
+
ds.prefetch(indices)
|
| 49 |
+
|
| 50 |
+
def set_epoch(self, epoch):
|
| 51 |
+
super().set_epoch(epoch)
|
| 52 |
+
for ds in self.datasets:
|
| 53 |
+
if hasattr(ds, "set_epoch"):
|
| 54 |
+
ds.set_epoch(epoch)
|
fairseq-0.10.2/fairseq/data/dictionary.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from collections import Counter
|
| 8 |
+
from multiprocessing import Pool
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from fairseq import utils
|
| 12 |
+
from fairseq.binarizer import safe_readline
|
| 13 |
+
from fairseq.data import data_utils
|
| 14 |
+
from fairseq.file_io import PathManager
|
| 15 |
+
from fairseq.tokenizer import tokenize_line
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Dictionary(object):
|
| 19 |
+
"""A mapping from symbols to consecutive integers"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
*, # begin keyword-only arguments
|
| 24 |
+
bos="<s>",
|
| 25 |
+
pad="<pad>",
|
| 26 |
+
eos="</s>",
|
| 27 |
+
unk="<unk>",
|
| 28 |
+
extra_special_symbols=None,
|
| 29 |
+
):
|
| 30 |
+
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
| 31 |
+
self.symbols = []
|
| 32 |
+
self.count = []
|
| 33 |
+
self.indices = {}
|
| 34 |
+
self.bos_index = self.add_symbol(bos)
|
| 35 |
+
self.pad_index = self.add_symbol(pad)
|
| 36 |
+
self.eos_index = self.add_symbol(eos)
|
| 37 |
+
self.unk_index = self.add_symbol(unk)
|
| 38 |
+
if extra_special_symbols:
|
| 39 |
+
for s in extra_special_symbols:
|
| 40 |
+
self.add_symbol(s)
|
| 41 |
+
self.nspecial = len(self.symbols)
|
| 42 |
+
|
| 43 |
+
def __eq__(self, other):
|
| 44 |
+
return self.indices == other.indices
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, idx):
|
| 47 |
+
if idx < len(self.symbols):
|
| 48 |
+
return self.symbols[idx]
|
| 49 |
+
return self.unk_word
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
"""Returns the number of symbols in the dictionary"""
|
| 53 |
+
return len(self.symbols)
|
| 54 |
+
|
| 55 |
+
def __contains__(self, sym):
|
| 56 |
+
return sym in self.indices
|
| 57 |
+
|
| 58 |
+
def index(self, sym):
|
| 59 |
+
"""Returns the index of the specified symbol"""
|
| 60 |
+
assert isinstance(sym, str)
|
| 61 |
+
if sym in self.indices:
|
| 62 |
+
return self.indices[sym]
|
| 63 |
+
return self.unk_index
|
| 64 |
+
|
| 65 |
+
def string(
|
| 66 |
+
self,
|
| 67 |
+
tensor,
|
| 68 |
+
bpe_symbol=None,
|
| 69 |
+
escape_unk=False,
|
| 70 |
+
extra_symbols_to_ignore=None,
|
| 71 |
+
unk_string=None,
|
| 72 |
+
):
|
| 73 |
+
"""Helper for converting a tensor of token indices to a string.
|
| 74 |
+
|
| 75 |
+
Can optionally remove BPE symbols or escape <unk> words.
|
| 76 |
+
"""
|
| 77 |
+
if torch.is_tensor(tensor) and tensor.dim() == 2:
|
| 78 |
+
return "\n".join(
|
| 79 |
+
self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore)
|
| 80 |
+
for t in tensor
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
|
| 84 |
+
extra_symbols_to_ignore.add(self.eos())
|
| 85 |
+
|
| 86 |
+
def token_string(i):
|
| 87 |
+
if i == self.unk():
|
| 88 |
+
if unk_string is not None:
|
| 89 |
+
return unk_string
|
| 90 |
+
else:
|
| 91 |
+
return self.unk_string(escape_unk)
|
| 92 |
+
else:
|
| 93 |
+
return self[i]
|
| 94 |
+
|
| 95 |
+
if hasattr(self, "bos_index"):
|
| 96 |
+
extra_symbols_to_ignore.add(self.bos())
|
| 97 |
+
|
| 98 |
+
sent = " ".join(
|
| 99 |
+
token_string(i)
|
| 100 |
+
for i in tensor
|
| 101 |
+
if utils.item(i) not in extra_symbols_to_ignore
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return data_utils.post_process(sent, bpe_symbol)
|
| 105 |
+
|
| 106 |
+
def unk_string(self, escape=False):
|
| 107 |
+
"""Return unknown string, optionally escaped as: <<unk>>"""
|
| 108 |
+
if escape:
|
| 109 |
+
return "<{}>".format(self.unk_word)
|
| 110 |
+
else:
|
| 111 |
+
return self.unk_word
|
| 112 |
+
|
| 113 |
+
def add_symbol(self, word, n=1, overwrite=False):
|
| 114 |
+
"""Adds a word to the dictionary"""
|
| 115 |
+
if word in self.indices and not overwrite:
|
| 116 |
+
idx = self.indices[word]
|
| 117 |
+
self.count[idx] = self.count[idx] + n
|
| 118 |
+
return idx
|
| 119 |
+
else:
|
| 120 |
+
idx = len(self.symbols)
|
| 121 |
+
self.indices[word] = idx
|
| 122 |
+
self.symbols.append(word)
|
| 123 |
+
self.count.append(n)
|
| 124 |
+
return idx
|
| 125 |
+
|
| 126 |
+
def update(self, new_dict):
|
| 127 |
+
"""Updates counts from new dictionary."""
|
| 128 |
+
for word in new_dict.symbols:
|
| 129 |
+
idx2 = new_dict.indices[word]
|
| 130 |
+
if word in self.indices:
|
| 131 |
+
idx = self.indices[word]
|
| 132 |
+
self.count[idx] = self.count[idx] + new_dict.count[idx2]
|
| 133 |
+
else:
|
| 134 |
+
idx = len(self.symbols)
|
| 135 |
+
self.indices[word] = idx
|
| 136 |
+
self.symbols.append(word)
|
| 137 |
+
self.count.append(new_dict.count[idx2])
|
| 138 |
+
|
| 139 |
+
def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
|
| 140 |
+
"""Sort symbols by frequency in descending order, ignoring special ones.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
- threshold defines the minimum word count
|
| 144 |
+
- nwords defines the total number of words in the final dictionary,
|
| 145 |
+
including special symbols
|
| 146 |
+
- padding_factor can be used to pad the dictionary size to be a
|
| 147 |
+
multiple of 8, which is important on some hardware (e.g., Nvidia
|
| 148 |
+
Tensor Cores).
|
| 149 |
+
"""
|
| 150 |
+
if nwords <= 0:
|
| 151 |
+
nwords = len(self)
|
| 152 |
+
|
| 153 |
+
new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial)))
|
| 154 |
+
new_symbols = self.symbols[: self.nspecial]
|
| 155 |
+
new_count = self.count[: self.nspecial]
|
| 156 |
+
|
| 157 |
+
c = Counter(
|
| 158 |
+
dict(
|
| 159 |
+
sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :]))
|
| 160 |
+
)
|
| 161 |
+
)
|
| 162 |
+
for symbol, count in c.most_common(nwords - self.nspecial):
|
| 163 |
+
if count >= threshold:
|
| 164 |
+
new_indices[symbol] = len(new_symbols)
|
| 165 |
+
new_symbols.append(symbol)
|
| 166 |
+
new_count.append(count)
|
| 167 |
+
else:
|
| 168 |
+
break
|
| 169 |
+
|
| 170 |
+
assert len(new_symbols) == len(new_indices)
|
| 171 |
+
|
| 172 |
+
self.count = list(new_count)
|
| 173 |
+
self.symbols = list(new_symbols)
|
| 174 |
+
self.indices = new_indices
|
| 175 |
+
|
| 176 |
+
self.pad_to_multiple_(padding_factor)
|
| 177 |
+
|
| 178 |
+
def pad_to_multiple_(self, padding_factor):
|
| 179 |
+
"""Pad Dictionary size to be a multiple of *padding_factor*."""
|
| 180 |
+
if padding_factor > 1:
|
| 181 |
+
i = 0
|
| 182 |
+
while len(self) % padding_factor != 0:
|
| 183 |
+
symbol = "madeupword{:04d}".format(i)
|
| 184 |
+
self.add_symbol(symbol, n=0)
|
| 185 |
+
i += 1
|
| 186 |
+
|
| 187 |
+
def bos(self):
|
| 188 |
+
"""Helper to get index of beginning-of-sentence symbol"""
|
| 189 |
+
return self.bos_index
|
| 190 |
+
|
| 191 |
+
def pad(self):
|
| 192 |
+
"""Helper to get index of pad symbol"""
|
| 193 |
+
return self.pad_index
|
| 194 |
+
|
| 195 |
+
def eos(self):
|
| 196 |
+
"""Helper to get index of end-of-sentence symbol"""
|
| 197 |
+
return self.eos_index
|
| 198 |
+
|
| 199 |
+
def unk(self):
|
| 200 |
+
"""Helper to get index of unk symbol"""
|
| 201 |
+
return self.unk_index
|
| 202 |
+
|
| 203 |
+
@classmethod
|
| 204 |
+
def load(cls, f):
|
| 205 |
+
"""Loads the dictionary from a text file with the format:
|
| 206 |
+
|
| 207 |
+
```
|
| 208 |
+
<symbol0> <count0>
|
| 209 |
+
<symbol1> <count1>
|
| 210 |
+
...
|
| 211 |
+
```
|
| 212 |
+
"""
|
| 213 |
+
d = cls()
|
| 214 |
+
d.add_from_file(f)
|
| 215 |
+
return d
|
| 216 |
+
|
| 217 |
+
def add_from_file(self, f):
|
| 218 |
+
"""
|
| 219 |
+
Loads a pre-existing dictionary from a text file and adds its symbols
|
| 220 |
+
to this instance.
|
| 221 |
+
"""
|
| 222 |
+
if isinstance(f, str):
|
| 223 |
+
try:
|
| 224 |
+
with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
|
| 225 |
+
self.add_from_file(fd)
|
| 226 |
+
except FileNotFoundError as fnfe:
|
| 227 |
+
raise fnfe
|
| 228 |
+
except UnicodeError:
|
| 229 |
+
raise Exception(
|
| 230 |
+
"Incorrect encoding detected in {}, please "
|
| 231 |
+
"rebuild the dataset".format(f)
|
| 232 |
+
)
|
| 233 |
+
return
|
| 234 |
+
|
| 235 |
+
lines = f.readlines()
|
| 236 |
+
indices_start_line = self._load_meta(lines)
|
| 237 |
+
|
| 238 |
+
for line in lines[indices_start_line:]:
|
| 239 |
+
try:
|
| 240 |
+
line, field = line.rstrip().rsplit(" ", 1)
|
| 241 |
+
if field == "#fairseq:overwrite":
|
| 242 |
+
overwrite = True
|
| 243 |
+
line, field = line.rsplit(" ", 1)
|
| 244 |
+
else:
|
| 245 |
+
overwrite = False
|
| 246 |
+
count = int(field)
|
| 247 |
+
word = line
|
| 248 |
+
if word in self and not overwrite:
|
| 249 |
+
raise RuntimeError(
|
| 250 |
+
"Duplicate word found when loading Dictionary: '{}'. "
|
| 251 |
+
"Duplicate words can overwrite earlier ones by adding the "
|
| 252 |
+
"#fairseq:overwrite flag at the end of the corresponding row "
|
| 253 |
+
"in the dictionary file. If using the Camembert model, please "
|
| 254 |
+
"download an updated copy of the model file.".format(word)
|
| 255 |
+
)
|
| 256 |
+
self.add_symbol(word, n=count, overwrite=overwrite)
|
| 257 |
+
except ValueError:
|
| 258 |
+
raise ValueError(
|
| 259 |
+
"Incorrect dictionary format, expected '<token> <cnt> [flags]'"
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
def _save(self, f, kv_iterator):
|
| 263 |
+
if isinstance(f, str):
|
| 264 |
+
PathManager.mkdirs(os.path.dirname(f))
|
| 265 |
+
with PathManager.open(f, "w", encoding="utf-8") as fd:
|
| 266 |
+
return self.save(fd)
|
| 267 |
+
for k, v in kv_iterator:
|
| 268 |
+
print("{} {}".format(k, v), file=f)
|
| 269 |
+
|
| 270 |
+
def _get_meta(self):
|
| 271 |
+
return [], []
|
| 272 |
+
|
| 273 |
+
def _load_meta(self, lines):
|
| 274 |
+
return 0
|
| 275 |
+
|
| 276 |
+
def save(self, f):
|
| 277 |
+
"""Stores dictionary into a text file"""
|
| 278 |
+
ex_keys, ex_vals = self._get_meta()
|
| 279 |
+
self._save(
|
| 280 |
+
f,
|
| 281 |
+
zip(
|
| 282 |
+
ex_keys + self.symbols[self.nspecial :],
|
| 283 |
+
ex_vals + self.count[self.nspecial :],
|
| 284 |
+
),
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
def dummy_sentence(self, length):
|
| 288 |
+
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
|
| 289 |
+
t[-1] = self.eos()
|
| 290 |
+
return t
|
| 291 |
+
|
| 292 |
+
def encode_line(
|
| 293 |
+
self,
|
| 294 |
+
line,
|
| 295 |
+
line_tokenizer=tokenize_line,
|
| 296 |
+
add_if_not_exist=True,
|
| 297 |
+
consumer=None,
|
| 298 |
+
append_eos=True,
|
| 299 |
+
reverse_order=False,
|
| 300 |
+
):
|
| 301 |
+
words = line_tokenizer(line)
|
| 302 |
+
if reverse_order:
|
| 303 |
+
words = list(reversed(words))
|
| 304 |
+
nwords = len(words)
|
| 305 |
+
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
|
| 306 |
+
|
| 307 |
+
for i, word in enumerate(words):
|
| 308 |
+
if add_if_not_exist:
|
| 309 |
+
idx = self.add_symbol(word)
|
| 310 |
+
else:
|
| 311 |
+
idx = self.index(word)
|
| 312 |
+
if consumer is not None:
|
| 313 |
+
consumer(word, idx)
|
| 314 |
+
ids[i] = idx
|
| 315 |
+
if append_eos:
|
| 316 |
+
ids[nwords] = self.eos_index
|
| 317 |
+
return ids
|
| 318 |
+
|
| 319 |
+
@staticmethod
|
| 320 |
+
def _add_file_to_dictionary_single_worker(
|
| 321 |
+
filename, tokenize, eos_word, worker_id=0, num_workers=1
|
| 322 |
+
):
|
| 323 |
+
counter = Counter()
|
| 324 |
+
with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f:
|
| 325 |
+
size = os.fstat(f.fileno()).st_size
|
| 326 |
+
chunk_size = size // num_workers
|
| 327 |
+
offset = worker_id * chunk_size
|
| 328 |
+
end = offset + chunk_size
|
| 329 |
+
f.seek(offset)
|
| 330 |
+
if offset > 0:
|
| 331 |
+
safe_readline(f) # drop first incomplete line
|
| 332 |
+
line = f.readline()
|
| 333 |
+
while line:
|
| 334 |
+
for word in tokenize(line):
|
| 335 |
+
counter.update([word])
|
| 336 |
+
counter.update([eos_word])
|
| 337 |
+
if f.tell() > end:
|
| 338 |
+
break
|
| 339 |
+
line = f.readline()
|
| 340 |
+
return counter
|
| 341 |
+
|
| 342 |
+
@staticmethod
|
| 343 |
+
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
|
| 344 |
+
def merge_result(counter):
|
| 345 |
+
for w, c in sorted(counter.items()):
|
| 346 |
+
dict.add_symbol(w, c)
|
| 347 |
+
|
| 348 |
+
if num_workers > 1:
|
| 349 |
+
pool = Pool(processes=num_workers)
|
| 350 |
+
results = []
|
| 351 |
+
for worker_id in range(num_workers):
|
| 352 |
+
results.append(
|
| 353 |
+
pool.apply_async(
|
| 354 |
+
Dictionary._add_file_to_dictionary_single_worker,
|
| 355 |
+
(filename, tokenize, dict.eos_word, worker_id, num_workers),
|
| 356 |
+
)
|
| 357 |
+
)
|
| 358 |
+
pool.close()
|
| 359 |
+
pool.join()
|
| 360 |
+
for r in results:
|
| 361 |
+
merge_result(r.get())
|
| 362 |
+
else:
|
| 363 |
+
merge_result(
|
| 364 |
+
Dictionary._add_file_to_dictionary_single_worker(
|
| 365 |
+
filename, tokenize, dict.eos_word
|
| 366 |
+
)
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class TruncatedDictionary(object):
|
| 371 |
+
def __init__(self, wrapped_dict, length):
|
| 372 |
+
self.__class__ = type(
|
| 373 |
+
wrapped_dict.__class__.__name__,
|
| 374 |
+
(self.__class__, wrapped_dict.__class__),
|
| 375 |
+
{},
|
| 376 |
+
)
|
| 377 |
+
self.__dict__ = wrapped_dict.__dict__
|
| 378 |
+
self.wrapped_dict = wrapped_dict
|
| 379 |
+
self.length = min(len(self.wrapped_dict), length)
|
| 380 |
+
|
| 381 |
+
def __len__(self):
|
| 382 |
+
return self.length
|
| 383 |
+
|
| 384 |
+
def __getitem__(self, i):
|
| 385 |
+
if i < self.length:
|
| 386 |
+
return self.wrapped_dict[i]
|
| 387 |
+
return self.wrapped_dict.unk()
|
fairseq-0.10.2/fairseq/data/encoders/byte_utils.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
WHITESPACE_NORMALIZER = re.compile(r"\s+")
|
| 10 |
+
SPACE = chr(32)
|
| 11 |
+
SPACE_ESCAPE = chr(9601)
|
| 12 |
+
# excluding non-breaking space (160) here
|
| 13 |
+
PRINTABLE_LATIN = set(
|
| 14 |
+
list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1))
|
| 15 |
+
)
|
| 16 |
+
BYTE_TO_BCHAR = {
|
| 17 |
+
b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256)
|
| 18 |
+
}
|
| 19 |
+
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def byte_encode(x: str) -> str:
|
| 23 |
+
normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
|
| 24 |
+
return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def byte_decode(x: str) -> str:
|
| 28 |
+
try:
|
| 29 |
+
return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
|
| 30 |
+
except ValueError:
|
| 31 |
+
return ""
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def smart_byte_decode(x: str) -> str:
|
| 35 |
+
output = byte_decode(x)
|
| 36 |
+
if output == "":
|
| 37 |
+
# DP the best recovery (max valid chars) if it's broken
|
| 38 |
+
n_bytes = len(x)
|
| 39 |
+
f = [0 for _ in range(n_bytes + 1)]
|
| 40 |
+
pt = [0 for _ in range(n_bytes + 1)]
|
| 41 |
+
for i in range(1, n_bytes + 1):
|
| 42 |
+
f[i], pt[i] = f[i - 1], i - 1
|
| 43 |
+
for j in range(1, min(4, i) + 1):
|
| 44 |
+
if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
|
| 45 |
+
f[i], pt[i] = f[i - j] + 1, i - j
|
| 46 |
+
cur_pt = n_bytes
|
| 47 |
+
while cur_pt > 0:
|
| 48 |
+
if f[cur_pt] == f[pt[cur_pt]] + 1:
|
| 49 |
+
output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
|
| 50 |
+
cur_pt = pt[cur_pt]
|
| 51 |
+
return output
|
fairseq-0.10.2/fairseq/data/encoders/fastbpe.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from fairseq import file_utils
|
| 7 |
+
from fairseq.data.encoders import register_bpe
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@register_bpe("fastbpe")
|
| 11 |
+
class fastBPE(object):
|
| 12 |
+
@staticmethod
|
| 13 |
+
def add_args(parser):
|
| 14 |
+
# fmt: off
|
| 15 |
+
parser.add_argument('--bpe-codes', type=str,
|
| 16 |
+
help='path to fastBPE BPE')
|
| 17 |
+
# fmt: on
|
| 18 |
+
|
| 19 |
+
def __init__(self, args):
|
| 20 |
+
if args.bpe_codes is None:
|
| 21 |
+
raise ValueError("--bpe-codes is required for --bpe=fastbpe")
|
| 22 |
+
codes = file_utils.cached_path(args.bpe_codes)
|
| 23 |
+
try:
|
| 24 |
+
import fastBPE
|
| 25 |
+
|
| 26 |
+
self.bpe = fastBPE.fastBPE(codes)
|
| 27 |
+
self.bpe_symbol = "@@ "
|
| 28 |
+
except ImportError:
|
| 29 |
+
raise ImportError("Please install fastBPE with: pip install fastBPE")
|
| 30 |
+
|
| 31 |
+
def encode(self, x: str) -> str:
|
| 32 |
+
return self.bpe.apply([x])[0]
|
| 33 |
+
|
| 34 |
+
def decode(self, x: str) -> str:
|
| 35 |
+
return (x + " ").replace(self.bpe_symbol, "").rstrip()
|
fairseq-0.10.2/fairseq/data/encoders/gpt2_bpe_utils.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte pair encoding utilities from GPT-2.
|
| 3 |
+
|
| 4 |
+
Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
| 5 |
+
Original license: MIT
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@lru_cache()
|
| 13 |
+
def bytes_to_unicode():
|
| 14 |
+
"""
|
| 15 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 16 |
+
The reversible bpe codes work on unicode strings.
|
| 17 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 18 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 19 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
| 20 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 21 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 22 |
+
"""
|
| 23 |
+
bs = (
|
| 24 |
+
list(range(ord("!"), ord("~") + 1))
|
| 25 |
+
+ list(range(ord("¡"), ord("¬") + 1))
|
| 26 |
+
+ list(range(ord("®"), ord("ÿ") + 1))
|
| 27 |
+
)
|
| 28 |
+
cs = bs[:]
|
| 29 |
+
n = 0
|
| 30 |
+
for b in range(2 ** 8):
|
| 31 |
+
if b not in bs:
|
| 32 |
+
bs.append(b)
|
| 33 |
+
cs.append(2 ** 8 + n)
|
| 34 |
+
n += 1
|
| 35 |
+
cs = [chr(n) for n in cs]
|
| 36 |
+
return dict(zip(bs, cs))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_pairs(word):
|
| 40 |
+
"""Return set of symbol pairs in a word.
|
| 41 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 42 |
+
"""
|
| 43 |
+
pairs = set()
|
| 44 |
+
prev_char = word[0]
|
| 45 |
+
for char in word[1:]:
|
| 46 |
+
pairs.add((prev_char, char))
|
| 47 |
+
prev_char = char
|
| 48 |
+
return pairs
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Encoder:
|
| 52 |
+
def __init__(self, encoder, bpe_merges, errors="replace"):
|
| 53 |
+
self.encoder = encoder
|
| 54 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 55 |
+
self.errors = errors # how to handle errors in decoding
|
| 56 |
+
self.byte_encoder = bytes_to_unicode()
|
| 57 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 58 |
+
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
| 59 |
+
self.cache = {}
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
import regex as re
|
| 63 |
+
|
| 64 |
+
self.re = re
|
| 65 |
+
except ImportError:
|
| 66 |
+
raise ImportError("Please install regex with: pip install regex")
|
| 67 |
+
|
| 68 |
+
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
| 69 |
+
self.pat = self.re.compile(
|
| 70 |
+
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def bpe(self, token):
|
| 74 |
+
if token in self.cache:
|
| 75 |
+
return self.cache[token]
|
| 76 |
+
word = tuple(token)
|
| 77 |
+
pairs = get_pairs(word)
|
| 78 |
+
|
| 79 |
+
if not pairs:
|
| 80 |
+
return token
|
| 81 |
+
|
| 82 |
+
while True:
|
| 83 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
| 84 |
+
if bigram not in self.bpe_ranks:
|
| 85 |
+
break
|
| 86 |
+
first, second = bigram
|
| 87 |
+
new_word = []
|
| 88 |
+
i = 0
|
| 89 |
+
while i < len(word):
|
| 90 |
+
try:
|
| 91 |
+
j = word.index(first, i)
|
| 92 |
+
new_word.extend(word[i:j])
|
| 93 |
+
i = j
|
| 94 |
+
except:
|
| 95 |
+
new_word.extend(word[i:])
|
| 96 |
+
break
|
| 97 |
+
|
| 98 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 99 |
+
new_word.append(first + second)
|
| 100 |
+
i += 2
|
| 101 |
+
else:
|
| 102 |
+
new_word.append(word[i])
|
| 103 |
+
i += 1
|
| 104 |
+
new_word = tuple(new_word)
|
| 105 |
+
word = new_word
|
| 106 |
+
if len(word) == 1:
|
| 107 |
+
break
|
| 108 |
+
else:
|
| 109 |
+
pairs = get_pairs(word)
|
| 110 |
+
word = " ".join(word)
|
| 111 |
+
self.cache[token] = word
|
| 112 |
+
return word
|
| 113 |
+
|
| 114 |
+
def encode(self, text):
|
| 115 |
+
bpe_tokens = []
|
| 116 |
+
for token in self.re.findall(self.pat, text):
|
| 117 |
+
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
| 118 |
+
bpe_tokens.extend(
|
| 119 |
+
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
|
| 120 |
+
)
|
| 121 |
+
return bpe_tokens
|
| 122 |
+
|
| 123 |
+
def decode(self, tokens):
|
| 124 |
+
text = "".join([self.decoder.get(token, token) for token in tokens])
|
| 125 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
| 126 |
+
"utf-8", errors=self.errors
|
| 127 |
+
)
|
| 128 |
+
return text
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_encoder(encoder_json_path, vocab_bpe_path):
|
| 132 |
+
with open(encoder_json_path, "r") as f:
|
| 133 |
+
encoder = json.load(f)
|
| 134 |
+
with open(vocab_bpe_path, "r", encoding="utf-8") as f:
|
| 135 |
+
bpe_data = f.read()
|
| 136 |
+
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
|
| 137 |
+
return Encoder(
|
| 138 |
+
encoder=encoder,
|
| 139 |
+
bpe_merges=bpe_merges,
|
| 140 |
+
)
|
fairseq-0.10.2/fairseq/data/encoders/moses_tokenizer.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from fairseq.data.encoders import register_tokenizer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@register_tokenizer("moses")
|
| 10 |
+
class MosesTokenizer(object):
|
| 11 |
+
@staticmethod
|
| 12 |
+
def add_args(parser):
|
| 13 |
+
# fmt: off
|
| 14 |
+
parser.add_argument('--moses-source-lang', metavar='SRC',
|
| 15 |
+
help='source language')
|
| 16 |
+
parser.add_argument('--moses-target-lang', metavar='TARGET',
|
| 17 |
+
help='target language')
|
| 18 |
+
parser.add_argument('--moses-no-dash-splits', action='store_true', default=False,
|
| 19 |
+
help='don\'t apply dash split rules')
|
| 20 |
+
parser.add_argument('--moses-no-escape', action='store_true', default=False,
|
| 21 |
+
help='don\'t perform HTML escaping on apostrophy, quotes, etc.')
|
| 22 |
+
# fmt: on
|
| 23 |
+
|
| 24 |
+
def __init__(self, args):
|
| 25 |
+
self.args = args
|
| 26 |
+
|
| 27 |
+
if getattr(args, "moses_source_lang", None) is None:
|
| 28 |
+
args.moses_source_lang = getattr(args, "source_lang", "en")
|
| 29 |
+
if getattr(args, "moses_target_lang", None) is None:
|
| 30 |
+
args.moses_target_lang = getattr(args, "target_lang", "en")
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from sacremoses import MosesTokenizer, MosesDetokenizer
|
| 34 |
+
|
| 35 |
+
self.tok = MosesTokenizer(args.moses_source_lang)
|
| 36 |
+
self.detok = MosesDetokenizer(args.moses_target_lang)
|
| 37 |
+
except ImportError:
|
| 38 |
+
raise ImportError(
|
| 39 |
+
"Please install Moses tokenizer with: pip install sacremoses"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def encode(self, x: str) -> str:
|
| 43 |
+
return self.tok.tokenize(
|
| 44 |
+
x,
|
| 45 |
+
aggressive_dash_splits=(not self.args.moses_no_dash_splits),
|
| 46 |
+
return_str=True,
|
| 47 |
+
escape=(not self.args.moses_no_escape),
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def decode(self, x: str) -> str:
|
| 51 |
+
return self.detok.detokenize(x.split())
|
fairseq-0.10.2/fairseq/data/encoders/nltk_tokenizer.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from fairseq.data.encoders import register_tokenizer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@register_tokenizer("nltk")
|
| 10 |
+
class NLTKTokenizer(object):
|
| 11 |
+
def __init__(self, source_lang=None, target_lang=None):
|
| 12 |
+
try:
|
| 13 |
+
from nltk.tokenize import word_tokenize
|
| 14 |
+
|
| 15 |
+
self.word_tokenize = word_tokenize
|
| 16 |
+
except ImportError:
|
| 17 |
+
raise ImportError("Please install nltk with: pip install nltk")
|
| 18 |
+
|
| 19 |
+
def encode(self, x: str) -> str:
|
| 20 |
+
return " ".join(self.word_tokenize(x))
|
| 21 |
+
|
| 22 |
+
def decode(self, x: str) -> str:
|
| 23 |
+
return x
|
fairseq-0.10.2/fairseq/data/encoders/subword_nmt_bpe.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from fairseq import file_utils
|
| 7 |
+
from fairseq.data.encoders import register_bpe
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@register_bpe("subword_nmt")
|
| 11 |
+
class SubwordNMTBPE(object):
|
| 12 |
+
@staticmethod
|
| 13 |
+
def add_args(parser):
|
| 14 |
+
# fmt: off
|
| 15 |
+
parser.add_argument('--bpe-codes', type=str,
|
| 16 |
+
help='path to subword NMT BPE')
|
| 17 |
+
parser.add_argument('--bpe-separator', default='@@',
|
| 18 |
+
help='BPE separator')
|
| 19 |
+
# fmt: on
|
| 20 |
+
|
| 21 |
+
def __init__(self, args):
|
| 22 |
+
if args.bpe_codes is None:
|
| 23 |
+
raise ValueError("--bpe-codes is required for --bpe=subword_nmt")
|
| 24 |
+
codes = file_utils.cached_path(args.bpe_codes)
|
| 25 |
+
try:
|
| 26 |
+
from subword_nmt import apply_bpe
|
| 27 |
+
|
| 28 |
+
bpe_parser = apply_bpe.create_parser()
|
| 29 |
+
bpe_args = bpe_parser.parse_args(
|
| 30 |
+
[
|
| 31 |
+
"--codes",
|
| 32 |
+
codes,
|
| 33 |
+
"--separator",
|
| 34 |
+
args.bpe_separator,
|
| 35 |
+
]
|
| 36 |
+
)
|
| 37 |
+
self.bpe = apply_bpe.BPE(
|
| 38 |
+
bpe_args.codes,
|
| 39 |
+
bpe_args.merges,
|
| 40 |
+
bpe_args.separator,
|
| 41 |
+
None,
|
| 42 |
+
bpe_args.glossaries,
|
| 43 |
+
)
|
| 44 |
+
self.bpe_symbol = bpe_args.separator + " "
|
| 45 |
+
except ImportError:
|
| 46 |
+
raise ImportError(
|
| 47 |
+
"Please install subword_nmt with: pip install subword-nmt"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def encode(self, x: str) -> str:
|
| 51 |
+
return self.bpe.process_line(x)
|
| 52 |
+
|
| 53 |
+
def decode(self, x: str) -> str:
|
| 54 |
+
return (x + " ").replace(self.bpe_symbol, "").rstrip()
|
fairseq-0.10.2/fairseq/data/encoders/utils.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from fairseq.data import encoders
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_whole_word_mask(args, dictionary):
|
| 11 |
+
bpe = encoders.build_bpe(args)
|
| 12 |
+
if bpe is not None:
|
| 13 |
+
|
| 14 |
+
def is_beginning_of_word(i):
|
| 15 |
+
if i < dictionary.nspecial:
|
| 16 |
+
# special elements are always considered beginnings
|
| 17 |
+
return True
|
| 18 |
+
tok = dictionary[i]
|
| 19 |
+
if tok.startswith("madeupword"):
|
| 20 |
+
return True
|
| 21 |
+
try:
|
| 22 |
+
return bpe.is_beginning_of_word(tok)
|
| 23 |
+
except ValueError:
|
| 24 |
+
return True
|
| 25 |
+
|
| 26 |
+
mask_whole_words = torch.ByteTensor(
|
| 27 |
+
list(map(is_beginning_of_word, range(len(dictionary))))
|
| 28 |
+
)
|
| 29 |
+
return mask_whole_words
|
| 30 |
+
return None
|
fairseq-0.10.2/fairseq/data/fairseq_dataset.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch.utils.data
|
| 8 |
+
from fairseq.data import data_utils
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EpochListening:
|
| 12 |
+
"""Mixin for receiving updates whenever the epoch increments."""
|
| 13 |
+
|
| 14 |
+
@property
|
| 15 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
| 16 |
+
"""
|
| 17 |
+
Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for
|
| 18 |
+
this dataset across epochs.
|
| 19 |
+
|
| 20 |
+
This needs to return ``False`` if the sample sizes can change across
|
| 21 |
+
epochs, in which case we may need to regenerate batches at each epoch.
|
| 22 |
+
If your dataset relies in ``set_epoch`` then you should consider setting
|
| 23 |
+
this to ``False``.
|
| 24 |
+
"""
|
| 25 |
+
return True
|
| 26 |
+
|
| 27 |
+
def set_epoch(self, epoch):
|
| 28 |
+
"""Will receive the updated epoch number at the beginning of the epoch."""
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class FairseqDataset(torch.utils.data.Dataset, EpochListening):
|
| 33 |
+
"""A dataset that provides helpers for batching."""
|
| 34 |
+
|
| 35 |
+
def __getitem__(self, index):
|
| 36 |
+
raise NotImplementedError
|
| 37 |
+
|
| 38 |
+
def __len__(self):
|
| 39 |
+
raise NotImplementedError
|
| 40 |
+
|
| 41 |
+
def collater(self, samples):
|
| 42 |
+
"""Merge a list of samples to form a mini-batch.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
samples (List[dict]): samples to collate
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
dict: a mini-batch suitable for forwarding with a Model
|
| 49 |
+
"""
|
| 50 |
+
raise NotImplementedError
|
| 51 |
+
|
| 52 |
+
def num_tokens(self, index):
|
| 53 |
+
"""Return the number of tokens in a sample. This value is used to
|
| 54 |
+
enforce ``--max-tokens`` during batching."""
|
| 55 |
+
raise NotImplementedError
|
| 56 |
+
|
| 57 |
+
def size(self, index):
|
| 58 |
+
"""Return an example's size as a float or tuple. This value is used when
|
| 59 |
+
filtering a dataset with ``--max-positions``."""
|
| 60 |
+
raise NotImplementedError
|
| 61 |
+
|
| 62 |
+
def ordered_indices(self):
|
| 63 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
| 64 |
+
on this order."""
|
| 65 |
+
return np.arange(len(self), dtype=np.int64)
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def supports_prefetch(self):
|
| 69 |
+
"""Whether this dataset supports prefetching."""
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
def attr(self, attr: str, index: int):
|
| 73 |
+
return getattr(self, attr, None)
|
| 74 |
+
|
| 75 |
+
def prefetch(self, indices):
|
| 76 |
+
"""Prefetch the data required for this epoch."""
|
| 77 |
+
raise NotImplementedError
|
| 78 |
+
|
| 79 |
+
def get_batch_shapes(self):
|
| 80 |
+
"""
|
| 81 |
+
Return a list of valid batch shapes, for example::
|
| 82 |
+
|
| 83 |
+
[(8, 512), (16, 256), (32, 128)]
|
| 84 |
+
|
| 85 |
+
The first dimension of each tuple is the batch size and can be ``None``
|
| 86 |
+
to automatically infer the max batch size based on ``--max-tokens``.
|
| 87 |
+
The second dimension of each tuple is the max supported length as given
|
| 88 |
+
by :func:`fairseq.data.FairseqDataset.num_tokens`.
|
| 89 |
+
|
| 90 |
+
This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size`
|
| 91 |
+
to restrict batch shapes. This is useful on TPUs to avoid too many
|
| 92 |
+
dynamic shapes (and recompilations).
|
| 93 |
+
"""
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
def batch_by_size(
|
| 97 |
+
self,
|
| 98 |
+
indices,
|
| 99 |
+
max_tokens=None,
|
| 100 |
+
max_sentences=None,
|
| 101 |
+
required_batch_size_multiple=1,
|
| 102 |
+
):
|
| 103 |
+
"""
|
| 104 |
+
Given an ordered set of indices, return batches according to
|
| 105 |
+
*max_tokens*, *max_sentences* and *required_batch_size_multiple*.
|
| 106 |
+
"""
|
| 107 |
+
from fairseq.data import data_utils
|
| 108 |
+
|
| 109 |
+
fixed_shapes = self.get_batch_shapes()
|
| 110 |
+
if fixed_shapes is not None:
|
| 111 |
+
|
| 112 |
+
def adjust_bsz(bsz, num_tokens):
|
| 113 |
+
if bsz is None:
|
| 114 |
+
assert max_tokens is not None, "Must specify --max-tokens"
|
| 115 |
+
bsz = max_tokens // num_tokens
|
| 116 |
+
if max_sentences is not None:
|
| 117 |
+
bsz = min(bsz, max_sentences)
|
| 118 |
+
elif (
|
| 119 |
+
bsz >= required_batch_size_multiple
|
| 120 |
+
and bsz % required_batch_size_multiple != 0
|
| 121 |
+
):
|
| 122 |
+
bsz -= bsz % required_batch_size_multiple
|
| 123 |
+
return bsz
|
| 124 |
+
|
| 125 |
+
fixed_shapes = np.array(
|
| 126 |
+
[
|
| 127 |
+
[adjust_bsz(bsz, num_tokens), num_tokens]
|
| 128 |
+
for (bsz, num_tokens) in fixed_shapes
|
| 129 |
+
]
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
return data_utils.batch_by_size(
|
| 133 |
+
indices,
|
| 134 |
+
num_tokens_fn=self.num_tokens,
|
| 135 |
+
max_tokens=max_tokens,
|
| 136 |
+
max_sentences=max_sentences,
|
| 137 |
+
required_batch_size_multiple=required_batch_size_multiple,
|
| 138 |
+
fixed_shapes=fixed_shapes,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def filter_indices_by_size(self, indices, max_sizes):
|
| 142 |
+
"""
|
| 143 |
+
Filter a list of sample indices. Remove those that are longer than
|
| 144 |
+
specified in *max_sizes*.
|
| 145 |
+
|
| 146 |
+
WARNING: don't update, override method in child classes
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
indices (np.array): original array of sample indices
|
| 150 |
+
max_sizes (int or list[int] or tuple[int]): max sample size,
|
| 151 |
+
can be defined separately for src and tgt (then list or tuple)
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
np.array: filtered sample array
|
| 155 |
+
list: list of removed indices
|
| 156 |
+
"""
|
| 157 |
+
if isinstance(max_sizes, float) or isinstance(max_sizes, int):
|
| 158 |
+
if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
|
| 159 |
+
ignored = indices[self.sizes[indices] > max_sizes].tolist()
|
| 160 |
+
indices = indices[self.sizes[indices] <= max_sizes]
|
| 161 |
+
elif (
|
| 162 |
+
hasattr(self, "sizes")
|
| 163 |
+
and isinstance(self.sizes, list)
|
| 164 |
+
and len(self.sizes) == 1
|
| 165 |
+
):
|
| 166 |
+
ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
|
| 167 |
+
indices = indices[self.sizes[0][indices] <= max_sizes]
|
| 168 |
+
else:
|
| 169 |
+
indices, ignored = data_utils._filter_by_size_dynamic(
|
| 170 |
+
indices, self.size, max_sizes
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
indices, ignored = data_utils._filter_by_size_dynamic(
|
| 174 |
+
indices, self.size, max_sizes
|
| 175 |
+
)
|
| 176 |
+
return indices, ignored
|
| 177 |
+
|
| 178 |
+
@property
|
| 179 |
+
def supports_fetch_outside_dataloader(self):
|
| 180 |
+
"""Whether this dataset supports fetching outside the workers of the dataloader."""
|
| 181 |
+
return True
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening):
|
| 185 |
+
"""
|
| 186 |
+
For datasets that need to be read sequentially, usually because the data is
|
| 187 |
+
being streamed or otherwise can't be manipulated on a single machine.
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
def __iter__(self):
|
| 191 |
+
raise NotImplementedError
|
fairseq-0.10.2/fairseq/data/fasta_dataset.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import subprocess
|
| 8 |
+
import threading
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def fasta_file_path(prefix_path):
|
| 16 |
+
return prefix_path + ".fasta"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FastaDataset(torch.utils.data.Dataset):
|
| 20 |
+
"""
|
| 21 |
+
For loading protein sequence datasets in the common FASTA data format
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, path: str, cache_indices=False):
|
| 25 |
+
self.fn = fasta_file_path(path)
|
| 26 |
+
self.threadlocal = threading.local()
|
| 27 |
+
self.cache = Path(f"{path}.fasta.idx.npy")
|
| 28 |
+
if cache_indices:
|
| 29 |
+
if self.cache.exists():
|
| 30 |
+
self.offsets, self.sizes = np.load(self.cache)
|
| 31 |
+
else:
|
| 32 |
+
self.offsets, self.sizes = self._build_index(path)
|
| 33 |
+
np.save(self.cache, np.stack([self.offsets, self.sizes]))
|
| 34 |
+
else:
|
| 35 |
+
self.offsets, self.sizes = self._build_index(path)
|
| 36 |
+
|
| 37 |
+
def _get_file(self):
|
| 38 |
+
if not hasattr(self.threadlocal, "f"):
|
| 39 |
+
self.threadlocal.f = open(self.fn, "r")
|
| 40 |
+
return self.threadlocal.f
|
| 41 |
+
|
| 42 |
+
def __getitem__(self, idx):
|
| 43 |
+
f = self._get_file()
|
| 44 |
+
f.seek(self.offsets[idx])
|
| 45 |
+
desc = f.readline().strip()
|
| 46 |
+
line = f.readline()
|
| 47 |
+
seq = ""
|
| 48 |
+
while line != "" and line[0] != ">":
|
| 49 |
+
seq += line.strip()
|
| 50 |
+
line = f.readline()
|
| 51 |
+
return desc, seq
|
| 52 |
+
|
| 53 |
+
def __len__(self):
|
| 54 |
+
return self.offsets.size
|
| 55 |
+
|
| 56 |
+
def _build_index(self, path: str):
|
| 57 |
+
# Use grep and awk to get 100M/s on local SSD.
|
| 58 |
+
# Should process your enormous 100G fasta in ~10 min single core...
|
| 59 |
+
path = fasta_file_path(path)
|
| 60 |
+
bytes_offsets = subprocess.check_output(
|
| 61 |
+
f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
|
| 62 |
+
"| grep --byte-offset '^>' -o | cut -d: -f1",
|
| 63 |
+
shell=True,
|
| 64 |
+
)
|
| 65 |
+
fasta_lengths = subprocess.check_output(
|
| 66 |
+
f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
|
| 67 |
+
"| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'",
|
| 68 |
+
shell=True,
|
| 69 |
+
)
|
| 70 |
+
bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ")
|
| 71 |
+
sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ")
|
| 72 |
+
return bytes_np, sizes_np
|
| 73 |
+
|
| 74 |
+
def __setstate__(self, state):
|
| 75 |
+
self.__dict__ = state
|
| 76 |
+
self.threadlocal = threading.local()
|
| 77 |
+
|
| 78 |
+
def __getstate__(self):
|
| 79 |
+
d = {}
|
| 80 |
+
for i, v in self.__dict__.items():
|
| 81 |
+
if i != "threadlocal":
|
| 82 |
+
d[i] = v
|
| 83 |
+
return d
|
| 84 |
+
|
| 85 |
+
def __del__(self):
|
| 86 |
+
if hasattr(self.threadlocal, "f"):
|
| 87 |
+
self.threadlocal.f.close()
|
| 88 |
+
del self.threadlocal.f
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def exists(path):
|
| 92 |
+
return os.path.exists(fasta_file_path(path))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class EncodedFastaDataset(FastaDataset):
|
| 96 |
+
"""
|
| 97 |
+
The FastaDataset returns raw sequences - this allows us to return
|
| 98 |
+
indices with a dictionary instead.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(self, path, dictionary):
|
| 102 |
+
super().__init__(path, cache_indices=True)
|
| 103 |
+
self.dictionary = dictionary
|
| 104 |
+
|
| 105 |
+
def __getitem__(self, idx):
|
| 106 |
+
desc, seq = super().__getitem__(idx)
|
| 107 |
+
return self.dictionary.encode_line(seq, line_tokenizer=list).long()
|
fairseq-0.10.2/fairseq/data/id_dataset.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from . import FairseqDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class IdDataset(FairseqDataset):
|
| 12 |
+
def __getitem__(self, index):
|
| 13 |
+
return index
|
| 14 |
+
|
| 15 |
+
def __len__(self):
|
| 16 |
+
return 0
|
| 17 |
+
|
| 18 |
+
def collater(self, samples):
|
| 19 |
+
return torch.tensor(samples)
|
fairseq-0.10.2/fairseq/data/language_pair_dataset.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq.data import FairseqDataset, data_utils
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def collate(
|
| 17 |
+
samples,
|
| 18 |
+
pad_idx,
|
| 19 |
+
eos_idx,
|
| 20 |
+
left_pad_source=True,
|
| 21 |
+
left_pad_target=False,
|
| 22 |
+
input_feeding=True,
|
| 23 |
+
pad_to_length=None,
|
| 24 |
+
pad_to_multiple=1,
|
| 25 |
+
):
|
| 26 |
+
if len(samples) == 0:
|
| 27 |
+
return {}
|
| 28 |
+
|
| 29 |
+
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
|
| 30 |
+
return data_utils.collate_tokens(
|
| 31 |
+
[s[key] for s in samples],
|
| 32 |
+
pad_idx,
|
| 33 |
+
eos_idx,
|
| 34 |
+
left_pad,
|
| 35 |
+
move_eos_to_beginning,
|
| 36 |
+
pad_to_length=pad_to_length,
|
| 37 |
+
pad_to_multiple=pad_to_multiple,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def check_alignment(alignment, src_len, tgt_len):
|
| 41 |
+
if alignment is None or len(alignment) == 0:
|
| 42 |
+
return False
|
| 43 |
+
if (
|
| 44 |
+
alignment[:, 0].max().item() >= src_len - 1
|
| 45 |
+
or alignment[:, 1].max().item() >= tgt_len - 1
|
| 46 |
+
):
|
| 47 |
+
logger.warning("alignment size mismatch found, skipping alignment!")
|
| 48 |
+
return False
|
| 49 |
+
return True
|
| 50 |
+
|
| 51 |
+
def compute_alignment_weights(alignments):
|
| 52 |
+
"""
|
| 53 |
+
Given a tensor of shape [:, 2] containing the source-target indices
|
| 54 |
+
corresponding to the alignments, a weight vector containing the
|
| 55 |
+
inverse frequency of each target index is computed.
|
| 56 |
+
For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
|
| 57 |
+
a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
|
| 58 |
+
index 3 is repeated twice)
|
| 59 |
+
"""
|
| 60 |
+
align_tgt = alignments[:, 1]
|
| 61 |
+
_, align_tgt_i, align_tgt_c = torch.unique(
|
| 62 |
+
align_tgt, return_inverse=True, return_counts=True
|
| 63 |
+
)
|
| 64 |
+
align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
|
| 65 |
+
return 1.0 / align_weights.float()
|
| 66 |
+
|
| 67 |
+
id = torch.LongTensor([s["id"] for s in samples])
|
| 68 |
+
|
| 69 |
+
# import pdb;pdb.set_trace()
|
| 70 |
+
|
| 71 |
+
src_tokens = merge(
|
| 72 |
+
"source",
|
| 73 |
+
left_pad=left_pad_source,
|
| 74 |
+
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
|
| 75 |
+
)
|
| 76 |
+
# sort by descending source length
|
| 77 |
+
src_lengths = torch.LongTensor(
|
| 78 |
+
[s["source"].ne(pad_idx).long().sum() for s in samples]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# import pdb;pdb.set_trace()
|
| 82 |
+
|
| 83 |
+
src_lengths, sort_order = src_lengths.sort(descending=True)
|
| 84 |
+
id = id.index_select(0, sort_order)
|
| 85 |
+
src_tokens = src_tokens.index_select(0, sort_order)
|
| 86 |
+
|
| 87 |
+
prev_output_tokens = None
|
| 88 |
+
target = None
|
| 89 |
+
if samples[0].get("target", None) is not None:
|
| 90 |
+
target = merge(
|
| 91 |
+
"target",
|
| 92 |
+
left_pad=left_pad_target,
|
| 93 |
+
pad_to_length=pad_to_length["target"]
|
| 94 |
+
if pad_to_length is not None
|
| 95 |
+
else None,
|
| 96 |
+
)
|
| 97 |
+
target = target.index_select(0, sort_order)
|
| 98 |
+
tgt_lengths = torch.LongTensor(
|
| 99 |
+
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
| 100 |
+
).index_select(0, sort_order)
|
| 101 |
+
ntokens = tgt_lengths.sum().item()
|
| 102 |
+
|
| 103 |
+
if samples[0].get("prev_output_tokens", None) is not None:
|
| 104 |
+
prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
|
| 105 |
+
elif input_feeding:
|
| 106 |
+
# we create a shifted version of targets for feeding the
|
| 107 |
+
# previous output token(s) into the next decoder step
|
| 108 |
+
prev_output_tokens = merge(
|
| 109 |
+
"target",
|
| 110 |
+
left_pad=left_pad_target,
|
| 111 |
+
move_eos_to_beginning=True,
|
| 112 |
+
pad_to_length=pad_to_length["target"]
|
| 113 |
+
if pad_to_length is not None
|
| 114 |
+
else None,
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
ntokens = src_lengths.sum().item()
|
| 118 |
+
|
| 119 |
+
batch = {
|
| 120 |
+
"id": id,
|
| 121 |
+
"nsentences": len(samples),
|
| 122 |
+
"ntokens": ntokens,
|
| 123 |
+
"net_input": {
|
| 124 |
+
"src_tokens": src_tokens,
|
| 125 |
+
"src_lengths": src_lengths,
|
| 126 |
+
},
|
| 127 |
+
"target": target,
|
| 128 |
+
}
|
| 129 |
+
if prev_output_tokens is not None:
|
| 130 |
+
batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
|
| 131 |
+
0, sort_order
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
if samples[0].get("alignment", None) is not None:
|
| 135 |
+
bsz, tgt_sz = batch["target"].shape
|
| 136 |
+
src_sz = batch["net_input"]["src_tokens"].shape[1]
|
| 137 |
+
|
| 138 |
+
offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
|
| 139 |
+
offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz
|
| 140 |
+
if left_pad_source:
|
| 141 |
+
offsets[:, 0] += src_sz - src_lengths
|
| 142 |
+
if left_pad_target:
|
| 143 |
+
offsets[:, 1] += tgt_sz - tgt_lengths
|
| 144 |
+
|
| 145 |
+
alignments = [
|
| 146 |
+
alignment + offset
|
| 147 |
+
for align_idx, offset, src_len, tgt_len in zip(
|
| 148 |
+
sort_order, offsets, src_lengths, tgt_lengths
|
| 149 |
+
)
|
| 150 |
+
for alignment in [samples[align_idx]["alignment"].view(-1, 2)]
|
| 151 |
+
if check_alignment(alignment, src_len, tgt_len)
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
if len(alignments) > 0:
|
| 155 |
+
alignments = torch.cat(alignments, dim=0)
|
| 156 |
+
align_weights = compute_alignment_weights(alignments)
|
| 157 |
+
|
| 158 |
+
batch["alignments"] = alignments
|
| 159 |
+
batch["align_weights"] = align_weights
|
| 160 |
+
|
| 161 |
+
if samples[0].get("constraints", None) is not None:
|
| 162 |
+
# Collate the packed constraints across the samples, padding to
|
| 163 |
+
# the length of the longest sample.
|
| 164 |
+
lens = [sample.get("constraints").size(0) for sample in samples]
|
| 165 |
+
max_len = max(lens)
|
| 166 |
+
constraints = torch.zeros((len(samples), max(lens))).long()
|
| 167 |
+
for i, sample in enumerate(samples):
|
| 168 |
+
constraints[i, 0 : lens[i]] = samples[i].get("constraints")
|
| 169 |
+
batch["constraints"] = constraints
|
| 170 |
+
|
| 171 |
+
return batch
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class LanguagePairDataset(FairseqDataset):
|
| 175 |
+
"""
|
| 176 |
+
A pair of torch.utils.data.Datasets.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
src (torch.utils.data.Dataset): source dataset to wrap
|
| 180 |
+
src_sizes (List[int]): source sentence lengths
|
| 181 |
+
src_dict (~fairseq.data.Dictionary): source vocabulary
|
| 182 |
+
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
|
| 183 |
+
tgt_sizes (List[int], optional): target sentence lengths
|
| 184 |
+
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
|
| 185 |
+
left_pad_source (bool, optional): pad source tensors on the left side
|
| 186 |
+
(default: True).
|
| 187 |
+
left_pad_target (bool, optional): pad target tensors on the left side
|
| 188 |
+
(default: False).
|
| 189 |
+
shuffle (bool, optional): shuffle dataset elements before batching
|
| 190 |
+
(default: True).
|
| 191 |
+
input_feeding (bool, optional): create a shifted version of the targets
|
| 192 |
+
to be passed into the model for teacher forcing (default: True).
|
| 193 |
+
remove_eos_from_source (bool, optional): if set, removes eos from end
|
| 194 |
+
of source if it's present (default: False).
|
| 195 |
+
append_eos_to_target (bool, optional): if set, appends eos to end of
|
| 196 |
+
target if it's absent (default: False).
|
| 197 |
+
align_dataset (torch.utils.data.Dataset, optional): dataset
|
| 198 |
+
containing alignments.
|
| 199 |
+
constraints (Tensor, optional): 2d tensor with a concatenated, zero-
|
| 200 |
+
delimited list of constraints for each sentence.
|
| 201 |
+
append_bos (bool, optional): if set, appends bos to the beginning of
|
| 202 |
+
source/target sentence.
|
| 203 |
+
num_buckets (int, optional): if set to a value greater than 0, then
|
| 204 |
+
batches will be bucketed into the given number of batch shapes.
|
| 205 |
+
src_lang_id (int, optional): source language ID, if set, the collated batch
|
| 206 |
+
will contain a field 'src_lang_id' in 'net_input' which indicates the
|
| 207 |
+
source language of the samples.
|
| 208 |
+
tgt_lang_id (int, optional): target language ID, if set, the collated batch
|
| 209 |
+
will contain a field 'tgt_lang_id' which indicates the target language
|
| 210 |
+
of the samples.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def __init__(
|
| 214 |
+
self,
|
| 215 |
+
src,
|
| 216 |
+
src_sizes,
|
| 217 |
+
src_dict,
|
| 218 |
+
tgt=None,
|
| 219 |
+
tgt_sizes=None,
|
| 220 |
+
tgt_dict=None,
|
| 221 |
+
left_pad_source=True,
|
| 222 |
+
left_pad_target=False,
|
| 223 |
+
shuffle=True,
|
| 224 |
+
input_feeding=True,
|
| 225 |
+
remove_eos_from_source=False,
|
| 226 |
+
append_eos_to_target=False,
|
| 227 |
+
align_dataset=None,
|
| 228 |
+
constraints=None,
|
| 229 |
+
append_bos=False,
|
| 230 |
+
eos=None,
|
| 231 |
+
num_buckets=0,
|
| 232 |
+
src_lang_id=None,
|
| 233 |
+
tgt_lang_id=None,
|
| 234 |
+
pad_to_multiple=1,
|
| 235 |
+
):
|
| 236 |
+
if tgt_dict is not None:
|
| 237 |
+
assert src_dict.pad() == tgt_dict.pad()
|
| 238 |
+
assert src_dict.eos() == tgt_dict.eos()
|
| 239 |
+
assert src_dict.unk() == tgt_dict.unk()
|
| 240 |
+
if tgt is not None:
|
| 241 |
+
assert len(src) == len(
|
| 242 |
+
tgt
|
| 243 |
+
), "Source and target must contain the same number of examples"
|
| 244 |
+
self.src = src
|
| 245 |
+
self.tgt = tgt
|
| 246 |
+
self.src_sizes = np.array(src_sizes)
|
| 247 |
+
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
|
| 248 |
+
self.sizes = (
|
| 249 |
+
np.vstack((self.src_sizes, self.tgt_sizes)).T
|
| 250 |
+
if self.tgt_sizes is not None
|
| 251 |
+
else self.src_sizes
|
| 252 |
+
)
|
| 253 |
+
self.src_dict = src_dict
|
| 254 |
+
self.tgt_dict = tgt_dict
|
| 255 |
+
self.left_pad_source = left_pad_source
|
| 256 |
+
self.left_pad_target = left_pad_target
|
| 257 |
+
self.shuffle = shuffle
|
| 258 |
+
self.input_feeding = input_feeding
|
| 259 |
+
self.remove_eos_from_source = remove_eos_from_source
|
| 260 |
+
self.append_eos_to_target = append_eos_to_target
|
| 261 |
+
self.align_dataset = align_dataset
|
| 262 |
+
if self.align_dataset is not None:
|
| 263 |
+
assert (
|
| 264 |
+
self.tgt_sizes is not None
|
| 265 |
+
), "Both source and target needed when alignments are provided"
|
| 266 |
+
self.constraints = constraints
|
| 267 |
+
self.append_bos = append_bos
|
| 268 |
+
self.eos = eos if eos is not None else src_dict.eos()
|
| 269 |
+
self.src_lang_id = src_lang_id
|
| 270 |
+
self.tgt_lang_id = tgt_lang_id
|
| 271 |
+
if num_buckets > 0:
|
| 272 |
+
from fairseq.data import BucketPadLengthDataset
|
| 273 |
+
|
| 274 |
+
self.src = BucketPadLengthDataset(
|
| 275 |
+
self.src,
|
| 276 |
+
sizes=self.src_sizes,
|
| 277 |
+
num_buckets=num_buckets,
|
| 278 |
+
pad_idx=self.src_dict.pad(),
|
| 279 |
+
left_pad=self.left_pad_source,
|
| 280 |
+
)
|
| 281 |
+
self.src_sizes = self.src.sizes
|
| 282 |
+
logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
|
| 283 |
+
if self.tgt is not None:
|
| 284 |
+
self.tgt = BucketPadLengthDataset(
|
| 285 |
+
self.tgt,
|
| 286 |
+
sizes=self.tgt_sizes,
|
| 287 |
+
num_buckets=num_buckets,
|
| 288 |
+
pad_idx=self.tgt_dict.pad(),
|
| 289 |
+
left_pad=self.left_pad_target,
|
| 290 |
+
)
|
| 291 |
+
self.tgt_sizes = self.tgt.sizes
|
| 292 |
+
logger.info(
|
| 293 |
+
"bucketing target lengths: {}".format(list(self.tgt.buckets))
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# determine bucket sizes using self.num_tokens, which will return
|
| 297 |
+
# the padded lengths (thanks to BucketPadLengthDataset)
|
| 298 |
+
num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
|
| 299 |
+
self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
|
| 300 |
+
self.buckets = [
|
| 301 |
+
(None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
|
| 302 |
+
]
|
| 303 |
+
else:
|
| 304 |
+
self.buckets = None
|
| 305 |
+
self.pad_to_multiple = pad_to_multiple
|
| 306 |
+
|
| 307 |
+
def get_batch_shapes(self):
|
| 308 |
+
return self.buckets
|
| 309 |
+
|
| 310 |
+
def __getitem__(self, index):
|
| 311 |
+
tgt_item = self.tgt[index] if self.tgt is not None else None
|
| 312 |
+
src_item = self.src[index]
|
| 313 |
+
# Append EOS to end of tgt sentence if it does not have an EOS and remove
|
| 314 |
+
# EOS from end of src sentence if it exists. This is useful when we use
|
| 315 |
+
# use existing datasets for opposite directions i.e., when we want to
|
| 316 |
+
# use tgt_dataset as src_dataset and vice versa
|
| 317 |
+
if self.append_eos_to_target:
|
| 318 |
+
eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
|
| 319 |
+
if self.tgt and self.tgt[index][-1] != eos:
|
| 320 |
+
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
|
| 321 |
+
|
| 322 |
+
if self.append_bos:
|
| 323 |
+
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
|
| 324 |
+
if self.tgt and self.tgt[index][0] != bos:
|
| 325 |
+
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
|
| 326 |
+
|
| 327 |
+
bos = self.src_dict.bos()
|
| 328 |
+
if self.src[index][0] != bos:
|
| 329 |
+
src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
|
| 330 |
+
|
| 331 |
+
if self.remove_eos_from_source:
|
| 332 |
+
eos = self.src_dict.eos()
|
| 333 |
+
if self.src[index][-1] == eos:
|
| 334 |
+
src_item = self.src[index][:-1]
|
| 335 |
+
|
| 336 |
+
example = {
|
| 337 |
+
"id": index,
|
| 338 |
+
"source": src_item,
|
| 339 |
+
"target": tgt_item,
|
| 340 |
+
}
|
| 341 |
+
if self.align_dataset is not None:
|
| 342 |
+
example["alignment"] = self.align_dataset[index]
|
| 343 |
+
if self.constraints is not None:
|
| 344 |
+
example["constraints"] = self.constraints[index]
|
| 345 |
+
return example
|
| 346 |
+
|
| 347 |
+
def __len__(self):
|
| 348 |
+
return len(self.src)
|
| 349 |
+
|
| 350 |
+
def collater(self, samples, pad_to_length=None):
|
| 351 |
+
"""Merge a list of samples to form a mini-batch.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
samples (List[dict]): samples to collate
|
| 355 |
+
pad_to_length (dict, optional): a dictionary of
|
| 356 |
+
{'source': source_pad_to_length, 'target': target_pad_to_length}
|
| 357 |
+
to indicate the max length to pad to in source and target respectively.
|
| 358 |
+
|
| 359 |
+
Returns:
|
| 360 |
+
dict: a mini-batch with the following keys:
|
| 361 |
+
|
| 362 |
+
- `id` (LongTensor): example IDs in the original input order
|
| 363 |
+
- `ntokens` (int): total number of tokens in the batch
|
| 364 |
+
- `net_input` (dict): the input to the Model, containing keys:
|
| 365 |
+
|
| 366 |
+
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
|
| 367 |
+
the source sentence of shape `(bsz, src_len)`. Padding will
|
| 368 |
+
appear on the left if *left_pad_source* is ``True``.
|
| 369 |
+
- `src_lengths` (LongTensor): 1D Tensor of the unpadded
|
| 370 |
+
lengths of each source sentence of shape `(bsz)`
|
| 371 |
+
- `prev_output_tokens` (LongTensor): a padded 2D Tensor of
|
| 372 |
+
tokens in the target sentence, shifted right by one
|
| 373 |
+
position for teacher forcing, of shape `(bsz, tgt_len)`.
|
| 374 |
+
This key will not be present if *input_feeding* is
|
| 375 |
+
``False``. Padding will appear on the left if
|
| 376 |
+
*left_pad_target* is ``True``.
|
| 377 |
+
- `src_lang_id` (LongTensor): a long Tensor which contains source
|
| 378 |
+
language IDs of each sample in the batch
|
| 379 |
+
|
| 380 |
+
- `target` (LongTensor): a padded 2D Tensor of tokens in the
|
| 381 |
+
target sentence of shape `(bsz, tgt_len)`. Padding will appear
|
| 382 |
+
on the left if *left_pad_target* is ``True``.
|
| 383 |
+
- `tgt_lang_id` (LongTensor): a long Tensor which contains target language
|
| 384 |
+
IDs of each sample in the batch
|
| 385 |
+
"""
|
| 386 |
+
res = collate(
|
| 387 |
+
samples,
|
| 388 |
+
pad_idx=self.src_dict.pad(),
|
| 389 |
+
eos_idx=self.eos,
|
| 390 |
+
left_pad_source=self.left_pad_source,
|
| 391 |
+
left_pad_target=self.left_pad_target,
|
| 392 |
+
input_feeding=self.input_feeding,
|
| 393 |
+
pad_to_length=pad_to_length,
|
| 394 |
+
pad_to_multiple=self.pad_to_multiple,
|
| 395 |
+
)
|
| 396 |
+
if self.src_lang_id is not None or self.tgt_lang_id is not None:
|
| 397 |
+
src_tokens = res["net_input"]["src_tokens"]
|
| 398 |
+
bsz = src_tokens.size(0)
|
| 399 |
+
if self.src_lang_id is not None:
|
| 400 |
+
res["net_input"]["src_lang_id"] = (
|
| 401 |
+
torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
|
| 402 |
+
)
|
| 403 |
+
if self.tgt_lang_id is not None:
|
| 404 |
+
res["tgt_lang_id"] = (
|
| 405 |
+
torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
|
| 406 |
+
)
|
| 407 |
+
return res
|
| 408 |
+
|
| 409 |
+
def num_tokens(self, index):
|
| 410 |
+
"""Return the number of tokens in a sample. This value is used to
|
| 411 |
+
enforce ``--max-tokens`` during batching."""
|
| 412 |
+
return max(
|
| 413 |
+
self.src_sizes[index],
|
| 414 |
+
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
def size(self, index):
|
| 418 |
+
"""Return an example's size as a float or tuple. This value is used when
|
| 419 |
+
filtering a dataset with ``--max-positions``."""
|
| 420 |
+
return (
|
| 421 |
+
self.src_sizes[index],
|
| 422 |
+
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
def ordered_indices(self):
|
| 426 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
| 427 |
+
on this order."""
|
| 428 |
+
if self.shuffle:
|
| 429 |
+
indices = np.random.permutation(len(self)).astype(np.int64)
|
| 430 |
+
else:
|
| 431 |
+
indices = np.arange(len(self), dtype=np.int64)
|
| 432 |
+
if self.buckets is None:
|
| 433 |
+
# sort by target length, then source length
|
| 434 |
+
if self.tgt_sizes is not None:
|
| 435 |
+
indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
|
| 436 |
+
return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
|
| 437 |
+
else:
|
| 438 |
+
# sort by bucketed_num_tokens, which is:
|
| 439 |
+
# max(padded_src_len, padded_tgt_len)
|
| 440 |
+
return indices[
|
| 441 |
+
np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
|
| 442 |
+
]
|
| 443 |
+
|
| 444 |
+
@property
|
| 445 |
+
def supports_prefetch(self):
|
| 446 |
+
return getattr(self.src, "supports_prefetch", False) and (
|
| 447 |
+
getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
def prefetch(self, indices):
|
| 451 |
+
self.src.prefetch(indices)
|
| 452 |
+
if self.tgt is not None:
|
| 453 |
+
self.tgt.prefetch(indices)
|
| 454 |
+
if self.align_dataset is not None:
|
| 455 |
+
self.align_dataset.prefetch(indices)
|
| 456 |
+
|
| 457 |
+
def filter_indices_by_size(self, indices, max_sizes):
|
| 458 |
+
"""Filter a list of sample indices. Remove those that are longer
|
| 459 |
+
than specified in max_sizes.
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
indices (np.array): original array of sample indices
|
| 463 |
+
max_sizes (int or list[int] or tuple[int]): max sample size,
|
| 464 |
+
can be defined separately for src and tgt (then list or tuple)
|
| 465 |
+
|
| 466 |
+
Returns:
|
| 467 |
+
np.array: filtered sample array
|
| 468 |
+
list: list of removed indices
|
| 469 |
+
"""
|
| 470 |
+
return data_utils.filter_paired_dataset_indices_by_size(
|
| 471 |
+
self.src_sizes,
|
| 472 |
+
self.tgt_sizes,
|
| 473 |
+
indices,
|
| 474 |
+
max_sizes,
|
| 475 |
+
)
|
fairseq-0.10.2/fairseq/data/legacy/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .block_pair_dataset import BlockPairDataset
|
| 7 |
+
from .masked_lm_dataset import MaskedLMDataset
|
| 8 |
+
from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"BertDictionary",
|
| 13 |
+
"BlockPairDataset",
|
| 14 |
+
"MaskedLMDataset",
|
| 15 |
+
"MaskedLMDictionary",
|
| 16 |
+
]
|
fairseq-0.10.2/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (413 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc
ADDED
|
Binary file (8.76 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc
ADDED
|
Binary file (2.05 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/legacy/block_pair_dataset.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq.data import FairseqDataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BlockPairDataset(FairseqDataset):
|
| 14 |
+
"""Break a Dataset of tokens into sentence pair blocks for next sentence
|
| 15 |
+
prediction as well as masked language model.
|
| 16 |
+
|
| 17 |
+
High-level logics are:
|
| 18 |
+
1. break input tensor to tensor blocks
|
| 19 |
+
2. pair the blocks with 50% next sentence and 50% random sentence
|
| 20 |
+
3. return paired blocks as well as related segment labels
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
dataset (~torch.utils.data.Dataset): dataset to break into blocks
|
| 24 |
+
sizes: array of sentence lengths
|
| 25 |
+
dictionary: dictionary for the task
|
| 26 |
+
block_size: maximum block size
|
| 27 |
+
break_mode: mode for breaking copurs into block pairs. currently we support
|
| 28 |
+
2 modes
|
| 29 |
+
doc: respect document boundaries and each part of the pair should belong to on document
|
| 30 |
+
none: don't respect any boundary and cut tokens evenly
|
| 31 |
+
short_seq_prob: probability for generating shorter block pairs
|
| 32 |
+
doc_break_size: Size for empty line separating documents. Typically 1 if
|
| 33 |
+
the sentences have eos, 0 otherwise.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
dataset,
|
| 39 |
+
dictionary,
|
| 40 |
+
sizes,
|
| 41 |
+
block_size,
|
| 42 |
+
break_mode="doc",
|
| 43 |
+
short_seq_prob=0.1,
|
| 44 |
+
doc_break_size=1,
|
| 45 |
+
):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.dataset = dataset
|
| 48 |
+
self.pad = dictionary.pad()
|
| 49 |
+
self.eos = dictionary.eos()
|
| 50 |
+
self.cls = dictionary.cls()
|
| 51 |
+
self.mask = dictionary.mask()
|
| 52 |
+
self.sep = dictionary.sep()
|
| 53 |
+
self.break_mode = break_mode
|
| 54 |
+
self.dictionary = dictionary
|
| 55 |
+
self.short_seq_prob = short_seq_prob
|
| 56 |
+
self.block_indices = []
|
| 57 |
+
|
| 58 |
+
assert len(dataset) == len(sizes)
|
| 59 |
+
|
| 60 |
+
if break_mode == "doc":
|
| 61 |
+
cur_doc = []
|
| 62 |
+
for sent_id, sz in enumerate(sizes):
|
| 63 |
+
assert doc_break_size == 0 or sz != 0, (
|
| 64 |
+
"when doc_break_size is non-zero, we expect documents to be"
|
| 65 |
+
"separated by a blank line with a single eos."
|
| 66 |
+
)
|
| 67 |
+
# empty line as document separator
|
| 68 |
+
if sz == doc_break_size:
|
| 69 |
+
if len(cur_doc) == 0:
|
| 70 |
+
continue
|
| 71 |
+
self.block_indices.append(cur_doc)
|
| 72 |
+
cur_doc = []
|
| 73 |
+
else:
|
| 74 |
+
cur_doc.append(sent_id)
|
| 75 |
+
max_num_tokens = block_size - 3 # Account for [CLS], [SEP], [SEP]
|
| 76 |
+
self.sent_pairs = []
|
| 77 |
+
self.sizes = []
|
| 78 |
+
for doc_id, doc in enumerate(self.block_indices):
|
| 79 |
+
self._generate_sentence_pair(doc, doc_id, max_num_tokens, sizes)
|
| 80 |
+
elif break_mode is None or break_mode == "none":
|
| 81 |
+
# each block should have half of the block size since we are constructing block pair
|
| 82 |
+
sent_length = (block_size - 3) // 2
|
| 83 |
+
total_len = sum(dataset.sizes)
|
| 84 |
+
length = math.ceil(total_len / sent_length)
|
| 85 |
+
|
| 86 |
+
def block_at(i):
|
| 87 |
+
start = i * sent_length
|
| 88 |
+
end = min(start + sent_length, total_len)
|
| 89 |
+
return (start, end)
|
| 90 |
+
|
| 91 |
+
sent_indices = np.array([block_at(i) for i in range(length)])
|
| 92 |
+
sent_sizes = np.array([e - s for s, e in sent_indices])
|
| 93 |
+
dataset_index = self._sent_to_dataset_index(sent_sizes)
|
| 94 |
+
|
| 95 |
+
# pair sentences
|
| 96 |
+
self._pair_sentences(dataset_index)
|
| 97 |
+
else:
|
| 98 |
+
raise ValueError("Invalid break_mode: " + break_mode)
|
| 99 |
+
|
| 100 |
+
def _pair_sentences(self, dataset_index):
|
| 101 |
+
"""
|
| 102 |
+
Give a list of evenly cut blocks/sentences, pair these sentences with 50%
|
| 103 |
+
consecutive sentences and 50% random sentences.
|
| 104 |
+
This is used for none break mode
|
| 105 |
+
"""
|
| 106 |
+
# pair sentences
|
| 107 |
+
for sent_id, sent in enumerate(dataset_index):
|
| 108 |
+
next_sent_label = (
|
| 109 |
+
1 if np.random.rand() > 0.5 and sent_id != len(dataset_index) - 1 else 0
|
| 110 |
+
)
|
| 111 |
+
if next_sent_label:
|
| 112 |
+
next_sent = dataset_index[sent_id + 1]
|
| 113 |
+
else:
|
| 114 |
+
next_sent = dataset_index[
|
| 115 |
+
self._skip_sampling(len(dataset_index), [sent_id, sent_id + 1])
|
| 116 |
+
]
|
| 117 |
+
self.sent_pairs.append((sent, next_sent, next_sent_label))
|
| 118 |
+
|
| 119 |
+
# The current blocks don't include the special tokens but the
|
| 120 |
+
# sizes already account for this
|
| 121 |
+
self.sizes.append(3 + sent[3] + next_sent[3])
|
| 122 |
+
|
| 123 |
+
def _sent_to_dataset_index(self, sent_sizes):
|
| 124 |
+
"""
|
| 125 |
+
Build index mapping block indices to the underlying dataset indices
|
| 126 |
+
"""
|
| 127 |
+
dataset_index = []
|
| 128 |
+
ds_idx, ds_remaining = -1, 0
|
| 129 |
+
for to_consume in sent_sizes:
|
| 130 |
+
sent_size = to_consume
|
| 131 |
+
if ds_remaining == 0:
|
| 132 |
+
ds_idx += 1
|
| 133 |
+
ds_remaining = sent_sizes[ds_idx]
|
| 134 |
+
start_ds_idx = ds_idx
|
| 135 |
+
start_offset = sent_sizes[ds_idx] - ds_remaining
|
| 136 |
+
while to_consume > ds_remaining:
|
| 137 |
+
to_consume -= ds_remaining
|
| 138 |
+
ds_idx += 1
|
| 139 |
+
ds_remaining = sent_sizes[ds_idx]
|
| 140 |
+
ds_remaining -= to_consume
|
| 141 |
+
dataset_index.append(
|
| 142 |
+
(
|
| 143 |
+
start_ds_idx, # starting index in dataset
|
| 144 |
+
start_offset, # starting offset within starting index
|
| 145 |
+
ds_idx, # ending index in dataset
|
| 146 |
+
sent_size, # sentence length
|
| 147 |
+
)
|
| 148 |
+
)
|
| 149 |
+
assert ds_remaining == 0
|
| 150 |
+
assert ds_idx == len(self.dataset) - 1
|
| 151 |
+
return dataset_index
|
| 152 |
+
|
| 153 |
+
def _generate_sentence_pair(self, doc, doc_id, max_num_tokens, sizes):
|
| 154 |
+
"""
|
| 155 |
+
Go through a single document and genrate sentence paris from it
|
| 156 |
+
"""
|
| 157 |
+
current_chunk = []
|
| 158 |
+
current_length = 0
|
| 159 |
+
curr = 0
|
| 160 |
+
# To provide more randomness, we decrease target seq length for parts of
|
| 161 |
+
# samples (10% by default). Note that max_num_tokens is the hard threshold
|
| 162 |
+
# for batching and will never be changed.
|
| 163 |
+
target_seq_length = max_num_tokens
|
| 164 |
+
if np.random.random() < self.short_seq_prob:
|
| 165 |
+
target_seq_length = np.random.randint(2, max_num_tokens)
|
| 166 |
+
# loop through all sentences in document
|
| 167 |
+
while curr < len(doc):
|
| 168 |
+
sent_id = doc[curr]
|
| 169 |
+
current_chunk.append(sent_id)
|
| 170 |
+
current_length = sum(sizes[current_chunk])
|
| 171 |
+
# split chunk and generate pair when exceed target_seq_length or
|
| 172 |
+
# finish the loop
|
| 173 |
+
if curr == len(doc) - 1 or current_length >= target_seq_length:
|
| 174 |
+
# split the chunk into 2 parts
|
| 175 |
+
a_end = 1
|
| 176 |
+
if len(current_chunk) > 2:
|
| 177 |
+
a_end = np.random.randint(1, len(current_chunk) - 1)
|
| 178 |
+
sent_a = current_chunk[:a_end]
|
| 179 |
+
len_a = sum(sizes[sent_a])
|
| 180 |
+
# generate next sentence label, note that if there is only 1 sentence
|
| 181 |
+
# in current chunk, label is always 0
|
| 182 |
+
next_sent_label = (
|
| 183 |
+
1 if np.random.rand() > 0.5 and len(current_chunk) != 1 else 0
|
| 184 |
+
)
|
| 185 |
+
if not next_sent_label:
|
| 186 |
+
# if next sentence label is 0, sample sent_b from a random doc
|
| 187 |
+
target_b_length = target_seq_length - len_a
|
| 188 |
+
rand_doc_id = self._skip_sampling(len(self.block_indices), [doc_id])
|
| 189 |
+
random_doc = self.block_indices[rand_doc_id]
|
| 190 |
+
random_start = np.random.randint(0, len(random_doc))
|
| 191 |
+
sent_b = []
|
| 192 |
+
len_b = 0
|
| 193 |
+
for j in range(random_start, len(random_doc)):
|
| 194 |
+
sent_b.append(random_doc[j])
|
| 195 |
+
len_b = sum(sizes[sent_b])
|
| 196 |
+
if len_b >= target_b_length:
|
| 197 |
+
break
|
| 198 |
+
# return the second part of the chunk since it's not used
|
| 199 |
+
num_unused_segments = len(current_chunk) - a_end
|
| 200 |
+
curr -= num_unused_segments
|
| 201 |
+
else:
|
| 202 |
+
# if next sentence label is 1, use the second part of chunk as sent_B
|
| 203 |
+
sent_b = current_chunk[a_end:]
|
| 204 |
+
len_b = sum(sizes[sent_b])
|
| 205 |
+
# currently sent_a and sent_B may be longer than max_num_tokens,
|
| 206 |
+
# truncate them and return block idx and offsets for them
|
| 207 |
+
sent_a, sent_b = self._truncate_sentences(
|
| 208 |
+
sent_a, sent_b, max_num_tokens
|
| 209 |
+
)
|
| 210 |
+
self.sent_pairs.append((sent_a, sent_b, next_sent_label))
|
| 211 |
+
self.sizes.append(3 + sent_a[3] + sent_b[3])
|
| 212 |
+
current_chunk = []
|
| 213 |
+
curr += 1
|
| 214 |
+
|
| 215 |
+
def _skip_sampling(self, total, skip_ids):
|
| 216 |
+
"""
|
| 217 |
+
Generate a random integer which is not in skip_ids. Sample range is [0, total)
|
| 218 |
+
TODO: ids in skip_ids should be consecutive, we can extend it to more generic version later
|
| 219 |
+
"""
|
| 220 |
+
rand_id = np.random.randint(total - len(skip_ids))
|
| 221 |
+
return rand_id if rand_id < min(skip_ids) else rand_id + len(skip_ids)
|
| 222 |
+
|
| 223 |
+
def _truncate_sentences(self, sent_a, sent_b, max_num_tokens):
|
| 224 |
+
"""
|
| 225 |
+
Trancate a pair of sentence to limit total length under max_num_tokens
|
| 226 |
+
Logics:
|
| 227 |
+
1. Truncate longer sentence
|
| 228 |
+
2. Tokens to be truncated could be at the beginning or the end of the sentnce
|
| 229 |
+
Returns:
|
| 230 |
+
Truncated sentences represented by dataset idx
|
| 231 |
+
"""
|
| 232 |
+
len_a, len_b = sum(self.dataset.sizes[sent_a]), sum(self.dataset.sizes[sent_b])
|
| 233 |
+
front_cut_a = front_cut_b = end_cut_a = end_cut_b = 0
|
| 234 |
+
|
| 235 |
+
while True:
|
| 236 |
+
total_length = (
|
| 237 |
+
len_a + len_b - front_cut_a - front_cut_b - end_cut_a - end_cut_b
|
| 238 |
+
)
|
| 239 |
+
if total_length <= max_num_tokens:
|
| 240 |
+
break
|
| 241 |
+
|
| 242 |
+
if len_a - front_cut_a - end_cut_a > len_b - front_cut_b - end_cut_b:
|
| 243 |
+
if np.random.rand() < 0.5:
|
| 244 |
+
front_cut_a += 1
|
| 245 |
+
else:
|
| 246 |
+
end_cut_a += 1
|
| 247 |
+
else:
|
| 248 |
+
if np.random.rand() < 0.5:
|
| 249 |
+
front_cut_b += 1
|
| 250 |
+
else:
|
| 251 |
+
end_cut_b += 1
|
| 252 |
+
|
| 253 |
+
# calculate ds indices as well as offsets and return
|
| 254 |
+
truncated_sent_a = self._cut_sentence(sent_a, front_cut_a, end_cut_a)
|
| 255 |
+
truncated_sent_b = self._cut_sentence(sent_b, front_cut_b, end_cut_b)
|
| 256 |
+
return truncated_sent_a, truncated_sent_b
|
| 257 |
+
|
| 258 |
+
def _cut_sentence(self, sent, front_cut, end_cut):
|
| 259 |
+
"""
|
| 260 |
+
Cut a sentence based on the numbers of tokens to be cut from beginning and end
|
| 261 |
+
Represent the sentence as dataset idx and return
|
| 262 |
+
"""
|
| 263 |
+
start_ds_idx, end_ds_idx, offset = sent[0], sent[-1], 0
|
| 264 |
+
target_len = sum(self.dataset.sizes[sent]) - front_cut - end_cut
|
| 265 |
+
while front_cut > 0:
|
| 266 |
+
if self.dataset.sizes[start_ds_idx] > front_cut:
|
| 267 |
+
offset += front_cut
|
| 268 |
+
break
|
| 269 |
+
else:
|
| 270 |
+
front_cut -= self.dataset.sizes[start_ds_idx]
|
| 271 |
+
start_ds_idx += 1
|
| 272 |
+
while end_cut > 0:
|
| 273 |
+
if self.dataset.sizes[end_ds_idx] > end_cut:
|
| 274 |
+
break
|
| 275 |
+
else:
|
| 276 |
+
end_cut -= self.dataset.sizes[end_ds_idx]
|
| 277 |
+
end_ds_idx -= 1
|
| 278 |
+
return start_ds_idx, offset, end_ds_idx, target_len
|
| 279 |
+
|
| 280 |
+
def _fetch_block(self, start_ds_idx, offset, end_ds_idx, length):
|
| 281 |
+
"""
|
| 282 |
+
Fetch a block of tokens based on its dataset idx
|
| 283 |
+
"""
|
| 284 |
+
buffer = torch.cat(
|
| 285 |
+
[self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
|
| 286 |
+
)
|
| 287 |
+
s, e = offset, offset + length
|
| 288 |
+
return buffer[s:e]
|
| 289 |
+
|
| 290 |
+
def __getitem__(self, index):
|
| 291 |
+
block1, block2, next_sent_label = self.sent_pairs[index]
|
| 292 |
+
block1 = self._fetch_block(*block1)
|
| 293 |
+
block2 = self._fetch_block(*block2)
|
| 294 |
+
return block1, block2, next_sent_label
|
| 295 |
+
|
| 296 |
+
def __len__(self):
|
| 297 |
+
return len(self.sizes)
|
| 298 |
+
|
| 299 |
+
@property
|
| 300 |
+
def supports_prefetch(self):
|
| 301 |
+
return getattr(self.dataset, "supports_prefetch", False)
|
| 302 |
+
|
| 303 |
+
def prefetch(self, indices):
|
| 304 |
+
prefetch_idx = set()
|
| 305 |
+
for index in indices:
|
| 306 |
+
for block1, block2, _ in [self.sent_pairs[index]]:
|
| 307 |
+
for ds_idx in range(block1[0], block1[2] + 1):
|
| 308 |
+
prefetch_idx.add(ds_idx)
|
| 309 |
+
for ds_idx in range(block2[0], block2[2] + 1):
|
| 310 |
+
prefetch_idx.add(ds_idx)
|
| 311 |
+
self.dataset.prefetch(prefetch_idx)
|
fairseq-0.10.2/fairseq/data/legacy/masked_lm_dictionary.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from fairseq.data import Dictionary
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MaskedLMDictionary(Dictionary):
|
| 10 |
+
"""
|
| 11 |
+
Dictionary for Masked Language Modelling tasks. This extends Dictionary by
|
| 12 |
+
adding the mask symbol.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
pad="<pad>",
|
| 18 |
+
eos="</s>",
|
| 19 |
+
unk="<unk>",
|
| 20 |
+
mask="<mask>",
|
| 21 |
+
):
|
| 22 |
+
super().__init__(pad=pad, eos=eos, unk=unk)
|
| 23 |
+
self.mask_word = mask
|
| 24 |
+
self.mask_index = self.add_symbol(mask)
|
| 25 |
+
self.nspecial = len(self.symbols)
|
| 26 |
+
|
| 27 |
+
def mask(self):
|
| 28 |
+
"""Helper to get index of mask symbol"""
|
| 29 |
+
return self.mask_index
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BertDictionary(MaskedLMDictionary):
|
| 33 |
+
"""
|
| 34 |
+
Dictionary for BERT task. This extends MaskedLMDictionary by adding support
|
| 35 |
+
for cls and sep symbols.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
pad="<pad>",
|
| 41 |
+
eos="</s>",
|
| 42 |
+
unk="<unk>",
|
| 43 |
+
mask="<mask>",
|
| 44 |
+
cls="<cls>",
|
| 45 |
+
sep="<sep>",
|
| 46 |
+
):
|
| 47 |
+
super().__init__(pad=pad, eos=eos, unk=unk, mask=mask)
|
| 48 |
+
self.cls_word = cls
|
| 49 |
+
self.sep_word = sep
|
| 50 |
+
self.cls_index = self.add_symbol(cls)
|
| 51 |
+
self.sep_index = self.add_symbol(sep)
|
| 52 |
+
self.nspecial = len(self.symbols)
|
| 53 |
+
|
| 54 |
+
def cls(self):
|
| 55 |
+
"""Helper to get index of cls symbol"""
|
| 56 |
+
return self.cls_index
|
| 57 |
+
|
| 58 |
+
def sep(self):
|
| 59 |
+
"""Helper to get index of sep symbol"""
|
| 60 |
+
return self.sep_index
|
fairseq-0.10.2/fairseq/data/list_dataset.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from . import BaseWrapperDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ListDataset(BaseWrapperDataset):
|
| 10 |
+
def __init__(self, dataset, sizes=None):
|
| 11 |
+
super().__init__(dataset)
|
| 12 |
+
self._sizes = sizes
|
| 13 |
+
|
| 14 |
+
def __iter__(self):
|
| 15 |
+
for x in self.dataset:
|
| 16 |
+
yield x
|
| 17 |
+
|
| 18 |
+
def collater(self, samples):
|
| 19 |
+
return samples
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def sizes(self):
|
| 23 |
+
return self._sizes
|
| 24 |
+
|
| 25 |
+
def num_tokens(self, index):
|
| 26 |
+
return self.sizes[index]
|
| 27 |
+
|
| 28 |
+
def size(self, index):
|
| 29 |
+
return self.sizes[index]
|
| 30 |
+
|
| 31 |
+
def set_epoch(self, epoch):
|
| 32 |
+
pass
|
fairseq-0.10.2/fairseq/data/lru_cache_dataset.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
|
| 8 |
+
from . import BaseWrapperDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LRUCacheDataset(BaseWrapperDataset):
|
| 12 |
+
def __init__(self, dataset, token=None):
|
| 13 |
+
super().__init__(dataset)
|
| 14 |
+
|
| 15 |
+
@lru_cache(maxsize=8)
|
| 16 |
+
def __getitem__(self, index):
|
| 17 |
+
return self.dataset[index]
|
| 18 |
+
|
| 19 |
+
@lru_cache(maxsize=8)
|
| 20 |
+
def collater(self, samples):
|
| 21 |
+
return self.dataset.collater(samples)
|
fairseq-0.10.2/fairseq/data/mask_tokens_dataset.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq.data import Dictionary, data_utils
|
| 11 |
+
|
| 12 |
+
from . import BaseWrapperDataset, LRUCacheDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MaskTokensDataset(BaseWrapperDataset):
|
| 16 |
+
"""
|
| 17 |
+
A wrapper Dataset for masked language modeling.
|
| 18 |
+
|
| 19 |
+
Input items are masked according to the specified masking probability.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
dataset: Dataset to wrap.
|
| 23 |
+
sizes: Sentence lengths
|
| 24 |
+
vocab: Dictionary with the vocabulary and special tokens.
|
| 25 |
+
pad_idx: Id of pad token in vocab
|
| 26 |
+
mask_idx: Id of mask token in vocab
|
| 27 |
+
return_masked_tokens: controls whether to return the non-masked tokens
|
| 28 |
+
(the default) or to return a tensor with the original masked token
|
| 29 |
+
IDs (and *pad_idx* elsewhere). The latter is useful as targets for
|
| 30 |
+
masked LM training.
|
| 31 |
+
seed: Seed for random number generator for reproducibility.
|
| 32 |
+
mask_prob: probability of replacing a token with *mask_idx*.
|
| 33 |
+
leave_unmasked_prob: probability that a masked token is unmasked.
|
| 34 |
+
random_token_prob: probability of replacing a masked token with a
|
| 35 |
+
random token from the vocabulary.
|
| 36 |
+
freq_weighted_replacement: sample random replacement words based on
|
| 37 |
+
word frequencies in the vocab.
|
| 38 |
+
mask_whole_words: only mask whole words. This should be a byte mask
|
| 39 |
+
over vocab indices, indicating whether it is the beginning of a
|
| 40 |
+
word. We will extend any mask to encompass the whole word.
|
| 41 |
+
bpe: BPE to use for whole-word masking.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
@classmethod
|
| 45 |
+
def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs):
|
| 46 |
+
"""Return the source and target datasets for masked LM training."""
|
| 47 |
+
dataset = LRUCacheDataset(dataset)
|
| 48 |
+
return (
|
| 49 |
+
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)),
|
| 50 |
+
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
dataset: torch.utils.data.Dataset,
|
| 56 |
+
vocab: Dictionary,
|
| 57 |
+
pad_idx: int,
|
| 58 |
+
mask_idx: int,
|
| 59 |
+
return_masked_tokens: bool = False,
|
| 60 |
+
seed: int = 1,
|
| 61 |
+
mask_prob: float = 0.15,
|
| 62 |
+
leave_unmasked_prob: float = 0.1,
|
| 63 |
+
random_token_prob: float = 0.1,
|
| 64 |
+
freq_weighted_replacement: bool = False,
|
| 65 |
+
mask_whole_words: torch.Tensor = None,
|
| 66 |
+
):
|
| 67 |
+
assert 0.0 < mask_prob < 1.0
|
| 68 |
+
assert 0.0 <= random_token_prob <= 1.0
|
| 69 |
+
assert 0.0 <= leave_unmasked_prob <= 1.0
|
| 70 |
+
assert random_token_prob + leave_unmasked_prob <= 1.0
|
| 71 |
+
|
| 72 |
+
self.dataset = dataset
|
| 73 |
+
self.vocab = vocab
|
| 74 |
+
self.pad_idx = pad_idx
|
| 75 |
+
self.mask_idx = mask_idx
|
| 76 |
+
self.return_masked_tokens = return_masked_tokens
|
| 77 |
+
self.seed = seed
|
| 78 |
+
self.mask_prob = mask_prob
|
| 79 |
+
self.leave_unmasked_prob = leave_unmasked_prob
|
| 80 |
+
self.random_token_prob = random_token_prob
|
| 81 |
+
self.mask_whole_words = mask_whole_words
|
| 82 |
+
|
| 83 |
+
if random_token_prob > 0.0:
|
| 84 |
+
if freq_weighted_replacement:
|
| 85 |
+
weights = np.array(self.vocab.count)
|
| 86 |
+
else:
|
| 87 |
+
weights = np.ones(len(self.vocab))
|
| 88 |
+
weights[: self.vocab.nspecial] = 0
|
| 89 |
+
self.weights = weights / weights.sum()
|
| 90 |
+
|
| 91 |
+
self.epoch = 0
|
| 92 |
+
|
| 93 |
+
@property
|
| 94 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
| 95 |
+
return True # only the noise changes, not item sizes
|
| 96 |
+
|
| 97 |
+
def set_epoch(self, epoch, **unused):
|
| 98 |
+
super().set_epoch(epoch)
|
| 99 |
+
self.epoch = epoch
|
| 100 |
+
|
| 101 |
+
@lru_cache(maxsize=8)
|
| 102 |
+
def __getitem__(self, index: int):
|
| 103 |
+
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
| 104 |
+
item = self.dataset[index]
|
| 105 |
+
sz = len(item)
|
| 106 |
+
|
| 107 |
+
assert (
|
| 108 |
+
self.mask_idx not in item
|
| 109 |
+
), "Dataset contains mask_idx (={}), this is not expected!".format(
|
| 110 |
+
self.mask_idx,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if self.mask_whole_words is not None:
|
| 114 |
+
word_begins_mask = self.mask_whole_words.gather(0, item)
|
| 115 |
+
word_begins_idx = word_begins_mask.nonzero().view(-1)
|
| 116 |
+
sz = len(word_begins_idx)
|
| 117 |
+
words = np.split(word_begins_mask, word_begins_idx)[1:]
|
| 118 |
+
assert len(words) == sz
|
| 119 |
+
word_lens = list(map(len, words))
|
| 120 |
+
|
| 121 |
+
# decide elements to mask
|
| 122 |
+
mask = np.full(sz, False)
|
| 123 |
+
num_mask = int(
|
| 124 |
+
# add a random number for probabilistic rounding
|
| 125 |
+
self.mask_prob * sz
|
| 126 |
+
+ np.random.rand()
|
| 127 |
+
)
|
| 128 |
+
mask[np.random.choice(sz, num_mask, replace=False)] = True
|
| 129 |
+
|
| 130 |
+
if self.return_masked_tokens:
|
| 131 |
+
# exit early if we're just returning the masked tokens
|
| 132 |
+
# (i.e., the targets for masked LM training)
|
| 133 |
+
if self.mask_whole_words is not None:
|
| 134 |
+
mask = np.repeat(mask, word_lens)
|
| 135 |
+
new_item = np.full(len(mask), self.pad_idx)
|
| 136 |
+
new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
|
| 137 |
+
return torch.from_numpy(new_item)
|
| 138 |
+
|
| 139 |
+
# decide unmasking and random replacement
|
| 140 |
+
rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
|
| 141 |
+
if rand_or_unmask_prob > 0.0:
|
| 142 |
+
rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
|
| 143 |
+
if self.random_token_prob == 0.0:
|
| 144 |
+
unmask = rand_or_unmask
|
| 145 |
+
rand_mask = None
|
| 146 |
+
elif self.leave_unmasked_prob == 0.0:
|
| 147 |
+
unmask = None
|
| 148 |
+
rand_mask = rand_or_unmask
|
| 149 |
+
else:
|
| 150 |
+
unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
|
| 151 |
+
decision = np.random.rand(sz) < unmask_prob
|
| 152 |
+
unmask = rand_or_unmask & decision
|
| 153 |
+
rand_mask = rand_or_unmask & (~decision)
|
| 154 |
+
else:
|
| 155 |
+
unmask = rand_mask = None
|
| 156 |
+
|
| 157 |
+
if unmask is not None:
|
| 158 |
+
mask = mask ^ unmask
|
| 159 |
+
|
| 160 |
+
if self.mask_whole_words is not None:
|
| 161 |
+
mask = np.repeat(mask, word_lens)
|
| 162 |
+
|
| 163 |
+
new_item = np.copy(item)
|
| 164 |
+
new_item[mask] = self.mask_idx
|
| 165 |
+
if rand_mask is not None:
|
| 166 |
+
num_rand = rand_mask.sum()
|
| 167 |
+
if num_rand > 0:
|
| 168 |
+
if self.mask_whole_words is not None:
|
| 169 |
+
rand_mask = np.repeat(rand_mask, word_lens)
|
| 170 |
+
num_rand = rand_mask.sum()
|
| 171 |
+
|
| 172 |
+
new_item[rand_mask] = np.random.choice(
|
| 173 |
+
len(self.vocab),
|
| 174 |
+
num_rand,
|
| 175 |
+
p=self.weights,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
return torch.from_numpy(new_item)
|
fairseq-0.10.2/fairseq/data/multi_corpus_dataset.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
from typing import Dict, List
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from fairseq.data import data_utils
|
| 12 |
+
|
| 13 |
+
from . import FairseqDataset
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MultiCorpusDataset(FairseqDataset):
|
| 20 |
+
"""
|
| 21 |
+
Stores multiple instances of FairseqDataset together. Requires each instance
|
| 22 |
+
to be the same dataset, as the collate method needs to work on batches with
|
| 23 |
+
samples from each dataset.
|
| 24 |
+
|
| 25 |
+
Allows specifying a distribution over the datasets to use. Note that unlike
|
| 26 |
+
MultiCorpusSampledDataset, this distribution allows sampling for each item,
|
| 27 |
+
rather than on a batch level.
|
| 28 |
+
|
| 29 |
+
Each time ordered_indices() is called, a new sample is generated with
|
| 30 |
+
the specified distribution.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
datasets: a OrderedDict of FairseqDataset instances.
|
| 34 |
+
distribution: a List containing the probability of getting an utterance from
|
| 35 |
+
corresponding dataset
|
| 36 |
+
seed: random seed for sampling the datsets
|
| 37 |
+
sort_indices: if true, will sort the ordered indices by size
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
datasets: Dict[str, FairseqDataset],
|
| 43 |
+
distribution: List[float],
|
| 44 |
+
seed: int,
|
| 45 |
+
sort_indices: bool = False,
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
assert isinstance(datasets, OrderedDict)
|
| 49 |
+
assert len(datasets) == len(distribution)
|
| 50 |
+
self.datasets = datasets
|
| 51 |
+
self.distribution = distribution
|
| 52 |
+
self.seed = seed
|
| 53 |
+
self.sort_indices = sort_indices
|
| 54 |
+
|
| 55 |
+
# Avoid repeated conversions to list later
|
| 56 |
+
self.dataset_list = list(datasets.values())
|
| 57 |
+
self.total_num_instances = 0
|
| 58 |
+
|
| 59 |
+
first_dataset = list(self.datasets.values())[0]
|
| 60 |
+
|
| 61 |
+
self.dataset_offsets = []
|
| 62 |
+
for dataset in datasets.values():
|
| 63 |
+
assert isinstance(dataset, FairseqDataset)
|
| 64 |
+
assert type(dataset) is type(first_dataset)
|
| 65 |
+
self.dataset_offsets.append(self.total_num_instances)
|
| 66 |
+
self.total_num_instances += len(dataset)
|
| 67 |
+
|
| 68 |
+
def ordered_indices(self):
|
| 69 |
+
with data_utils.numpy_seed(self.seed, self.epoch):
|
| 70 |
+
# Used to store the order of indices of each dataset to use
|
| 71 |
+
indices = [
|
| 72 |
+
np.random.permutation(len(dataset))
|
| 73 |
+
for dataset in self.datasets.values()
|
| 74 |
+
]
|
| 75 |
+
# Keep track of which samples we've used for each dataset
|
| 76 |
+
counters = [0 for _ in self.datasets]
|
| 77 |
+
|
| 78 |
+
sampled_indices = [
|
| 79 |
+
self._sample(indices, counters) for _ in range(self.total_num_instances)
|
| 80 |
+
]
|
| 81 |
+
if self.sort_indices:
|
| 82 |
+
sampled_indices.sort(key=lambda i: self.num_tokens(i))
|
| 83 |
+
return np.array(sampled_indices, dtype=np.int64)
|
| 84 |
+
|
| 85 |
+
def _sample(self, indices, counters):
|
| 86 |
+
# First pick dataset
|
| 87 |
+
dataset_idx = np.random.choice(len(self.distribution), p=self.distribution)
|
| 88 |
+
|
| 89 |
+
# Then get dataset internal index
|
| 90 |
+
idx = indices[dataset_idx][counters[dataset_idx]]
|
| 91 |
+
|
| 92 |
+
# Convert to multi-datasets index
|
| 93 |
+
idx += self.dataset_offsets[dataset_idx]
|
| 94 |
+
|
| 95 |
+
counters[dataset_idx] += 1
|
| 96 |
+
|
| 97 |
+
# Reset if we reach end
|
| 98 |
+
if counters[dataset_idx] == len(self.dataset_list[dataset_idx]):
|
| 99 |
+
counters[dataset_idx] = 0
|
| 100 |
+
indices[dataset_idx] = np.random.permutation(
|
| 101 |
+
len(self.dataset_list[dataset_idx])
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return idx
|
| 105 |
+
|
| 106 |
+
def _map_index(self, index: int):
|
| 107 |
+
"""
|
| 108 |
+
If dataset A has length N and dataset B has length M
|
| 109 |
+
then index 1 maps to index 1 of dataset A, and index N + 1
|
| 110 |
+
maps to index 1 of B.
|
| 111 |
+
"""
|
| 112 |
+
counter = 0
|
| 113 |
+
for key, dataset in self.datasets.items():
|
| 114 |
+
if index < counter + len(dataset):
|
| 115 |
+
return index - counter, key
|
| 116 |
+
counter += len(dataset)
|
| 117 |
+
raise ValueError(
|
| 118 |
+
"Invalid index: {}, max: {}".format(index, self.total_num_instances)
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def __len__(self):
|
| 122 |
+
"""
|
| 123 |
+
Length of this dataset is the sum of individual datasets
|
| 124 |
+
"""
|
| 125 |
+
return self.total_num_instances
|
| 126 |
+
|
| 127 |
+
def __getitem__(self, index):
|
| 128 |
+
index, key = self._map_index(index)
|
| 129 |
+
return self.datasets[key][index]
|
| 130 |
+
|
| 131 |
+
def collater(self, samples):
|
| 132 |
+
"""
|
| 133 |
+
Since we enforce all datsets to be the same, collating is just
|
| 134 |
+
picking the first one and doing collate.
|
| 135 |
+
"""
|
| 136 |
+
if len(samples) == 0:
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
return list(self.datasets.values())[0].collater(samples)
|
| 140 |
+
|
| 141 |
+
def num_tokens(self, index: int):
|
| 142 |
+
index, key = self._map_index(index)
|
| 143 |
+
return self.datasets[key].num_tokens(index)
|
| 144 |
+
|
| 145 |
+
def size(self, index: int):
|
| 146 |
+
index, key = self._map_index(index)
|
| 147 |
+
return self.datasets[key].size(index)
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
def set_epoch(self, epoch, **unused):
|
| 154 |
+
super().set_epoch(epoch)
|
| 155 |
+
self.epoch = epoch
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def supports_prefetch(self):
|
| 159 |
+
return False
|
fairseq-0.10.2/fairseq/data/multi_corpus_sampled_dataset.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
from typing import Callable, Dict, List
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from . import FairseqDataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def uniform_sampler(x):
|
| 15 |
+
# Sample from uniform distribution
|
| 16 |
+
return np.random.choice(x, 1).item()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MultiCorpusSampledDataset(FairseqDataset):
|
| 20 |
+
"""
|
| 21 |
+
Stores multiple instances of FairseqDataset together and in every iteration
|
| 22 |
+
creates a batch by first sampling a dataset according to a specified
|
| 23 |
+
probability distribution and then getting instances from that dataset.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
datasets: an OrderedDict of FairseqDataset instances.
|
| 27 |
+
sampling_func: A function for sampling over list of dataset keys.
|
| 28 |
+
The default strategy is to sample uniformly.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
datasets: Dict[str, FairseqDataset],
|
| 34 |
+
sampling_func: Callable[[List], int] = None,
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
assert isinstance(datasets, OrderedDict)
|
| 38 |
+
self.datasets = datasets
|
| 39 |
+
if sampling_func is None:
|
| 40 |
+
sampling_func = uniform_sampler
|
| 41 |
+
self.sampling_func = sampling_func
|
| 42 |
+
|
| 43 |
+
self.total_num_instances = 0
|
| 44 |
+
for _, dataset in datasets.items():
|
| 45 |
+
assert isinstance(dataset, FairseqDataset)
|
| 46 |
+
self.total_num_instances += len(dataset)
|
| 47 |
+
|
| 48 |
+
self._ordered_indices = None
|
| 49 |
+
|
| 50 |
+
def __len__(self):
|
| 51 |
+
"""
|
| 52 |
+
Length of this dataset is the sum of individual datasets
|
| 53 |
+
"""
|
| 54 |
+
return self.total_num_instances
|
| 55 |
+
|
| 56 |
+
def ordered_indices(self):
|
| 57 |
+
"""
|
| 58 |
+
Ordered indices for batching. Here we call the underlying
|
| 59 |
+
dataset's ordered_indices() so that we get the same random ordering
|
| 60 |
+
as we would have from using the underlying dataset directly.
|
| 61 |
+
"""
|
| 62 |
+
if self._ordered_indices is None:
|
| 63 |
+
self._ordered_indices = OrderedDict(
|
| 64 |
+
[
|
| 65 |
+
(key, dataset.ordered_indices())
|
| 66 |
+
for key, dataset in self.datasets.items()
|
| 67 |
+
]
|
| 68 |
+
)
|
| 69 |
+
return np.arange(len(self))
|
| 70 |
+
|
| 71 |
+
def _map_index_to_dataset(self, key: int, index: int):
|
| 72 |
+
"""
|
| 73 |
+
Different underlying datasets have different lengths. In order to ensure
|
| 74 |
+
we are not accessing an index outside the range of the current dataset
|
| 75 |
+
size, we wrap around. This function should be called after we have
|
| 76 |
+
created an ordering for this and all underlying datasets.
|
| 77 |
+
"""
|
| 78 |
+
assert (
|
| 79 |
+
self._ordered_indices is not None
|
| 80 |
+
), "Must call MultiCorpusSampledDataset.ordered_indices() first"
|
| 81 |
+
mapped_index = index % len(self.datasets[key])
|
| 82 |
+
return self._ordered_indices[key][mapped_index]
|
| 83 |
+
|
| 84 |
+
def __getitem__(self, index: int):
|
| 85 |
+
"""
|
| 86 |
+
Get the item associated with index from each underlying dataset.
|
| 87 |
+
Since index is in the range of [0, TotalNumInstances], we need to
|
| 88 |
+
map the index to the dataset before retrieving the item.
|
| 89 |
+
"""
|
| 90 |
+
return OrderedDict(
|
| 91 |
+
[
|
| 92 |
+
(key, dataset[self._map_index_to_dataset(key, index)])
|
| 93 |
+
for key, dataset in self.datasets.items()
|
| 94 |
+
]
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def collater(self, samples: List[Dict]):
|
| 98 |
+
"""
|
| 99 |
+
Generate a mini-batch for this dataset.
|
| 100 |
+
To convert this into a regular mini-batch we use the following
|
| 101 |
+
logic:
|
| 102 |
+
1. Select a dataset using the specified probability distribution.
|
| 103 |
+
2. Call the collater function of the selected dataset.
|
| 104 |
+
"""
|
| 105 |
+
if len(samples) == 0:
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
selected_key = self.sampling_func(list(self.datasets.keys()))
|
| 109 |
+
selected_samples = [sample[selected_key] for sample in samples]
|
| 110 |
+
return self.datasets[selected_key].collater(selected_samples)
|
| 111 |
+
|
| 112 |
+
def num_tokens(self, index: int):
|
| 113 |
+
"""
|
| 114 |
+
Return an example's length (number of tokens), used for batching. Here
|
| 115 |
+
we return the max across all examples at index across all underlying
|
| 116 |
+
datasets.
|
| 117 |
+
"""
|
| 118 |
+
return max(
|
| 119 |
+
dataset.num_tokens(self._map_index_to_dataset(key, index))
|
| 120 |
+
for key, dataset in self.datasets.items()
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def size(self, index: int):
|
| 124 |
+
"""
|
| 125 |
+
Return an example's size as a float or tuple. Here we return the max
|
| 126 |
+
across all underlying datasets. This value is used when filtering a
|
| 127 |
+
dataset with max-positions.
|
| 128 |
+
"""
|
| 129 |
+
return max(
|
| 130 |
+
dataset.size(self._map_index_to_dataset(key, index))
|
| 131 |
+
for key, dataset in self.datasets.items()
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
@property
|
| 135 |
+
def supports_prefetch(self):
|
| 136 |
+
return all(
|
| 137 |
+
getattr(dataset, "supports_prefetch", False)
|
| 138 |
+
for dataset in self.datasets.values()
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def prefetch(self, indices):
|
| 142 |
+
for key, dataset in self.datasets.items():
|
| 143 |
+
dataset.prefetch(
|
| 144 |
+
[self._map_index_to_dataset(key, index) for index in indices]
|
| 145 |
+
)
|
fairseq-0.10.2/fairseq/data/num_samples_dataset.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from . import FairseqDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class NumSamplesDataset(FairseqDataset):
|
| 10 |
+
def __getitem__(self, index):
|
| 11 |
+
return 1
|
| 12 |
+
|
| 13 |
+
def __len__(self):
|
| 14 |
+
return 0
|
| 15 |
+
|
| 16 |
+
def collater(self, samples):
|
| 17 |
+
return sum(samples)
|
fairseq-0.10.2/fairseq/data/offset_tokens_dataset.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from . import BaseWrapperDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class OffsetTokensDataset(BaseWrapperDataset):
|
| 10 |
+
def __init__(self, dataset, offset):
|
| 11 |
+
super().__init__(dataset)
|
| 12 |
+
self.offset = offset
|
| 13 |
+
|
| 14 |
+
def __getitem__(self, idx):
|
| 15 |
+
return self.dataset[idx] + self.offset
|
fairseq-0.10.2/fairseq/data/pad_dataset.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from fairseq.data import data_utils
|
| 7 |
+
|
| 8 |
+
from . import BaseWrapperDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PadDataset(BaseWrapperDataset):
|
| 12 |
+
def __init__(self, dataset, pad_idx, left_pad):
|
| 13 |
+
super().__init__(dataset)
|
| 14 |
+
self.pad_idx = pad_idx
|
| 15 |
+
self.left_pad = left_pad
|
| 16 |
+
|
| 17 |
+
def collater(self, samples):
|
| 18 |
+
return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class LeftPadDataset(PadDataset):
|
| 22 |
+
def __init__(self, dataset, pad_idx):
|
| 23 |
+
super().__init__(dataset, pad_idx, left_pad=True)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class RightPadDataset(PadDataset):
|
| 27 |
+
def __init__(self, dataset, pad_idx):
|
| 28 |
+
super().__init__(dataset, pad_idx, left_pad=False)
|
fairseq-0.10.2/fairseq/data/prepend_dataset.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from . import BaseWrapperDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PrependDataset(BaseWrapperDataset):
|
| 13 |
+
def __init__(self, dataset, prepend_getter, ensure_first_token_is=None):
|
| 14 |
+
super().__init__(dataset)
|
| 15 |
+
self.prepend_getter = prepend_getter
|
| 16 |
+
self.ensure_first_token = ensure_first_token_is
|
| 17 |
+
|
| 18 |
+
def __getitem__(self, idx):
|
| 19 |
+
item = self.dataset[idx]
|
| 20 |
+
is_tuple = isinstance(item, tuple)
|
| 21 |
+
src = item[0] if is_tuple else item
|
| 22 |
+
|
| 23 |
+
assert self.ensure_first_token is None or src[0] == self.ensure_first_token
|
| 24 |
+
prepend_idx = self.prepend_getter(self.dataset, idx)
|
| 25 |
+
assert isinstance(prepend_idx, int)
|
| 26 |
+
src[0] = prepend_idx
|
| 27 |
+
item = tuple((src,) + item[1:]) if is_tuple else src
|
| 28 |
+
return item
|
fairseq-0.10.2/fairseq/data/resampling_dataset.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from fairseq.data import BaseWrapperDataset, plasma_utils
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ResamplingDataset(BaseWrapperDataset):
|
| 16 |
+
"""Randomly samples from a given dataset at each epoch.
|
| 17 |
+
|
| 18 |
+
Sampling is done with or without replacement, depending on the "replace"
|
| 19 |
+
parameter.
|
| 20 |
+
|
| 21 |
+
Optionally, the epoch size can be rescaled. This is potentially desirable
|
| 22 |
+
to increase per-epoch coverage of the base dataset (since sampling with
|
| 23 |
+
replacement means that many items in the dataset will be left out). In the
|
| 24 |
+
case of sampling without replacement, size_ratio should be strictly less
|
| 25 |
+
than 1.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
dataset (~torch.utils.data.Dataset): dataset on which to sample.
|
| 29 |
+
weights (List[float]): list of probability weights
|
| 30 |
+
(default: None, which corresponds to uniform sampling).
|
| 31 |
+
replace (bool): sampling mode; True for "with replacement", or False
|
| 32 |
+
for "without replacement" (default: True)
|
| 33 |
+
size_ratio (float): the ratio to subsample to; must be positive
|
| 34 |
+
(default: 1.0).
|
| 35 |
+
batch_by_size (bool): whether or not to batch by sequence length
|
| 36 |
+
(default: True).
|
| 37 |
+
seed (int): RNG seed to use (default: 0).
|
| 38 |
+
epoch (int): starting epoch number (default: 1).
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
dataset,
|
| 44 |
+
weights=None,
|
| 45 |
+
replace=True,
|
| 46 |
+
size_ratio=1.0,
|
| 47 |
+
batch_by_size=True,
|
| 48 |
+
seed=0,
|
| 49 |
+
epoch=1,
|
| 50 |
+
):
|
| 51 |
+
super().__init__(dataset)
|
| 52 |
+
|
| 53 |
+
if weights is None:
|
| 54 |
+
self.weights = None
|
| 55 |
+
|
| 56 |
+
else:
|
| 57 |
+
assert len(weights) == len(dataset)
|
| 58 |
+
weights_arr = np.array(weights, dtype=np.float64)
|
| 59 |
+
weights_arr /= weights_arr.sum()
|
| 60 |
+
self.weights = plasma_utils.PlasmaArray(weights_arr)
|
| 61 |
+
|
| 62 |
+
self.replace = replace
|
| 63 |
+
|
| 64 |
+
assert size_ratio > 0.0
|
| 65 |
+
if not self.replace:
|
| 66 |
+
assert size_ratio < 1.0
|
| 67 |
+
self.size_ratio = float(size_ratio)
|
| 68 |
+
self.actual_size = np.ceil(len(dataset) * self.size_ratio).astype(int)
|
| 69 |
+
|
| 70 |
+
self.batch_by_size = batch_by_size
|
| 71 |
+
self.seed = seed
|
| 72 |
+
|
| 73 |
+
self._cur_epoch = None
|
| 74 |
+
self._cur_indices = None
|
| 75 |
+
|
| 76 |
+
self.set_epoch(epoch)
|
| 77 |
+
|
| 78 |
+
def __getitem__(self, index):
|
| 79 |
+
return self.dataset[self._cur_indices.array[index]]
|
| 80 |
+
|
| 81 |
+
def __len__(self):
|
| 82 |
+
return self.actual_size
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def sizes(self):
|
| 86 |
+
if isinstance(self.dataset.sizes, list):
|
| 87 |
+
return [s[self._cur_indices.array] for s in self.dataset.sizes]
|
| 88 |
+
return self.dataset.sizes[self._cur_indices.array]
|
| 89 |
+
|
| 90 |
+
def num_tokens(self, index):
|
| 91 |
+
return self.dataset.num_tokens(self._cur_indices.array[index])
|
| 92 |
+
|
| 93 |
+
def size(self, index):
|
| 94 |
+
return self.dataset.size(self._cur_indices.array[index])
|
| 95 |
+
|
| 96 |
+
def ordered_indices(self):
|
| 97 |
+
if self.batch_by_size:
|
| 98 |
+
order = [
|
| 99 |
+
np.arange(len(self)),
|
| 100 |
+
self.sizes,
|
| 101 |
+
] # No need to handle `self.shuffle == True`
|
| 102 |
+
return np.lexsort(order)
|
| 103 |
+
else:
|
| 104 |
+
return np.arange(len(self))
|
| 105 |
+
|
| 106 |
+
def prefetch(self, indices):
|
| 107 |
+
self.dataset.prefetch(self._cur_indices.array[indices])
|
| 108 |
+
|
| 109 |
+
@property
|
| 110 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
| 111 |
+
return False
|
| 112 |
+
|
| 113 |
+
def set_epoch(self, epoch):
|
| 114 |
+
logger.debug("ResamplingDataset.set_epoch: {}".format(epoch))
|
| 115 |
+
super().set_epoch(epoch)
|
| 116 |
+
|
| 117 |
+
if epoch == self._cur_epoch:
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
self._cur_epoch = epoch
|
| 121 |
+
|
| 122 |
+
# Generate a weighted sample of indices as a function of the
|
| 123 |
+
# random seed and the current epoch.
|
| 124 |
+
|
| 125 |
+
rng = np.random.RandomState(
|
| 126 |
+
[
|
| 127 |
+
42, # magic number
|
| 128 |
+
self.seed % (2 ** 32), # global seed
|
| 129 |
+
self._cur_epoch, # epoch index
|
| 130 |
+
]
|
| 131 |
+
)
|
| 132 |
+
self._cur_indices = plasma_utils.PlasmaArray(
|
| 133 |
+
rng.choice(
|
| 134 |
+
len(self.dataset),
|
| 135 |
+
self.actual_size,
|
| 136 |
+
replace=self.replace,
|
| 137 |
+
p=(None if self.weights is None else self.weights.array),
|
| 138 |
+
)
|
| 139 |
+
)
|
fairseq-0.10.2/fairseq/data/shorten_dataset.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from fairseq.data import data_utils
|
| 8 |
+
|
| 9 |
+
from . import BaseWrapperDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TruncateDataset(BaseWrapperDataset):
|
| 13 |
+
"""Truncate a sequence by returning the first truncation_length tokens"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, dataset, truncation_length):
|
| 16 |
+
super().__init__(dataset)
|
| 17 |
+
assert truncation_length is not None
|
| 18 |
+
self.truncation_length = truncation_length
|
| 19 |
+
self.dataset = dataset
|
| 20 |
+
|
| 21 |
+
def __getitem__(self, index):
|
| 22 |
+
item = self.dataset[index]
|
| 23 |
+
item_len = item.size(0)
|
| 24 |
+
if item_len > self.truncation_length:
|
| 25 |
+
item = item[: self.truncation_length]
|
| 26 |
+
return item
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def sizes(self):
|
| 30 |
+
return np.minimum(self.dataset.sizes, self.truncation_length)
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self.dataset)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class RandomCropDataset(TruncateDataset):
|
| 37 |
+
"""Truncate a sequence by returning a random crop of truncation_length tokens"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, dataset, truncation_length, seed=1):
|
| 40 |
+
super().__init__(dataset, truncation_length)
|
| 41 |
+
self.seed = seed
|
| 42 |
+
self.epoch = 0
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
| 46 |
+
return True # only the crop changes, not item sizes
|
| 47 |
+
|
| 48 |
+
def set_epoch(self, epoch, **unused):
|
| 49 |
+
super().set_epoch(epoch)
|
| 50 |
+
self.epoch = epoch
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, index):
|
| 53 |
+
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
| 54 |
+
item = self.dataset[index]
|
| 55 |
+
item_len = item.size(0)
|
| 56 |
+
excess = item_len - self.truncation_length
|
| 57 |
+
if excess > 0:
|
| 58 |
+
start_idx = np.random.randint(0, excess)
|
| 59 |
+
item = item[start_idx : start_idx + self.truncation_length]
|
| 60 |
+
return item
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def maybe_shorten_dataset(
|
| 64 |
+
dataset,
|
| 65 |
+
split,
|
| 66 |
+
shorten_data_split_list,
|
| 67 |
+
shorten_method,
|
| 68 |
+
tokens_per_sample,
|
| 69 |
+
seed,
|
| 70 |
+
):
|
| 71 |
+
truncate_split = (
|
| 72 |
+
split in shorten_data_split_list.split(",") or len(shorten_data_split_list) == 0
|
| 73 |
+
)
|
| 74 |
+
if shorten_method == "truncate" and truncate_split:
|
| 75 |
+
dataset = TruncateDataset(dataset, tokens_per_sample)
|
| 76 |
+
elif shorten_method == "random_crop" and truncate_split:
|
| 77 |
+
dataset = RandomCropDataset(dataset, tokens_per_sample, seed)
|
| 78 |
+
return dataset
|
fairseq-0.10.2/fairseq/data/strip_token_dataset.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from . import BaseWrapperDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class StripTokenDataset(BaseWrapperDataset):
|
| 10 |
+
def __init__(self, dataset, id_to_strip):
|
| 11 |
+
super().__init__(dataset)
|
| 12 |
+
self.id_to_strip = id_to_strip
|
| 13 |
+
|
| 14 |
+
def __getitem__(self, index):
|
| 15 |
+
item = self.dataset[index]
|
| 16 |
+
while len(item) > 0 and item[-1] == self.id_to_strip:
|
| 17 |
+
item = item[:-1]
|
| 18 |
+
while len(item) > 0 and item[0] == self.id_to_strip:
|
| 19 |
+
item = item[1:]
|
| 20 |
+
return item
|
fairseq-0.10.2/fairseq/data/token_block_dataset.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from fairseq.data import FairseqDataset, plasma_utils
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TokenBlockDataset(FairseqDataset):
|
| 12 |
+
"""Break a Dataset of tokens into blocks.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
dataset (~torch.utils.data.Dataset): dataset to break into blocks
|
| 16 |
+
sizes (List[int]): sentence lengths (required for 'complete' and 'eos')
|
| 17 |
+
block_size (int): maximum block size (ignored in 'eos' break mode)
|
| 18 |
+
break_mode (str, optional): Mode used for breaking tokens. Values can
|
| 19 |
+
be one of:
|
| 20 |
+
- 'none': break tokens into equally sized blocks (up to block_size)
|
| 21 |
+
- 'complete': break tokens into blocks (up to block_size) such that
|
| 22 |
+
blocks contains complete sentences, although block_size may be
|
| 23 |
+
exceeded if some sentences exceed block_size
|
| 24 |
+
- 'complete_doc': similar to 'complete' mode, but do not
|
| 25 |
+
cross document boundaries
|
| 26 |
+
- 'eos': each block contains one sentence (block_size is ignored)
|
| 27 |
+
include_targets (bool, optional): return next tokens as targets
|
| 28 |
+
(default: False).
|
| 29 |
+
document_sep_len (int, optional): document separator size (required for
|
| 30 |
+
'complete_doc' break mode). Typically 1 if the sentences have eos
|
| 31 |
+
and 0 otherwise.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
dataset,
|
| 37 |
+
sizes,
|
| 38 |
+
block_size,
|
| 39 |
+
pad,
|
| 40 |
+
eos,
|
| 41 |
+
break_mode=None,
|
| 42 |
+
include_targets=False,
|
| 43 |
+
document_sep_len=1,
|
| 44 |
+
):
|
| 45 |
+
try:
|
| 46 |
+
from fairseq.data.token_block_utils_fast import (
|
| 47 |
+
_get_slice_indices_fast,
|
| 48 |
+
_get_block_to_dataset_index_fast,
|
| 49 |
+
)
|
| 50 |
+
except ImportError:
|
| 51 |
+
raise ImportError(
|
| 52 |
+
"Please build Cython components with: `pip install --editable .` "
|
| 53 |
+
"or `python setup.py build_ext --inplace`"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.dataset = dataset
|
| 58 |
+
self.pad = pad
|
| 59 |
+
self.eos = eos
|
| 60 |
+
self.include_targets = include_targets
|
| 61 |
+
|
| 62 |
+
assert len(dataset) == len(sizes)
|
| 63 |
+
assert len(dataset) > 0
|
| 64 |
+
|
| 65 |
+
if isinstance(sizes, list):
|
| 66 |
+
sizes = np.array(sizes, dtype=np.int64)
|
| 67 |
+
else:
|
| 68 |
+
if torch.is_tensor(sizes):
|
| 69 |
+
sizes = sizes.numpy()
|
| 70 |
+
sizes = sizes.astype(np.int64)
|
| 71 |
+
|
| 72 |
+
break_mode = break_mode if break_mode is not None else "none"
|
| 73 |
+
|
| 74 |
+
# For "eos" break-mode, block_size is not required parameters.
|
| 75 |
+
if break_mode == "eos" and block_size is None:
|
| 76 |
+
block_size = 0
|
| 77 |
+
|
| 78 |
+
slice_indices = _get_slice_indices_fast(
|
| 79 |
+
sizes, str(break_mode), block_size, document_sep_len
|
| 80 |
+
)
|
| 81 |
+
self._sizes = slice_indices[:, 1] - slice_indices[:, 0]
|
| 82 |
+
|
| 83 |
+
# build index mapping block indices to the underlying dataset indices
|
| 84 |
+
if break_mode == "eos":
|
| 85 |
+
# much faster version for eos break mode
|
| 86 |
+
block_to_dataset_index = np.stack(
|
| 87 |
+
[
|
| 88 |
+
np.arange(len(sizes)), # starting index in dataset
|
| 89 |
+
np.zeros(
|
| 90 |
+
len(sizes), dtype=np.long
|
| 91 |
+
), # starting offset within starting index
|
| 92 |
+
np.arange(len(sizes)), # ending index in dataset
|
| 93 |
+
],
|
| 94 |
+
1,
|
| 95 |
+
)
|
| 96 |
+
else:
|
| 97 |
+
block_to_dataset_index = _get_block_to_dataset_index_fast(
|
| 98 |
+
sizes,
|
| 99 |
+
slice_indices,
|
| 100 |
+
)
|
| 101 |
+
self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
|
| 102 |
+
self._sizes = plasma_utils.PlasmaArray(self._sizes)
|
| 103 |
+
self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index)
|
| 104 |
+
|
| 105 |
+
@property
|
| 106 |
+
def slice_indices(self):
|
| 107 |
+
return self._slice_indices.array
|
| 108 |
+
|
| 109 |
+
@property
|
| 110 |
+
def sizes(self):
|
| 111 |
+
return self._sizes.array
|
| 112 |
+
|
| 113 |
+
@property
|
| 114 |
+
def block_to_dataset_index(self):
|
| 115 |
+
return self._block_to_dataset_index.array
|
| 116 |
+
|
| 117 |
+
def attr(self, attr: str, index: int):
|
| 118 |
+
start_ds_idx, _, _ = self.block_to_dataset_index[index]
|
| 119 |
+
return self.dataset.attr(attr, start_ds_idx)
|
| 120 |
+
|
| 121 |
+
def __getitem__(self, index):
|
| 122 |
+
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
|
| 123 |
+
|
| 124 |
+
buffer = torch.cat(
|
| 125 |
+
[self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
slice_s, slice_e = self.slice_indices[index]
|
| 129 |
+
length = slice_e - slice_s
|
| 130 |
+
s, e = start_offset, start_offset + length
|
| 131 |
+
item = buffer[s:e]
|
| 132 |
+
|
| 133 |
+
if self.include_targets:
|
| 134 |
+
# *target* is the original sentence (=item)
|
| 135 |
+
# *source* is shifted right by 1 (maybe left-padded with eos)
|
| 136 |
+
# *past_target* is shifted right by 2 (left-padded as needed)
|
| 137 |
+
if s == 0:
|
| 138 |
+
source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]])
|
| 139 |
+
past_target = torch.cat(
|
| 140 |
+
[item.new([self.pad, self.eos]), buffer[0 : e - 2]]
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
source = buffer[s - 1 : e - 1]
|
| 144 |
+
if s == 1:
|
| 145 |
+
past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]])
|
| 146 |
+
else:
|
| 147 |
+
past_target = buffer[s - 2 : e - 2]
|
| 148 |
+
|
| 149 |
+
return source, item, past_target
|
| 150 |
+
|
| 151 |
+
return item
|
| 152 |
+
|
| 153 |
+
def __len__(self):
|
| 154 |
+
return len(self.slice_indices)
|
| 155 |
+
|
| 156 |
+
@property
|
| 157 |
+
def supports_prefetch(self):
|
| 158 |
+
return getattr(self.dataset, "supports_prefetch", False)
|
| 159 |
+
|
| 160 |
+
def prefetch(self, indices):
|
| 161 |
+
self.dataset.prefetch(
|
| 162 |
+
{
|
| 163 |
+
ds_idx
|
| 164 |
+
for index in indices
|
| 165 |
+
for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]]
|
| 166 |
+
for ds_idx in range(start_ds_idx, end_ds_idx + 1)
|
| 167 |
+
}
|
| 168 |
+
)
|