Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq-0.10.2/fairseq/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/__pycache__/binarizer.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/__pycache__/checkpoint_utils.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/__pycache__/distributed_utils.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/__pycache__/file_utils.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/__pycache__/quantization_utils.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/__pycache__/search.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/__pycache__/token_generation_constraints.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/clib/libbleu/module.cpp +37 -0
- fairseq-0.10.2/fairseq/criterions/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/criterions/__pycache__/wav2vec_criterion.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__init__.py +124 -0
- fairseq-0.10.2/fairseq/data/add_target_dataset.py +70 -0
- fairseq-0.10.2/fairseq/data/append_token_dataset.py +41 -0
- fairseq-0.10.2/fairseq/data/backtranslation_dataset.py +165 -0
- fairseq-0.10.2/fairseq/data/base_wrapper_dataset.py +78 -0
- fairseq-0.10.2/fairseq/data/bucket_pad_length_dataset.py +76 -0
- fairseq-0.10.2/fairseq/data/colorize_dataset.py +25 -0
- fairseq-0.10.2/fairseq/data/data_utils.py +499 -0
- fairseq-0.10.2/fairseq/data/data_utils_fast.cpp +0 -0
- fairseq-0.10.2/fairseq/data/data_utils_fast.pyx +123 -0
- fairseq-0.10.2/fairseq/data/denoising_dataset.py +436 -0
- fairseq-0.10.2/fairseq/data/indexed_dataset.py +561 -0
- fairseq-0.10.2/fairseq/data/iterators.py +594 -0
- fairseq-0.10.2/fairseq/data/legacy/masked_lm_dataset.py +303 -0
- fairseq-0.10.2/fairseq/data/lm_context_window_dataset.py +79 -0
- fairseq-0.10.2/fairseq/data/monolingual_dataset.py +230 -0
- fairseq-0.10.2/fairseq/data/nested_dictionary_dataset.py +125 -0
- fairseq-0.10.2/fairseq/data/noising.py +333 -0
- fairseq-0.10.2/fairseq/data/numel_dataset.py +31 -0
- fairseq-0.10.2/fairseq/data/plasma_utils.py +91 -0
- fairseq-0.10.2/fairseq/data/prepend_token_dataset.py +41 -0
- fairseq-0.10.2/fairseq/data/raw_label_dataset.py +23 -0
- fairseq-0.10.2/fairseq/data/replace_dataset.py +36 -0
- fairseq-0.10.2/fairseq/data/roll_dataset.py +18 -0
- fairseq-0.10.2/fairseq/data/round_robin_zip_datasets.py +117 -0
- fairseq-0.10.2/fairseq/data/sort_dataset.py +21 -0
- fairseq-0.10.2/fairseq/data/subsample_dataset.py +72 -0
- fairseq-0.10.2/fairseq/data/token_block_utils_fast.cpp +0 -0
- fairseq-0.10.2/fairseq/data/token_block_utils_fast.pyx +187 -0
- fairseq-0.10.2/fairseq/data/transform_eos_dataset.py +120 -0
fairseq-0.10.2/fairseq/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (780 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/__pycache__/binarizer.cpython-310.pyc
ADDED
|
Binary file (3.11 kB). View file
|
|
|
fairseq-0.10.2/fairseq/__pycache__/checkpoint_utils.cpython-310.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
fairseq-0.10.2/fairseq/__pycache__/distributed_utils.cpython-310.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
fairseq-0.10.2/fairseq/__pycache__/file_utils.cpython-310.pyc
ADDED
|
Binary file (8.69 kB). View file
|
|
|
fairseq-0.10.2/fairseq/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc
ADDED
|
Binary file (5.01 kB). View file
|
|
|
fairseq-0.10.2/fairseq/__pycache__/quantization_utils.cpython-310.pyc
ADDED
|
Binary file (3.53 kB). View file
|
|
|
fairseq-0.10.2/fairseq/__pycache__/search.cpython-310.pyc
ADDED
|
Binary file (21.9 kB). View file
|
|
|
fairseq-0.10.2/fairseq/__pycache__/token_generation_constraints.cpython-310.pyc
ADDED
|
Binary file (16.2 kB). View file
|
|
|
fairseq-0.10.2/fairseq/clib/libbleu/module.cpp
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 2017-present, Facebook, Inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
*
|
| 5 |
+
* This source code is licensed under the license found in the
|
| 6 |
+
* LICENSE file in the root directory of this source tree.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
#include <Python.h>
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
static PyMethodDef method_def[] = {
|
| 13 |
+
{NULL, NULL, 0, NULL}
|
| 14 |
+
};
|
| 15 |
+
|
| 16 |
+
static struct PyModuleDef module_def = {
|
| 17 |
+
PyModuleDef_HEAD_INIT,
|
| 18 |
+
"libbleu", /* name of module */
|
| 19 |
+
NULL, /* module documentation, may be NULL */
|
| 20 |
+
-1, /* size of per-interpreter state of the module,
|
| 21 |
+
or -1 if the module keeps state in global variables. */
|
| 22 |
+
method_def
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
#if PY_MAJOR_VERSION == 2
|
| 27 |
+
PyMODINIT_FUNC init_libbleu()
|
| 28 |
+
#else
|
| 29 |
+
PyMODINIT_FUNC PyInit_libbleu()
|
| 30 |
+
#endif
|
| 31 |
+
{
|
| 32 |
+
PyObject *m = PyModule_Create(&module_def);
|
| 33 |
+
if (!m) {
|
| 34 |
+
return NULL;
|
| 35 |
+
}
|
| 36 |
+
return m;
|
| 37 |
+
}
|
fairseq-0.10.2/fairseq/criterions/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.07 kB). View file
|
|
|
fairseq-0.10.2/fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc
ADDED
|
Binary file (4.68 kB). View file
|
|
|
fairseq-0.10.2/fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc
ADDED
|
Binary file (4.45 kB). View file
|
|
|
fairseq-0.10.2/fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc
ADDED
|
Binary file (3.83 kB). View file
|
|
|
fairseq-0.10.2/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc
ADDED
|
Binary file (6.12 kB). View file
|
|
|
fairseq-0.10.2/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc
ADDED
|
Binary file (4.63 kB). View file
|
|
|
fairseq-0.10.2/fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc
ADDED
|
Binary file (3.15 kB). View file
|
|
|
fairseq-0.10.2/fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc
ADDED
|
Binary file (5.93 kB). View file
|
|
|
fairseq-0.10.2/fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc
ADDED
|
Binary file (4.53 kB). View file
|
|
|
fairseq-0.10.2/fairseq/criterions/__pycache__/wav2vec_criterion.cpython-310.pyc
ADDED
|
Binary file (5.57 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__init__.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""isort:skip_file"""
|
| 6 |
+
|
| 7 |
+
from .dictionary import Dictionary, TruncatedDictionary
|
| 8 |
+
|
| 9 |
+
from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
|
| 10 |
+
|
| 11 |
+
from .base_wrapper_dataset import BaseWrapperDataset
|
| 12 |
+
|
| 13 |
+
from .add_target_dataset import AddTargetDataset
|
| 14 |
+
from .append_token_dataset import AppendTokenDataset
|
| 15 |
+
from .audio.raw_audio_dataset import FileAudioDataset
|
| 16 |
+
from .backtranslation_dataset import BacktranslationDataset
|
| 17 |
+
from .bucket_pad_length_dataset import BucketPadLengthDataset
|
| 18 |
+
from .colorize_dataset import ColorizeDataset
|
| 19 |
+
from .concat_dataset import ConcatDataset
|
| 20 |
+
from .concat_sentences_dataset import ConcatSentencesDataset
|
| 21 |
+
from .denoising_dataset import DenoisingDataset
|
| 22 |
+
from .id_dataset import IdDataset
|
| 23 |
+
from .indexed_dataset import (
|
| 24 |
+
IndexedCachedDataset,
|
| 25 |
+
IndexedDataset,
|
| 26 |
+
IndexedRawTextDataset,
|
| 27 |
+
MMapIndexedDataset,
|
| 28 |
+
)
|
| 29 |
+
from .language_pair_dataset import LanguagePairDataset
|
| 30 |
+
from .list_dataset import ListDataset
|
| 31 |
+
from .lm_context_window_dataset import LMContextWindowDataset
|
| 32 |
+
from .lru_cache_dataset import LRUCacheDataset
|
| 33 |
+
from .mask_tokens_dataset import MaskTokensDataset
|
| 34 |
+
from .monolingual_dataset import MonolingualDataset
|
| 35 |
+
from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset
|
| 36 |
+
from .nested_dictionary_dataset import NestedDictionaryDataset
|
| 37 |
+
from .noising import NoisingDataset
|
| 38 |
+
from .numel_dataset import NumelDataset
|
| 39 |
+
from .num_samples_dataset import NumSamplesDataset
|
| 40 |
+
from .offset_tokens_dataset import OffsetTokensDataset
|
| 41 |
+
from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
|
| 42 |
+
from .prepend_dataset import PrependDataset
|
| 43 |
+
from .prepend_token_dataset import PrependTokenDataset
|
| 44 |
+
from .raw_label_dataset import RawLabelDataset
|
| 45 |
+
from .replace_dataset import ReplaceDataset
|
| 46 |
+
from .resampling_dataset import ResamplingDataset
|
| 47 |
+
from .roll_dataset import RollDataset
|
| 48 |
+
from .round_robin_zip_datasets import RoundRobinZipDatasets
|
| 49 |
+
from .sort_dataset import SortDataset
|
| 50 |
+
from .strip_token_dataset import StripTokenDataset
|
| 51 |
+
from .subsample_dataset import SubsampleDataset
|
| 52 |
+
from .token_block_dataset import TokenBlockDataset
|
| 53 |
+
from .transform_eos_dataset import TransformEosDataset
|
| 54 |
+
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
|
| 55 |
+
from .shorten_dataset import TruncateDataset, RandomCropDataset
|
| 56 |
+
from .multilingual.sampled_multi_dataset import SampledMultiDataset
|
| 57 |
+
from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset
|
| 58 |
+
from .fasta_dataset import FastaDataset, EncodedFastaDataset
|
| 59 |
+
|
| 60 |
+
from .iterators import (
|
| 61 |
+
CountingIterator,
|
| 62 |
+
EpochBatchIterator,
|
| 63 |
+
GroupedIterator,
|
| 64 |
+
ShardedIterator,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
__all__ = [
|
| 68 |
+
"AddTargetDataset",
|
| 69 |
+
"AppendTokenDataset",
|
| 70 |
+
"BacktranslationDataset",
|
| 71 |
+
"BaseWrapperDataset",
|
| 72 |
+
"BucketPadLengthDataset",
|
| 73 |
+
"ColorizeDataset",
|
| 74 |
+
"ConcatDataset",
|
| 75 |
+
"ConcatSentencesDataset",
|
| 76 |
+
"CountingIterator",
|
| 77 |
+
"DenoisingDataset",
|
| 78 |
+
"Dictionary",
|
| 79 |
+
"EncodedFastaDataset",
|
| 80 |
+
"EpochBatchIterator",
|
| 81 |
+
"FairseqDataset",
|
| 82 |
+
"FairseqIterableDataset",
|
| 83 |
+
"FastaDataset",
|
| 84 |
+
"GroupedIterator",
|
| 85 |
+
"IdDataset",
|
| 86 |
+
"IndexedCachedDataset",
|
| 87 |
+
"IndexedDataset",
|
| 88 |
+
"IndexedRawTextDataset",
|
| 89 |
+
"LanguagePairDataset",
|
| 90 |
+
"LeftPadDataset",
|
| 91 |
+
"ListDataset",
|
| 92 |
+
"LMContextWindowDataset",
|
| 93 |
+
"LRUCacheDataset",
|
| 94 |
+
"MaskTokensDataset",
|
| 95 |
+
"MMapIndexedDataset",
|
| 96 |
+
"MonolingualDataset",
|
| 97 |
+
"MultiCorpusSampledDataset",
|
| 98 |
+
"NestedDictionaryDataset",
|
| 99 |
+
"NoisingDataset",
|
| 100 |
+
"NumelDataset",
|
| 101 |
+
"NumSamplesDataset",
|
| 102 |
+
"OffsetTokensDataset",
|
| 103 |
+
"PadDataset",
|
| 104 |
+
"PrependDataset",
|
| 105 |
+
"PrependTokenDataset",
|
| 106 |
+
"ReplaceDataset",
|
| 107 |
+
"RollDataset",
|
| 108 |
+
"FileAudioDataset",
|
| 109 |
+
"RawLabelDataset",
|
| 110 |
+
"ResamplingDataset",
|
| 111 |
+
"RightPadDataset",
|
| 112 |
+
"RoundRobinZipDatasets",
|
| 113 |
+
"SampledMultiDataset",
|
| 114 |
+
"SampledMultiEpochDataset",
|
| 115 |
+
"ShardedIterator",
|
| 116 |
+
"SortDataset",
|
| 117 |
+
"StripTokenDataset",
|
| 118 |
+
"SubsampleDataset",
|
| 119 |
+
"TokenBlockDataset",
|
| 120 |
+
"TransformEosDataset",
|
| 121 |
+
"TransformEosLangPairDataset",
|
| 122 |
+
"TruncateDataset",
|
| 123 |
+
"TruncatedDictionary",
|
| 124 |
+
]
|
fairseq-0.10.2/fairseq/data/add_target_dataset.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from . import BaseWrapperDataset, data_utils
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AddTargetDataset(BaseWrapperDataset):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
dataset,
|
| 15 |
+
labels,
|
| 16 |
+
pad,
|
| 17 |
+
eos,
|
| 18 |
+
batch_targets,
|
| 19 |
+
process_label=None,
|
| 20 |
+
add_to_input=False,
|
| 21 |
+
):
|
| 22 |
+
super().__init__(dataset)
|
| 23 |
+
self.labels = labels
|
| 24 |
+
self.batch_targets = batch_targets
|
| 25 |
+
self.pad = pad
|
| 26 |
+
self.eos = eos
|
| 27 |
+
self.process_label = process_label
|
| 28 |
+
self.add_to_input = add_to_input
|
| 29 |
+
|
| 30 |
+
def get_label(self, index):
|
| 31 |
+
return (
|
| 32 |
+
self.labels[index]
|
| 33 |
+
if self.process_label is None
|
| 34 |
+
else self.process_label(self.labels[index])
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, index):
|
| 38 |
+
item = self.dataset[index]
|
| 39 |
+
item["label"] = self.get_label(index)
|
| 40 |
+
return item
|
| 41 |
+
|
| 42 |
+
def size(self, index):
|
| 43 |
+
sz = self.dataset.size(index)
|
| 44 |
+
own_sz = len(self.get_label(index))
|
| 45 |
+
return (sz, own_sz)
|
| 46 |
+
|
| 47 |
+
def collater(self, samples):
|
| 48 |
+
collated = self.dataset.collater(samples)
|
| 49 |
+
if len(collated) == 0:
|
| 50 |
+
return collated
|
| 51 |
+
indices = set(collated["id"].tolist())
|
| 52 |
+
target = [s["label"] for s in samples if s["id"] in indices]
|
| 53 |
+
|
| 54 |
+
if self.batch_targets:
|
| 55 |
+
collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
|
| 56 |
+
target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
|
| 57 |
+
collated["ntokens"] = collated["target_lengths"].sum().item()
|
| 58 |
+
else:
|
| 59 |
+
collated["ntokens"] = sum([len(t) for t in target])
|
| 60 |
+
|
| 61 |
+
collated["target"] = target
|
| 62 |
+
|
| 63 |
+
if self.add_to_input:
|
| 64 |
+
eos = target.new_full((target.size(0), 1), self.eos)
|
| 65 |
+
collated["target"] = torch.cat([target, eos], dim=-1).long()
|
| 66 |
+
collated["net_input"]["prev_output_tokens"] = torch.cat(
|
| 67 |
+
[eos, target], dim=-1
|
| 68 |
+
).long()
|
| 69 |
+
collated["ntokens"] += target.size(0)
|
| 70 |
+
return collated
|
fairseq-0.10.2/fairseq/data/append_token_dataset.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from . import BaseWrapperDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AppendTokenDataset(BaseWrapperDataset):
|
| 13 |
+
def __init__(self, dataset, token=None):
|
| 14 |
+
super().__init__(dataset)
|
| 15 |
+
self.token = token
|
| 16 |
+
if token is not None:
|
| 17 |
+
self._sizes = np.array(dataset.sizes) + 1
|
| 18 |
+
else:
|
| 19 |
+
self._sizes = dataset.sizes
|
| 20 |
+
|
| 21 |
+
def __getitem__(self, idx):
|
| 22 |
+
item = self.dataset[idx]
|
| 23 |
+
if self.token is not None:
|
| 24 |
+
item = torch.cat([item, item.new([self.token])])
|
| 25 |
+
return item
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def sizes(self):
|
| 29 |
+
return self._sizes
|
| 30 |
+
|
| 31 |
+
def num_tokens(self, index):
|
| 32 |
+
n = self.dataset.num_tokens(index)
|
| 33 |
+
if self.token is not None:
|
| 34 |
+
n += 1
|
| 35 |
+
return n
|
| 36 |
+
|
| 37 |
+
def size(self, index):
|
| 38 |
+
n = self.dataset.size(index)
|
| 39 |
+
if self.token is not None:
|
| 40 |
+
n += 1
|
| 41 |
+
return n
|
fairseq-0.10.2/fairseq/data/backtranslation_dataset.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from fairseq import utils
|
| 8 |
+
|
| 9 |
+
from . import FairseqDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
|
| 13 |
+
"""Backtranslate a list of samples.
|
| 14 |
+
|
| 15 |
+
Given an input (*samples*) of the form:
|
| 16 |
+
|
| 17 |
+
[{'id': 1, 'source': 'hallo welt'}]
|
| 18 |
+
|
| 19 |
+
this will return:
|
| 20 |
+
|
| 21 |
+
[{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
samples (List[dict]): samples to backtranslate. Individual samples are
|
| 25 |
+
expected to have a 'source' key, which will become the 'target'
|
| 26 |
+
after backtranslation.
|
| 27 |
+
collate_fn (callable): function to collate samples into a mini-batch
|
| 28 |
+
generate_fn (callable): function to generate backtranslations
|
| 29 |
+
cuda (bool): use GPU for generation (default: ``True``)
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
List[dict]: an updated list of samples with a backtranslated source
|
| 33 |
+
"""
|
| 34 |
+
collated_samples = collate_fn(samples)
|
| 35 |
+
s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
|
| 36 |
+
generated_sources = generate_fn(s)
|
| 37 |
+
|
| 38 |
+
id_to_src = {sample["id"]: sample["source"] for sample in samples}
|
| 39 |
+
|
| 40 |
+
# Go through each tgt sentence in batch and its corresponding best
|
| 41 |
+
# generated hypothesis and create a backtranslation data pair
|
| 42 |
+
# {id: id, source: generated backtranslation, target: original tgt}
|
| 43 |
+
return [
|
| 44 |
+
{
|
| 45 |
+
"id": id.item(),
|
| 46 |
+
"target": id_to_src[id.item()],
|
| 47 |
+
"source": hypos[0]["tokens"].cpu(),
|
| 48 |
+
}
|
| 49 |
+
for id, hypos in zip(collated_samples["id"], generated_sources)
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class BacktranslationDataset(FairseqDataset):
|
| 54 |
+
"""
|
| 55 |
+
Sets up a backtranslation dataset which takes a tgt batch, generates
|
| 56 |
+
a src using a tgt-src backtranslation function (*backtranslation_fn*),
|
| 57 |
+
and returns the corresponding `{generated src, input tgt}` batch.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
|
| 61 |
+
backtranslated. Only the source side of this dataset will be used.
|
| 62 |
+
After backtranslation, the source sentences in this dataset will be
|
| 63 |
+
returned as the targets.
|
| 64 |
+
src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
|
| 65 |
+
sentences.
|
| 66 |
+
tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
|
| 67 |
+
sentences to be backtranslated.
|
| 68 |
+
backtranslation_fn (callable, optional): function to call to generate
|
| 69 |
+
backtranslations. This is typically the `generate` method of a
|
| 70 |
+
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
|
| 71 |
+
Pass in None when it is not available at initialization time, and
|
| 72 |
+
use set_backtranslation_fn function to set it when available.
|
| 73 |
+
output_collater (callable, optional): function to call on the
|
| 74 |
+
backtranslated samples to create the final batch
|
| 75 |
+
(default: ``tgt_dataset.collater``).
|
| 76 |
+
cuda: use GPU for generation
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
tgt_dataset,
|
| 82 |
+
src_dict,
|
| 83 |
+
tgt_dict=None,
|
| 84 |
+
backtranslation_fn=None,
|
| 85 |
+
output_collater=None,
|
| 86 |
+
cuda=True,
|
| 87 |
+
**kwargs
|
| 88 |
+
):
|
| 89 |
+
self.tgt_dataset = tgt_dataset
|
| 90 |
+
self.backtranslation_fn = backtranslation_fn
|
| 91 |
+
self.output_collater = (
|
| 92 |
+
output_collater if output_collater is not None else tgt_dataset.collater
|
| 93 |
+
)
|
| 94 |
+
self.cuda = cuda if torch.cuda.is_available() else False
|
| 95 |
+
self.src_dict = src_dict
|
| 96 |
+
self.tgt_dict = tgt_dict
|
| 97 |
+
|
| 98 |
+
def __getitem__(self, index):
|
| 99 |
+
"""
|
| 100 |
+
Returns a single sample from *tgt_dataset*. Note that backtranslation is
|
| 101 |
+
not applied in this step; use :func:`collater` instead to backtranslate
|
| 102 |
+
a batch of samples.
|
| 103 |
+
"""
|
| 104 |
+
return self.tgt_dataset[index]
|
| 105 |
+
|
| 106 |
+
def __len__(self):
|
| 107 |
+
return len(self.tgt_dataset)
|
| 108 |
+
|
| 109 |
+
def set_backtranslation_fn(self, backtranslation_fn):
|
| 110 |
+
self.backtranslation_fn = backtranslation_fn
|
| 111 |
+
|
| 112 |
+
def collater(self, samples):
|
| 113 |
+
"""Merge and backtranslate a list of samples to form a mini-batch.
|
| 114 |
+
|
| 115 |
+
Using the samples from *tgt_dataset*, load a collated target sample to
|
| 116 |
+
feed to the backtranslation model. Then take the backtranslation with
|
| 117 |
+
the best score as the source and the original input as the target.
|
| 118 |
+
|
| 119 |
+
Note: we expect *tgt_dataset* to provide a function `collater()` that
|
| 120 |
+
will collate samples into the format expected by *backtranslation_fn*.
|
| 121 |
+
After backtranslation, we will feed the new list of samples (i.e., the
|
| 122 |
+
`(backtranslated source, original source)` pairs) to *output_collater*
|
| 123 |
+
and return the result.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
samples (List[dict]): samples to backtranslate and collate
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
dict: a mini-batch with keys coming from *output_collater*
|
| 130 |
+
"""
|
| 131 |
+
if samples[0].get("is_dummy", False):
|
| 132 |
+
return samples
|
| 133 |
+
samples = backtranslate_samples(
|
| 134 |
+
samples=samples,
|
| 135 |
+
collate_fn=self.tgt_dataset.collater,
|
| 136 |
+
generate_fn=(lambda net_input: self.backtranslation_fn(net_input)),
|
| 137 |
+
cuda=self.cuda,
|
| 138 |
+
)
|
| 139 |
+
return self.output_collater(samples)
|
| 140 |
+
|
| 141 |
+
def num_tokens(self, index):
|
| 142 |
+
"""Just use the tgt dataset num_tokens"""
|
| 143 |
+
return self.tgt_dataset.num_tokens(index)
|
| 144 |
+
|
| 145 |
+
def ordered_indices(self):
|
| 146 |
+
"""Just use the tgt dataset ordered_indices"""
|
| 147 |
+
return self.tgt_dataset.ordered_indices()
|
| 148 |
+
|
| 149 |
+
def size(self, index):
|
| 150 |
+
"""Return an example's size as a float or tuple. This value is used
|
| 151 |
+
when filtering a dataset with ``--max-positions``.
|
| 152 |
+
|
| 153 |
+
Note: we use *tgt_dataset* to approximate the length of the source
|
| 154 |
+
sentence, since we do not know the actual length until after
|
| 155 |
+
backtranslation.
|
| 156 |
+
"""
|
| 157 |
+
tgt_size = self.tgt_dataset.size(index)[0]
|
| 158 |
+
return (tgt_size, tgt_size)
|
| 159 |
+
|
| 160 |
+
@property
|
| 161 |
+
def supports_prefetch(self):
|
| 162 |
+
return getattr(self.tgt_dataset, "supports_prefetch", False)
|
| 163 |
+
|
| 164 |
+
def prefetch(self, indices):
|
| 165 |
+
return self.tgt_dataset.prefetch(indices)
|
fairseq-0.10.2/fairseq/data/base_wrapper_dataset.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from torch.utils.data.dataloader import default_collate
|
| 7 |
+
|
| 8 |
+
from . import FairseqDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseWrapperDataset(FairseqDataset):
|
| 12 |
+
def __init__(self, dataset):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.dataset = dataset
|
| 15 |
+
|
| 16 |
+
def __getitem__(self, index):
|
| 17 |
+
return self.dataset[index]
|
| 18 |
+
|
| 19 |
+
def __len__(self):
|
| 20 |
+
return len(self.dataset)
|
| 21 |
+
|
| 22 |
+
def collater(self, samples):
|
| 23 |
+
if hasattr(self.dataset, "collater"):
|
| 24 |
+
return self.dataset.collater(samples)
|
| 25 |
+
else:
|
| 26 |
+
return default_collate(samples)
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def sizes(self):
|
| 30 |
+
return self.dataset.sizes
|
| 31 |
+
|
| 32 |
+
def num_tokens(self, index):
|
| 33 |
+
return self.dataset.num_tokens(index)
|
| 34 |
+
|
| 35 |
+
def size(self, index):
|
| 36 |
+
return self.dataset.size(index)
|
| 37 |
+
|
| 38 |
+
def ordered_indices(self):
|
| 39 |
+
return self.dataset.ordered_indices()
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def supports_prefetch(self):
|
| 43 |
+
return getattr(self.dataset, "supports_prefetch", False)
|
| 44 |
+
|
| 45 |
+
def attr(self, attr: str, index: int):
|
| 46 |
+
return self.dataset.attr(attr, index)
|
| 47 |
+
|
| 48 |
+
def prefetch(self, indices):
|
| 49 |
+
self.dataset.prefetch(indices)
|
| 50 |
+
|
| 51 |
+
def get_batch_shapes(self):
|
| 52 |
+
return self.dataset.get_batch_shapes()
|
| 53 |
+
|
| 54 |
+
def batch_by_size(
|
| 55 |
+
self,
|
| 56 |
+
indices,
|
| 57 |
+
max_tokens=None,
|
| 58 |
+
max_sentences=None,
|
| 59 |
+
required_batch_size_multiple=1,
|
| 60 |
+
):
|
| 61 |
+
return self.dataset.batch_by_size(
|
| 62 |
+
indices,
|
| 63 |
+
max_tokens=max_tokens,
|
| 64 |
+
max_sentences=max_sentences,
|
| 65 |
+
required_batch_size_multiple=required_batch_size_multiple,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def filter_indices_by_size(self, indices, max_sizes):
|
| 69 |
+
return self.dataset.filter_indices_by_size(indices, max_sizes)
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
| 73 |
+
return self.dataset.can_reuse_epoch_itr_across_epochs
|
| 74 |
+
|
| 75 |
+
def set_epoch(self, epoch):
|
| 76 |
+
super().set_epoch(epoch)
|
| 77 |
+
if hasattr(self.dataset, "set_epoch"):
|
| 78 |
+
self.dataset.set_epoch(epoch)
|
fairseq-0.10.2/fairseq/data/bucket_pad_length_dataset.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from fairseq.data import BaseWrapperDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BucketPadLengthDataset(BaseWrapperDataset):
|
| 12 |
+
"""
|
| 13 |
+
Bucket and pad item lengths to the nearest bucket size. This can be used to
|
| 14 |
+
reduce the number of unique batch shapes, which is important on TPUs since
|
| 15 |
+
each new batch shape requires a recompilation.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
dataset (FairseqDatset): dataset to bucket
|
| 19 |
+
sizes (List[int]): all item sizes
|
| 20 |
+
num_buckets (int): number of buckets to create
|
| 21 |
+
pad_idx (int): padding symbol
|
| 22 |
+
left_pad (bool): if True, pad on the left; otherwise right pad
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
dataset,
|
| 28 |
+
sizes,
|
| 29 |
+
num_buckets,
|
| 30 |
+
pad_idx,
|
| 31 |
+
left_pad,
|
| 32 |
+
):
|
| 33 |
+
super().__init__(dataset)
|
| 34 |
+
self.pad_idx = pad_idx
|
| 35 |
+
self.left_pad = left_pad
|
| 36 |
+
|
| 37 |
+
assert num_buckets > 0
|
| 38 |
+
self.buckets = np.unique(
|
| 39 |
+
np.percentile(
|
| 40 |
+
sizes,
|
| 41 |
+
np.linspace(0, 100, num_buckets + 1),
|
| 42 |
+
interpolation="lower",
|
| 43 |
+
)[1:]
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def get_bucketed_sizes(orig_sizes, buckets):
|
| 47 |
+
sizes = np.copy(orig_sizes)
|
| 48 |
+
assert np.min(sizes) >= 0
|
| 49 |
+
start_val = -1
|
| 50 |
+
for end_val in buckets:
|
| 51 |
+
mask = (sizes > start_val) & (sizes <= end_val)
|
| 52 |
+
sizes[mask] = end_val
|
| 53 |
+
start_val = end_val
|
| 54 |
+
return sizes
|
| 55 |
+
|
| 56 |
+
self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, index):
|
| 59 |
+
item = self.dataset[index]
|
| 60 |
+
bucket_size = self._bucketed_sizes[index]
|
| 61 |
+
num_pad = bucket_size - item.size(-1)
|
| 62 |
+
return F.pad(
|
| 63 |
+
item,
|
| 64 |
+
(num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
|
| 65 |
+
value=self.pad_idx,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def sizes(self):
|
| 70 |
+
return self._bucketed_sizes
|
| 71 |
+
|
| 72 |
+
def num_tokens(self, index):
|
| 73 |
+
return self._bucketed_sizes[index]
|
| 74 |
+
|
| 75 |
+
def size(self, index):
|
| 76 |
+
return self._bucketed_sizes[index]
|
fairseq-0.10.2/fairseq/data/colorize_dataset.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from . import BaseWrapperDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ColorizeDataset(BaseWrapperDataset):
|
| 12 |
+
""" Adds 'colors' property to net input that is obtained from the provided color getter for use by models """
|
| 13 |
+
|
| 14 |
+
def __init__(self, dataset, color_getter):
|
| 15 |
+
super().__init__(dataset)
|
| 16 |
+
self.color_getter = color_getter
|
| 17 |
+
|
| 18 |
+
def collater(self, samples):
|
| 19 |
+
base_collate = super().collater(samples)
|
| 20 |
+
if len(base_collate) > 0:
|
| 21 |
+
base_collate["net_input"]["colors"] = torch.tensor(
|
| 22 |
+
list(self.color_getter(self.dataset, s["id"]) for s in samples),
|
| 23 |
+
dtype=torch.long,
|
| 24 |
+
)
|
| 25 |
+
return base_collate
|
fairseq-0.10.2/fairseq/data/data_utils.py
ADDED
|
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from collections.abc import Iterable
|
| 8 |
+
except ImportError:
|
| 9 |
+
from collections import Iterable
|
| 10 |
+
import contextlib
|
| 11 |
+
import itertools
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
import warnings
|
| 15 |
+
from typing import Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def infer_language_pair(path):
|
| 25 |
+
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
|
| 26 |
+
src, dst = None, None
|
| 27 |
+
for filename in os.listdir(path):
|
| 28 |
+
parts = filename.split(".")
|
| 29 |
+
if len(parts) >= 3 and len(parts[1].split("-")) == 2:
|
| 30 |
+
return parts[1].split("-")
|
| 31 |
+
return src, dst
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def collate_tokens(
|
| 35 |
+
values,
|
| 36 |
+
pad_idx,
|
| 37 |
+
eos_idx=None,
|
| 38 |
+
left_pad=False,
|
| 39 |
+
move_eos_to_beginning=False,
|
| 40 |
+
pad_to_length=None,
|
| 41 |
+
pad_to_multiple=1,
|
| 42 |
+
):
|
| 43 |
+
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
| 44 |
+
size = max(v.size(0) for v in values)
|
| 45 |
+
size = size if pad_to_length is None else max(size, pad_to_length)
|
| 46 |
+
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
|
| 47 |
+
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
|
| 48 |
+
res = values[0].new(len(values), size).fill_(pad_idx)
|
| 49 |
+
|
| 50 |
+
def copy_tensor(src, dst):
|
| 51 |
+
assert dst.numel() == src.numel()
|
| 52 |
+
if move_eos_to_beginning:
|
| 53 |
+
if eos_idx is None:
|
| 54 |
+
# if no eos_idx is specified, then use the last token in src
|
| 55 |
+
dst[0] = src[-1]
|
| 56 |
+
else:
|
| 57 |
+
dst[0] = eos_idx
|
| 58 |
+
dst[1:] = src[:-1]
|
| 59 |
+
else:
|
| 60 |
+
dst.copy_(src)
|
| 61 |
+
|
| 62 |
+
for i, v in enumerate(values):
|
| 63 |
+
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
|
| 64 |
+
return res
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def load_indexed_dataset(
|
| 68 |
+
path, dictionary=None, dataset_impl=None, combine=False, default="cached"
|
| 69 |
+
):
|
| 70 |
+
"""A helper function for loading indexed datasets.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
path (str): path to indexed dataset (e.g., 'data-bin/train')
|
| 74 |
+
dictionary (~fairseq.data.Dictionary): data dictionary
|
| 75 |
+
dataset_impl (str, optional): which dataset implementation to use. If
|
| 76 |
+
not provided, it will be inferred automatically. For legacy indexed
|
| 77 |
+
data we use the 'cached' implementation by default.
|
| 78 |
+
combine (bool, optional): automatically load and combine multiple
|
| 79 |
+
datasets. For example, if *path* is 'data-bin/train', then we will
|
| 80 |
+
combine 'data-bin/train', 'data-bin/train1', ... and return a
|
| 81 |
+
single ConcatDataset instance.
|
| 82 |
+
"""
|
| 83 |
+
from fairseq.data.concat_dataset import ConcatDataset
|
| 84 |
+
import fairseq.data.indexed_dataset as indexed_dataset
|
| 85 |
+
|
| 86 |
+
datasets = []
|
| 87 |
+
for k in itertools.count():
|
| 88 |
+
path_k = path + (str(k) if k > 0 else "")
|
| 89 |
+
path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
|
| 90 |
+
|
| 91 |
+
dataset_impl_k = dataset_impl
|
| 92 |
+
if dataset_impl_k is None:
|
| 93 |
+
dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
|
| 94 |
+
dataset = indexed_dataset.make_dataset(
|
| 95 |
+
path_k,
|
| 96 |
+
impl=dataset_impl_k or default,
|
| 97 |
+
fix_lua_indexing=True,
|
| 98 |
+
dictionary=dictionary,
|
| 99 |
+
)
|
| 100 |
+
if dataset is None:
|
| 101 |
+
break
|
| 102 |
+
logger.info("loaded {} examples from: {}".format(len(dataset), path_k))
|
| 103 |
+
datasets.append(dataset)
|
| 104 |
+
if not combine:
|
| 105 |
+
break
|
| 106 |
+
if len(datasets) == 0:
|
| 107 |
+
return None
|
| 108 |
+
elif len(datasets) == 1:
|
| 109 |
+
return datasets[0]
|
| 110 |
+
else:
|
| 111 |
+
return ConcatDataset(datasets)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@contextlib.contextmanager
|
| 115 |
+
def numpy_seed(seed, *addl_seeds):
|
| 116 |
+
"""Context manager which seeds the NumPy PRNG with the specified seed and
|
| 117 |
+
restores the state afterward"""
|
| 118 |
+
if seed is None:
|
| 119 |
+
yield
|
| 120 |
+
return
|
| 121 |
+
if len(addl_seeds) > 0:
|
| 122 |
+
seed = int(hash((seed, *addl_seeds)) % 1e6)
|
| 123 |
+
state = np.random.get_state()
|
| 124 |
+
np.random.seed(seed)
|
| 125 |
+
try:
|
| 126 |
+
yield
|
| 127 |
+
finally:
|
| 128 |
+
np.random.set_state(state)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def collect_filtered(function, iterable, filtered):
|
| 132 |
+
"""
|
| 133 |
+
Similar to :func:`filter` but collects filtered elements in ``filtered``.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
function (callable): function that returns ``False`` for elements that
|
| 137 |
+
should be filtered
|
| 138 |
+
iterable (iterable): iterable to filter
|
| 139 |
+
filtered (list): list to store filtered elements
|
| 140 |
+
"""
|
| 141 |
+
for el in iterable:
|
| 142 |
+
if function(el):
|
| 143 |
+
yield el
|
| 144 |
+
else:
|
| 145 |
+
filtered.append(el)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
|
| 149 |
+
def compare_leq(a, b):
|
| 150 |
+
return a <= b if not isinstance(a, tuple) else max(a) <= b
|
| 151 |
+
|
| 152 |
+
def check_size(idx):
|
| 153 |
+
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
| 154 |
+
return size_fn(idx) <= max_positions
|
| 155 |
+
elif isinstance(max_positions, dict):
|
| 156 |
+
idx_size = size_fn(idx)
|
| 157 |
+
assert isinstance(idx_size, dict)
|
| 158 |
+
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
|
| 159 |
+
return all(
|
| 160 |
+
all(
|
| 161 |
+
a is None or b is None or a <= b
|
| 162 |
+
for a, b in zip(idx_size[key], max_positions[key])
|
| 163 |
+
)
|
| 164 |
+
for key in intersect_keys
|
| 165 |
+
)
|
| 166 |
+
else:
|
| 167 |
+
# Hacky as heck, for the specific case of multilingual training with RoundRobin.
|
| 168 |
+
if isinstance(size_fn(idx), dict) and isinstance(max_positions, tuple):
|
| 169 |
+
return all(
|
| 170 |
+
a is None or b is None or compare_leq(a, b)
|
| 171 |
+
for a, b in zip(size_fn(idx).values(), max_positions)
|
| 172 |
+
)
|
| 173 |
+
# For MultiCorpusSampledDataset, will generalize it later
|
| 174 |
+
if not isinstance(size_fn(idx), Iterable):
|
| 175 |
+
return all(size_fn(idx) <= b for b in max_positions)
|
| 176 |
+
return all(
|
| 177 |
+
a is None or b is None or a <= b
|
| 178 |
+
for a, b in zip(size_fn(idx), max_positions)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
ignored = []
|
| 182 |
+
itr = collect_filtered(check_size, indices, ignored)
|
| 183 |
+
indices = np.fromiter(itr, dtype=np.int64, count=-1)
|
| 184 |
+
return indices, ignored
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def filter_by_size(indices, dataset, max_positions, raise_exception=False):
|
| 188 |
+
"""
|
| 189 |
+
[deprecated] Filter indices based on their size.
|
| 190 |
+
Use `FairseqDataset::filter_indices_by_size` instead.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
indices (List[int]): ordered list of dataset indices
|
| 194 |
+
dataset (FairseqDataset): fairseq dataset instance
|
| 195 |
+
max_positions (tuple): filter elements larger than this size.
|
| 196 |
+
Comparisons are done component-wise.
|
| 197 |
+
raise_exception (bool, optional): if ``True``, raise an exception if
|
| 198 |
+
any elements are filtered (default: False).
|
| 199 |
+
"""
|
| 200 |
+
warnings.warn(
|
| 201 |
+
"data_utils.filter_by_size is deprecated. "
|
| 202 |
+
"Use `FairseqDataset::filter_indices_by_size` instead.",
|
| 203 |
+
stacklevel=2,
|
| 204 |
+
)
|
| 205 |
+
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
| 206 |
+
if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
|
| 207 |
+
ignored = indices[dataset.sizes[indices] > max_positions].tolist()
|
| 208 |
+
indices = indices[dataset.sizes[indices] <= max_positions]
|
| 209 |
+
elif (
|
| 210 |
+
hasattr(dataset, "sizes")
|
| 211 |
+
and isinstance(dataset.sizes, list)
|
| 212 |
+
and len(dataset.sizes) == 1
|
| 213 |
+
):
|
| 214 |
+
ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
|
| 215 |
+
indices = indices[dataset.sizes[0][indices] <= max_positions]
|
| 216 |
+
else:
|
| 217 |
+
indices, ignored = _filter_by_size_dynamic(
|
| 218 |
+
indices, dataset.size, max_positions
|
| 219 |
+
)
|
| 220 |
+
else:
|
| 221 |
+
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
|
| 222 |
+
|
| 223 |
+
if len(ignored) > 0 and raise_exception:
|
| 224 |
+
raise Exception(
|
| 225 |
+
(
|
| 226 |
+
"Size of sample #{} is invalid (={}) since max_positions={}, "
|
| 227 |
+
"skip this example with --skip-invalid-size-inputs-valid-test"
|
| 228 |
+
).format(ignored[0], dataset.size(ignored[0]), max_positions)
|
| 229 |
+
)
|
| 230 |
+
if len(ignored) > 0:
|
| 231 |
+
logger.warning(
|
| 232 |
+
(
|
| 233 |
+
"{} samples have invalid sizes and will be skipped, "
|
| 234 |
+
"max_positions={}, first few sample ids={}"
|
| 235 |
+
).format(len(ignored), max_positions, ignored[:10])
|
| 236 |
+
)
|
| 237 |
+
return indices
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
|
| 241 |
+
"""Filter a list of sample indices. Remove those that are longer
|
| 242 |
+
than specified in max_sizes.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
indices (np.array): original array of sample indices
|
| 246 |
+
max_sizes (int or list[int] or tuple[int]): max sample size,
|
| 247 |
+
can be defined separately for src and tgt (then list or tuple)
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
np.array: filtered sample array
|
| 251 |
+
list: list of removed indices
|
| 252 |
+
"""
|
| 253 |
+
if max_sizes is None:
|
| 254 |
+
return indices, []
|
| 255 |
+
if type(max_sizes) in (int, float):
|
| 256 |
+
max_src_size, max_tgt_size = max_sizes, max_sizes
|
| 257 |
+
else:
|
| 258 |
+
max_src_size, max_tgt_size = max_sizes
|
| 259 |
+
if tgt_sizes is None:
|
| 260 |
+
ignored = indices[src_sizes[indices] > max_src_size]
|
| 261 |
+
else:
|
| 262 |
+
ignored = indices[
|
| 263 |
+
(src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
|
| 264 |
+
]
|
| 265 |
+
if len(ignored) > 0:
|
| 266 |
+
if tgt_sizes is None:
|
| 267 |
+
indices = indices[src_sizes[indices] <= max_src_size]
|
| 268 |
+
else:
|
| 269 |
+
indices = indices[
|
| 270 |
+
(src_sizes[indices] <= max_src_size)
|
| 271 |
+
& (tgt_sizes[indices] <= max_tgt_size)
|
| 272 |
+
]
|
| 273 |
+
return indices, ignored.tolist()
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def batch_by_size(
|
| 277 |
+
indices,
|
| 278 |
+
num_tokens_fn,
|
| 279 |
+
max_tokens=None,
|
| 280 |
+
max_sentences=None,
|
| 281 |
+
required_batch_size_multiple=1,
|
| 282 |
+
fixed_shapes=None,
|
| 283 |
+
):
|
| 284 |
+
"""
|
| 285 |
+
Yield mini-batches of indices bucketed by size. Batches may contain
|
| 286 |
+
sequences of different lengths.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
indices (List[int]): ordered list of dataset indices
|
| 290 |
+
num_tokens_fn (callable): function that returns the number of tokens at
|
| 291 |
+
a given index
|
| 292 |
+
max_tokens (int, optional): max number of tokens in each batch
|
| 293 |
+
(default: None).
|
| 294 |
+
max_sentences (int, optional): max number of sentences in each
|
| 295 |
+
batch (default: None).
|
| 296 |
+
required_batch_size_multiple (int, optional): require batch size to
|
| 297 |
+
be less than N or a multiple of N (default: 1).
|
| 298 |
+
fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
|
| 299 |
+
only be created with the given shapes. *max_sentences* and
|
| 300 |
+
*required_batch_size_multiple* will be ignored (default: None).
|
| 301 |
+
"""
|
| 302 |
+
try:
|
| 303 |
+
from fairseq.data.data_utils_fast import (
|
| 304 |
+
batch_by_size_fast,
|
| 305 |
+
batch_fixed_shapes_fast,
|
| 306 |
+
)
|
| 307 |
+
except ImportError:
|
| 308 |
+
raise ImportError(
|
| 309 |
+
"Please build Cython components with: `pip install --editable .` "
|
| 310 |
+
"or `python setup.py build_ext --inplace`"
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
max_tokens = max_tokens if max_tokens is not None else -1
|
| 314 |
+
max_sentences = max_sentences if max_sentences is not None else -1
|
| 315 |
+
bsz_mult = required_batch_size_multiple
|
| 316 |
+
|
| 317 |
+
if not isinstance(indices, np.ndarray):
|
| 318 |
+
indices = np.fromiter(indices, dtype=np.int64, count=-1)
|
| 319 |
+
|
| 320 |
+
if fixed_shapes is None:
|
| 321 |
+
return batch_by_size_fast(
|
| 322 |
+
indices,
|
| 323 |
+
num_tokens_fn,
|
| 324 |
+
max_tokens,
|
| 325 |
+
max_sentences,
|
| 326 |
+
bsz_mult,
|
| 327 |
+
)
|
| 328 |
+
else:
|
| 329 |
+
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
|
| 330 |
+
sort_order = np.lexsort(
|
| 331 |
+
[
|
| 332 |
+
fixed_shapes[:, 1].argsort(), # length
|
| 333 |
+
fixed_shapes[:, 0].argsort(), # bsz
|
| 334 |
+
]
|
| 335 |
+
)
|
| 336 |
+
fixed_shapes_sorted = fixed_shapes[sort_order]
|
| 337 |
+
return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def post_process(sentence: str, symbol: str):
|
| 341 |
+
if symbol == "sentencepiece":
|
| 342 |
+
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
|
| 343 |
+
elif symbol == "wordpiece":
|
| 344 |
+
sentence = sentence.replace(" ", "").replace("_", " ").strip()
|
| 345 |
+
elif symbol == "letter":
|
| 346 |
+
sentence = sentence.replace(" ", "").replace("|", " ").strip()
|
| 347 |
+
elif symbol == "_EOW":
|
| 348 |
+
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
|
| 349 |
+
elif symbol is not None and symbol != "none":
|
| 350 |
+
sentence = (sentence + " ").replace(symbol, "").rstrip()
|
| 351 |
+
return sentence
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def compute_mask_indices(
|
| 355 |
+
shape: Tuple[int, int],
|
| 356 |
+
padding_mask: Optional[torch.Tensor],
|
| 357 |
+
mask_prob: float,
|
| 358 |
+
mask_length: int,
|
| 359 |
+
mask_type: str = "static",
|
| 360 |
+
mask_other: float = 0.0,
|
| 361 |
+
min_masks: int = 0,
|
| 362 |
+
no_overlap: bool = False,
|
| 363 |
+
min_space: int = 0,
|
| 364 |
+
) -> np.ndarray:
|
| 365 |
+
"""
|
| 366 |
+
Computes random mask spans for a given shape
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
shape: the the shape for which to compute masks.
|
| 370 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
| 371 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
| 372 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
| 373 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
| 374 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
| 375 |
+
mask_type: how to compute mask lengths
|
| 376 |
+
static = fixed size
|
| 377 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
| 378 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
| 379 |
+
poisson = sample from possion distribution with lambda = mask length
|
| 380 |
+
min_masks: minimum number of masked spans
|
| 381 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
| 382 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
| 383 |
+
"""
|
| 384 |
+
|
| 385 |
+
bsz, all_sz = shape
|
| 386 |
+
mask = np.full((bsz, all_sz), False)
|
| 387 |
+
|
| 388 |
+
all_num_mask = int(
|
| 389 |
+
# add a random number for probabilistic rounding
|
| 390 |
+
mask_prob * all_sz / float(mask_length)
|
| 391 |
+
+ np.random.rand()
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
all_num_mask = max(min_masks, all_num_mask)
|
| 395 |
+
|
| 396 |
+
mask_idcs = []
|
| 397 |
+
for i in range(bsz):
|
| 398 |
+
if padding_mask is not None:
|
| 399 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
| 400 |
+
num_mask = int(
|
| 401 |
+
# add a random number for probabilistic rounding
|
| 402 |
+
mask_prob * sz / float(mask_length)
|
| 403 |
+
+ np.random.rand()
|
| 404 |
+
)
|
| 405 |
+
num_mask = max(min_masks, num_mask)
|
| 406 |
+
else:
|
| 407 |
+
sz = all_sz
|
| 408 |
+
num_mask = all_num_mask
|
| 409 |
+
|
| 410 |
+
if mask_type == "static":
|
| 411 |
+
lengths = np.full(num_mask, mask_length)
|
| 412 |
+
elif mask_type == "uniform":
|
| 413 |
+
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
| 414 |
+
elif mask_type == "normal":
|
| 415 |
+
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
| 416 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
| 417 |
+
elif mask_type == "poisson":
|
| 418 |
+
lengths = np.random.poisson(mask_length, size=num_mask)
|
| 419 |
+
lengths = [int(round(x)) for x in lengths]
|
| 420 |
+
else:
|
| 421 |
+
raise Exception("unknown mask selection " + mask_type)
|
| 422 |
+
|
| 423 |
+
if sum(lengths) == 0:
|
| 424 |
+
lengths[0] = min(mask_length, sz - 1)
|
| 425 |
+
|
| 426 |
+
if no_overlap:
|
| 427 |
+
mask_idc = []
|
| 428 |
+
|
| 429 |
+
def arrange(s, e, length, keep_length):
|
| 430 |
+
span_start = np.random.randint(s, e - length)
|
| 431 |
+
mask_idc.extend(span_start + i for i in range(length))
|
| 432 |
+
|
| 433 |
+
new_parts = []
|
| 434 |
+
if span_start - s - min_space >= keep_length:
|
| 435 |
+
new_parts.append((s, span_start - min_space + 1))
|
| 436 |
+
if e - span_start - keep_length - min_space > keep_length:
|
| 437 |
+
new_parts.append((span_start + length + min_space, e))
|
| 438 |
+
return new_parts
|
| 439 |
+
|
| 440 |
+
parts = [(0, sz)]
|
| 441 |
+
min_length = min(lengths)
|
| 442 |
+
for length in sorted(lengths, reverse=True):
|
| 443 |
+
lens = np.fromiter(
|
| 444 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
| 445 |
+
np.int,
|
| 446 |
+
)
|
| 447 |
+
l_sum = np.sum(lens)
|
| 448 |
+
if l_sum == 0:
|
| 449 |
+
break
|
| 450 |
+
probs = lens / np.sum(lens)
|
| 451 |
+
c = np.random.choice(len(parts), p=probs)
|
| 452 |
+
s, e = parts.pop(c)
|
| 453 |
+
parts.extend(arrange(s, e, length, min_length))
|
| 454 |
+
mask_idc = np.asarray(mask_idc)
|
| 455 |
+
else:
|
| 456 |
+
min_len = min(lengths)
|
| 457 |
+
if sz - min_len <= num_mask:
|
| 458 |
+
min_len = sz - num_mask - 1
|
| 459 |
+
|
| 460 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
| 461 |
+
|
| 462 |
+
mask_idc = np.asarray(
|
| 463 |
+
[
|
| 464 |
+
mask_idc[j] + offset
|
| 465 |
+
for j in range(len(mask_idc))
|
| 466 |
+
for offset in range(lengths[j])
|
| 467 |
+
]
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
| 471 |
+
|
| 472 |
+
min_len = min([len(m) for m in mask_idcs])
|
| 473 |
+
for i, mask_idc in enumerate(mask_idcs):
|
| 474 |
+
if len(mask_idc) > min_len:
|
| 475 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
| 476 |
+
mask[i, mask_idc] = True
|
| 477 |
+
|
| 478 |
+
return mask
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def get_mem_usage():
|
| 482 |
+
try:
|
| 483 |
+
import psutil
|
| 484 |
+
|
| 485 |
+
mb = 1024 * 1024
|
| 486 |
+
return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
|
| 487 |
+
except ImportError:
|
| 488 |
+
return "N/A"
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor:
|
| 492 |
+
bsz, max_lens = lens.size(0), torch.max(lens).item()
|
| 493 |
+
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
|
| 494 |
+
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
|
| 495 |
+
return mask
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def lengths_to_mask(lens: torch.LongTensor) -> torch.BoolTensor:
|
| 499 |
+
return ~lengths_to_padding_mask(lens)
|
fairseq-0.10.2/fairseq/data/data_utils_fast.cpp
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
fairseq-0.10.2/fairseq/data/data_utils_fast.pyx
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# cython: language_level=3
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
cimport cython
|
| 10 |
+
cimport numpy as np
|
| 11 |
+
|
| 12 |
+
from libc.stdint cimport int32_t, int64_t
|
| 13 |
+
|
| 14 |
+
ctypedef int64_t DTYPE_t
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
cdef _is_batch_full(int64_t num_sentences, int64_t num_tokens, int64_t max_tokens, int64_t max_sentences):
|
| 18 |
+
if num_sentences == 0:
|
| 19 |
+
return 0
|
| 20 |
+
if max_sentences > 0 and num_sentences == max_sentences:
|
| 21 |
+
return 1
|
| 22 |
+
if max_tokens > 0 and num_tokens > max_tokens:
|
| 23 |
+
return 1
|
| 24 |
+
return 0
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@cython.cdivision(True)
|
| 28 |
+
cpdef list batch_by_size_fast(
|
| 29 |
+
np.ndarray[DTYPE_t, ndim=1] indices,
|
| 30 |
+
num_tokens_fn,
|
| 31 |
+
int64_t max_tokens,
|
| 32 |
+
int64_t max_sentences,
|
| 33 |
+
int32_t bsz_mult,
|
| 34 |
+
):
|
| 35 |
+
cdef int64_t sample_len = 0
|
| 36 |
+
cdef list sample_lens = []
|
| 37 |
+
cdef list batch = []
|
| 38 |
+
cdef list batches = []
|
| 39 |
+
cdef int64_t mod_len
|
| 40 |
+
cdef int64_t i
|
| 41 |
+
cdef int64_t idx
|
| 42 |
+
cdef int64_t num_tokens
|
| 43 |
+
cdef DTYPE_t[:] indices_view = indices
|
| 44 |
+
|
| 45 |
+
for i in range(len(indices_view)):
|
| 46 |
+
idx = indices_view[i]
|
| 47 |
+
num_tokens = num_tokens_fn(idx)
|
| 48 |
+
sample_lens.append(num_tokens)
|
| 49 |
+
sample_len = max(sample_len, num_tokens)
|
| 50 |
+
|
| 51 |
+
assert max_tokens <= 0 or sample_len <= max_tokens, (
|
| 52 |
+
"sentence at index {} of size {} exceeds max_tokens "
|
| 53 |
+
"limit of {}!".format(idx, sample_len, max_tokens)
|
| 54 |
+
)
|
| 55 |
+
num_tokens = (len(batch) + 1) * sample_len
|
| 56 |
+
|
| 57 |
+
if _is_batch_full(len(batch), num_tokens, max_tokens, max_sentences):
|
| 58 |
+
mod_len = max(
|
| 59 |
+
bsz_mult * (len(batch) // bsz_mult),
|
| 60 |
+
len(batch) % bsz_mult,
|
| 61 |
+
)
|
| 62 |
+
batches.append(batch[:mod_len])
|
| 63 |
+
batch = batch[mod_len:]
|
| 64 |
+
sample_lens = sample_lens[mod_len:]
|
| 65 |
+
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
|
| 66 |
+
batch.append(idx)
|
| 67 |
+
if len(batch) > 0:
|
| 68 |
+
batches.append(batch)
|
| 69 |
+
return batches
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
cdef _find_valid_shape(
|
| 73 |
+
DTYPE_t[:, :] shapes_view,
|
| 74 |
+
int64_t num_sentences,
|
| 75 |
+
int64_t num_tokens,
|
| 76 |
+
):
|
| 77 |
+
"""Return index of first valid shape of -1 if none is found."""
|
| 78 |
+
for i in range(shapes_view.shape[0]):
|
| 79 |
+
if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]:
|
| 80 |
+
return i
|
| 81 |
+
return -1
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@cython.cdivision(True)
|
| 85 |
+
cpdef list batch_fixed_shapes_fast(
|
| 86 |
+
np.ndarray[DTYPE_t, ndim=1] indices,
|
| 87 |
+
num_tokens_fn,
|
| 88 |
+
np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
|
| 89 |
+
):
|
| 90 |
+
cdef int64_t sample_len = 0
|
| 91 |
+
cdef list sample_lens = []
|
| 92 |
+
cdef list batch = []
|
| 93 |
+
cdef list batches = []
|
| 94 |
+
cdef int64_t mod_len
|
| 95 |
+
cdef int64_t i
|
| 96 |
+
cdef int64_t idx
|
| 97 |
+
cdef int64_t num_tokens
|
| 98 |
+
cdef DTYPE_t[:] indices_view = indices
|
| 99 |
+
cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted
|
| 100 |
+
|
| 101 |
+
for i in range(len(indices_view)):
|
| 102 |
+
idx = indices_view[i]
|
| 103 |
+
num_tokens = num_tokens_fn(idx)
|
| 104 |
+
sample_lens.append(num_tokens)
|
| 105 |
+
sample_len = max(sample_len, num_tokens)
|
| 106 |
+
|
| 107 |
+
shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len)
|
| 108 |
+
if shape_idx == -1:
|
| 109 |
+
batches.append(batch)
|
| 110 |
+
batch = []
|
| 111 |
+
sample_lens = []
|
| 112 |
+
sample_len = 0
|
| 113 |
+
shapes_view = fixed_shapes_sorted
|
| 114 |
+
elif shape_idx > 0:
|
| 115 |
+
# small optimization for the next call to _find_valid_shape
|
| 116 |
+
shapes_view = shapes_view[shape_idx:]
|
| 117 |
+
|
| 118 |
+
batch.append(idx)
|
| 119 |
+
|
| 120 |
+
if len(batch) > 0:
|
| 121 |
+
batches.append(batch)
|
| 122 |
+
|
| 123 |
+
return batches
|
fairseq-0.10.2/fairseq/data/denoising_dataset.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from . import FairseqDataset, data_utils
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def collate(
|
| 15 |
+
samples,
|
| 16 |
+
pad_idx,
|
| 17 |
+
eos_idx,
|
| 18 |
+
vocab,
|
| 19 |
+
left_pad_source=False,
|
| 20 |
+
left_pad_target=False,
|
| 21 |
+
input_feeding=True,
|
| 22 |
+
pad_to_length=None,
|
| 23 |
+
):
|
| 24 |
+
assert input_feeding
|
| 25 |
+
if len(samples) == 0:
|
| 26 |
+
return {}
|
| 27 |
+
|
| 28 |
+
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
|
| 29 |
+
return data_utils.collate_tokens(
|
| 30 |
+
[s[key] for s in samples],
|
| 31 |
+
pad_idx,
|
| 32 |
+
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
|
| 33 |
+
left_pad=left_pad,
|
| 34 |
+
move_eos_to_beginning=move_eos_to_beginning,
|
| 35 |
+
pad_to_length=pad_to_length,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
id = torch.LongTensor([s["id"] for s in samples])
|
| 39 |
+
src_tokens = merge(
|
| 40 |
+
"source",
|
| 41 |
+
left_pad=left_pad_source,
|
| 42 |
+
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
|
| 43 |
+
)
|
| 44 |
+
# sort by descending source length
|
| 45 |
+
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
|
| 46 |
+
src_lengths, sort_order = src_lengths.sort(descending=True)
|
| 47 |
+
id = id.index_select(0, sort_order)
|
| 48 |
+
src_tokens = src_tokens.index_select(0, sort_order)
|
| 49 |
+
|
| 50 |
+
prev_output_tokens = None
|
| 51 |
+
target = None
|
| 52 |
+
if samples[0].get("target", None) is not None:
|
| 53 |
+
target = merge(
|
| 54 |
+
"target",
|
| 55 |
+
left_pad=left_pad_target,
|
| 56 |
+
pad_to_length=pad_to_length["target"]
|
| 57 |
+
if pad_to_length is not None
|
| 58 |
+
else None,
|
| 59 |
+
)
|
| 60 |
+
target = target.index_select(0, sort_order)
|
| 61 |
+
ntokens = sum(len(s["target"]) for s in samples)
|
| 62 |
+
|
| 63 |
+
if input_feeding:
|
| 64 |
+
# we create a shifted version of targets for feeding the
|
| 65 |
+
# previous output token(s) into the next decoder step
|
| 66 |
+
prev_output_tokens = merge(
|
| 67 |
+
"target",
|
| 68 |
+
left_pad=left_pad_target,
|
| 69 |
+
move_eos_to_beginning=True,
|
| 70 |
+
pad_to_length=pad_to_length["target"]
|
| 71 |
+
if pad_to_length is not None
|
| 72 |
+
else None,
|
| 73 |
+
)
|
| 74 |
+
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
|
| 75 |
+
else:
|
| 76 |
+
ntokens = sum(len(s["source"]) for s in samples)
|
| 77 |
+
|
| 78 |
+
batch = {
|
| 79 |
+
"id": id,
|
| 80 |
+
"ntokens": ntokens,
|
| 81 |
+
"net_input": {
|
| 82 |
+
"src_tokens": src_tokens,
|
| 83 |
+
"src_lengths": src_lengths,
|
| 84 |
+
},
|
| 85 |
+
"target": target,
|
| 86 |
+
"nsentences": samples[0]["source"].size(0),
|
| 87 |
+
"sort_order": sort_order,
|
| 88 |
+
}
|
| 89 |
+
if prev_output_tokens is not None:
|
| 90 |
+
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
|
| 91 |
+
|
| 92 |
+
return batch
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class DenoisingDataset(FairseqDataset):
|
| 96 |
+
"""
|
| 97 |
+
A wrapper around TokenBlockDataset for BART dataset.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
dataset (TokenBlockDataset): dataset to wrap
|
| 101 |
+
sizes (List[int]): sentence lengths
|
| 102 |
+
vocab (~fairseq.data.Dictionary): vocabulary
|
| 103 |
+
mask_idx (int): dictionary index used for masked token
|
| 104 |
+
mask_whole_words: only mask whole words. This should be a byte mask
|
| 105 |
+
over vocab indices, indicating whether it is the beginning of a
|
| 106 |
+
word. We will extend any mask to encompass the whole word.
|
| 107 |
+
shuffle (bool, optional): shuffle the elements before batching.
|
| 108 |
+
Default: ``True``
|
| 109 |
+
seed: Seed for random number generator for reproducibility.
|
| 110 |
+
args: argparse arguments.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
dataset,
|
| 116 |
+
sizes,
|
| 117 |
+
vocab,
|
| 118 |
+
mask_idx,
|
| 119 |
+
mask_whole_words,
|
| 120 |
+
shuffle,
|
| 121 |
+
seed,
|
| 122 |
+
args,
|
| 123 |
+
eos=None,
|
| 124 |
+
item_transform_func=None,
|
| 125 |
+
):
|
| 126 |
+
self.dataset = dataset
|
| 127 |
+
|
| 128 |
+
self.sizes = sizes
|
| 129 |
+
|
| 130 |
+
self.vocab = vocab
|
| 131 |
+
self.shuffle = shuffle
|
| 132 |
+
self.seed = seed
|
| 133 |
+
self.mask_idx = mask_idx
|
| 134 |
+
self.mask_whole_word = mask_whole_words
|
| 135 |
+
self.mask_ratio = args.mask
|
| 136 |
+
self.random_ratio = args.mask_random
|
| 137 |
+
self.insert_ratio = args.insert
|
| 138 |
+
self.rotate_ratio = args.rotate
|
| 139 |
+
self.permute_sentence_ratio = args.permute_sentences
|
| 140 |
+
self.eos = eos if eos is not None else vocab.eos()
|
| 141 |
+
self.item_transform_func = item_transform_func
|
| 142 |
+
|
| 143 |
+
if args.bpe != "gpt2":
|
| 144 |
+
self.full_stop_index = self.vocab.eos()
|
| 145 |
+
else:
|
| 146 |
+
assert args.bpe == "gpt2"
|
| 147 |
+
self.full_stop_index = self.vocab.index("13")
|
| 148 |
+
|
| 149 |
+
self.replace_length = args.replace_length
|
| 150 |
+
if self.replace_length not in [-1, 0, 1]:
|
| 151 |
+
raise ValueError(f"invalid arg: replace_length={self.replace_length}")
|
| 152 |
+
if args.mask_length not in ["subword", "word", "span-poisson"]:
|
| 153 |
+
raise ValueError(f"invalid arg: mask-length={args.mask_length}")
|
| 154 |
+
if args.mask_length == "subword" and args.replace_length not in [0, 1]:
|
| 155 |
+
raise ValueError(f"if using subwords, use replace-length=1 or 0")
|
| 156 |
+
|
| 157 |
+
self.mask_span_distribution = None
|
| 158 |
+
if args.mask_length == "span-poisson":
|
| 159 |
+
_lambda = args.poisson_lambda
|
| 160 |
+
|
| 161 |
+
lambda_to_the_k = 1
|
| 162 |
+
e_to_the_minus_lambda = math.exp(-_lambda)
|
| 163 |
+
k_factorial = 1
|
| 164 |
+
ps = []
|
| 165 |
+
for k in range(0, 128):
|
| 166 |
+
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
|
| 167 |
+
lambda_to_the_k *= _lambda
|
| 168 |
+
k_factorial *= k + 1
|
| 169 |
+
if ps[-1] < 0.0000001:
|
| 170 |
+
break
|
| 171 |
+
ps = torch.FloatTensor(ps)
|
| 172 |
+
self.mask_span_distribution = torch.distributions.Categorical(ps)
|
| 173 |
+
|
| 174 |
+
self.epoch = 0
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
| 178 |
+
return True # only the noise changes, not item sizes
|
| 179 |
+
|
| 180 |
+
def set_epoch(self, epoch, **unused):
|
| 181 |
+
self.epoch = epoch
|
| 182 |
+
|
| 183 |
+
def __getitem__(self, index):
|
| 184 |
+
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
| 185 |
+
tokens = self.dataset[index]
|
| 186 |
+
assert tokens[-1] == self.eos
|
| 187 |
+
source, target = tokens, tokens.clone()
|
| 188 |
+
|
| 189 |
+
if self.permute_sentence_ratio > 0.0:
|
| 190 |
+
source = self.permute_sentences(source, self.permute_sentence_ratio)
|
| 191 |
+
|
| 192 |
+
if self.mask_ratio > 0:
|
| 193 |
+
source = self.add_whole_word_mask(source, self.mask_ratio)
|
| 194 |
+
|
| 195 |
+
if self.insert_ratio > 0:
|
| 196 |
+
source = self.add_insertion_noise(source, self.insert_ratio)
|
| 197 |
+
|
| 198 |
+
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
|
| 199 |
+
source = self.add_rolling_noise(source)
|
| 200 |
+
# there can additional changes to make:
|
| 201 |
+
if self.item_transform_func is not None:
|
| 202 |
+
source, target = self.item_transform_func(source, target)
|
| 203 |
+
|
| 204 |
+
assert (source >= 0).all()
|
| 205 |
+
assert (source[1:-1] >= 1).all()
|
| 206 |
+
assert (source <= len(self.vocab)).all()
|
| 207 |
+
assert source[0] == self.vocab.bos()
|
| 208 |
+
assert source[-1] == self.eos
|
| 209 |
+
return {
|
| 210 |
+
"id": index,
|
| 211 |
+
"source": source,
|
| 212 |
+
"target": target,
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
def __len__(self):
|
| 216 |
+
return len(self.dataset)
|
| 217 |
+
|
| 218 |
+
def permute_sentences(self, source, p=1.0):
|
| 219 |
+
full_stops = source == self.full_stop_index
|
| 220 |
+
# Pretend it ends with a full stop so last span is a sentence
|
| 221 |
+
full_stops[-2] = 1
|
| 222 |
+
|
| 223 |
+
# Tokens that are full stops, where the previous token is not
|
| 224 |
+
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
|
| 225 |
+
result = source.clone()
|
| 226 |
+
|
| 227 |
+
num_sentences = sentence_ends.size(0)
|
| 228 |
+
num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
|
| 229 |
+
substitutions = torch.randperm(num_sentences)[:num_to_permute]
|
| 230 |
+
ordering = torch.arange(0, num_sentences)
|
| 231 |
+
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
|
| 232 |
+
|
| 233 |
+
# Ignore <bos> at start
|
| 234 |
+
index = 1
|
| 235 |
+
for i in ordering:
|
| 236 |
+
sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
|
| 237 |
+
result[index : index + sentence.size(0)] = sentence
|
| 238 |
+
index += sentence.size(0)
|
| 239 |
+
return result
|
| 240 |
+
|
| 241 |
+
def word_starts(self, source):
|
| 242 |
+
if self.mask_whole_word is not None:
|
| 243 |
+
is_word_start = self.mask_whole_word.gather(0, source)
|
| 244 |
+
else:
|
| 245 |
+
is_word_start = torch.ones(source.size())
|
| 246 |
+
is_word_start[0] = 0
|
| 247 |
+
is_word_start[-1] = 0
|
| 248 |
+
return is_word_start
|
| 249 |
+
|
| 250 |
+
def add_whole_word_mask(self, source, p):
|
| 251 |
+
is_word_start = self.word_starts(source)
|
| 252 |
+
num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
|
| 253 |
+
num_inserts = 0
|
| 254 |
+
if num_to_mask == 0:
|
| 255 |
+
return source
|
| 256 |
+
|
| 257 |
+
if self.mask_span_distribution is not None:
|
| 258 |
+
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
|
| 259 |
+
|
| 260 |
+
# Make sure we have enough to mask
|
| 261 |
+
cum_length = torch.cumsum(lengths, 0)
|
| 262 |
+
while cum_length[-1] < num_to_mask:
|
| 263 |
+
lengths = torch.cat(
|
| 264 |
+
[
|
| 265 |
+
lengths,
|
| 266 |
+
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
|
| 267 |
+
],
|
| 268 |
+
dim=0,
|
| 269 |
+
)
|
| 270 |
+
cum_length = torch.cumsum(lengths, 0)
|
| 271 |
+
|
| 272 |
+
# Trim to masking budget
|
| 273 |
+
i = 0
|
| 274 |
+
while cum_length[i] < num_to_mask:
|
| 275 |
+
i += 1
|
| 276 |
+
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
|
| 277 |
+
num_to_mask = i + 1
|
| 278 |
+
lengths = lengths[:num_to_mask]
|
| 279 |
+
|
| 280 |
+
# Handle 0-length mask (inserts) separately
|
| 281 |
+
lengths = lengths[lengths > 0]
|
| 282 |
+
num_inserts = num_to_mask - lengths.size(0)
|
| 283 |
+
num_to_mask -= num_inserts
|
| 284 |
+
if num_to_mask == 0:
|
| 285 |
+
return self.add_insertion_noise(source, num_inserts / source.size(0))
|
| 286 |
+
|
| 287 |
+
assert (lengths > 0).all()
|
| 288 |
+
else:
|
| 289 |
+
lengths = torch.ones((num_to_mask,)).long()
|
| 290 |
+
assert is_word_start[-1] == 0
|
| 291 |
+
word_starts = is_word_start.nonzero(as_tuple=False)
|
| 292 |
+
indices = word_starts[
|
| 293 |
+
torch.randperm(word_starts.size(0))[:num_to_mask]
|
| 294 |
+
].squeeze(1)
|
| 295 |
+
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
|
| 296 |
+
|
| 297 |
+
source_length = source.size(0)
|
| 298 |
+
assert source_length - 1 not in indices
|
| 299 |
+
to_keep = torch.ones(source_length, dtype=torch.bool)
|
| 300 |
+
is_word_start[
|
| 301 |
+
-1
|
| 302 |
+
] = 255 # acts as a long length, so spans don't go over the end of doc
|
| 303 |
+
if self.replace_length == 0:
|
| 304 |
+
to_keep[indices] = 0
|
| 305 |
+
else:
|
| 306 |
+
# keep index, but replace it with [MASK]
|
| 307 |
+
source[indices] = self.mask_idx
|
| 308 |
+
source[indices[mask_random]] = torch.randint(
|
| 309 |
+
1, len(self.vocab), size=(mask_random.sum(),)
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
if self.mask_span_distribution is not None:
|
| 313 |
+
assert len(lengths.size()) == 1
|
| 314 |
+
assert lengths.size() == indices.size()
|
| 315 |
+
lengths -= 1
|
| 316 |
+
while indices.size(0) > 0:
|
| 317 |
+
assert lengths.size() == indices.size()
|
| 318 |
+
lengths -= is_word_start[indices + 1].long()
|
| 319 |
+
uncompleted = lengths >= 0
|
| 320 |
+
indices = indices[uncompleted] + 1
|
| 321 |
+
mask_random = mask_random[uncompleted]
|
| 322 |
+
lengths = lengths[uncompleted]
|
| 323 |
+
if self.replace_length != -1:
|
| 324 |
+
# delete token
|
| 325 |
+
to_keep[indices] = 0
|
| 326 |
+
else:
|
| 327 |
+
# keep index, but replace it with [MASK]
|
| 328 |
+
source[indices] = self.mask_idx
|
| 329 |
+
source[indices[mask_random]] = torch.randint(
|
| 330 |
+
1, len(self.vocab), size=(mask_random.sum(),)
|
| 331 |
+
)
|
| 332 |
+
else:
|
| 333 |
+
# A bit faster when all lengths are 1
|
| 334 |
+
while indices.size(0) > 0:
|
| 335 |
+
uncompleted = is_word_start[indices + 1] == 0
|
| 336 |
+
indices = indices[uncompleted] + 1
|
| 337 |
+
mask_random = mask_random[uncompleted]
|
| 338 |
+
if self.replace_length != -1:
|
| 339 |
+
# delete token
|
| 340 |
+
to_keep[indices] = 0
|
| 341 |
+
else:
|
| 342 |
+
# keep index, but replace it with [MASK]
|
| 343 |
+
source[indices] = self.mask_idx
|
| 344 |
+
source[indices[mask_random]] = torch.randint(
|
| 345 |
+
1, len(self.vocab), size=(mask_random.sum(),)
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
assert source_length - 1 not in indices
|
| 349 |
+
|
| 350 |
+
source = source[to_keep]
|
| 351 |
+
|
| 352 |
+
if num_inserts > 0:
|
| 353 |
+
source = self.add_insertion_noise(source, num_inserts / source.size(0))
|
| 354 |
+
|
| 355 |
+
return source
|
| 356 |
+
|
| 357 |
+
def add_permuted_noise(self, tokens, p):
|
| 358 |
+
num_words = len(tokens)
|
| 359 |
+
num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
|
| 360 |
+
substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
|
| 361 |
+
tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
|
| 362 |
+
return tokens
|
| 363 |
+
|
| 364 |
+
def add_rolling_noise(self, tokens):
|
| 365 |
+
offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
|
| 366 |
+
tokens = torch.cat(
|
| 367 |
+
(tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
|
| 368 |
+
dim=0,
|
| 369 |
+
)
|
| 370 |
+
return tokens
|
| 371 |
+
|
| 372 |
+
def add_insertion_noise(self, tokens, p):
|
| 373 |
+
if p == 0.0:
|
| 374 |
+
return tokens
|
| 375 |
+
|
| 376 |
+
num_tokens = len(tokens)
|
| 377 |
+
n = int(math.ceil(num_tokens * p))
|
| 378 |
+
|
| 379 |
+
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
|
| 380 |
+
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
|
| 381 |
+
noise_mask[noise_indices] = 1
|
| 382 |
+
result = torch.LongTensor(n + len(tokens)).fill_(-1)
|
| 383 |
+
|
| 384 |
+
num_random = int(math.ceil(n * self.random_ratio))
|
| 385 |
+
result[noise_indices[num_random:]] = self.mask_idx
|
| 386 |
+
result[noise_indices[:num_random]] = torch.randint(
|
| 387 |
+
low=1, high=len(self.vocab), size=(num_random,)
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
result[~noise_mask] = tokens
|
| 391 |
+
|
| 392 |
+
assert (result >= 0).all()
|
| 393 |
+
return result
|
| 394 |
+
|
| 395 |
+
def collater(self, samples, pad_to_length=None):
|
| 396 |
+
"""Merge a list of samples to form a mini-batch.
|
| 397 |
+
Args:
|
| 398 |
+
samples (List[dict]): samples to collate
|
| 399 |
+
Returns:
|
| 400 |
+
dict: a mini-batch of data
|
| 401 |
+
"""
|
| 402 |
+
return collate(
|
| 403 |
+
samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
def num_tokens(self, index):
|
| 407 |
+
"""Return the number of tokens in a sample. This value is used to
|
| 408 |
+
enforce ``--max-tokens`` during batching."""
|
| 409 |
+
return self.sizes[index]
|
| 410 |
+
|
| 411 |
+
def size(self, index):
|
| 412 |
+
"""Return an example's size as a float or tuple. This value is used when
|
| 413 |
+
filtering a dataset with ``--max-positions``."""
|
| 414 |
+
return self.sizes[index]
|
| 415 |
+
|
| 416 |
+
def ordered_indices(self):
|
| 417 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
| 418 |
+
on this order."""
|
| 419 |
+
if self.shuffle:
|
| 420 |
+
indices = np.random.permutation(len(self))
|
| 421 |
+
else:
|
| 422 |
+
indices = np.arange(len(self))
|
| 423 |
+
return indices[np.argsort(self.sizes[indices], kind="mergesort")]
|
| 424 |
+
|
| 425 |
+
def prefetch(self, indices):
|
| 426 |
+
self.src.prefetch(indices)
|
| 427 |
+
self.tgt.prefetch(indices)
|
| 428 |
+
|
| 429 |
+
@property
|
| 430 |
+
def supports_prefetch(self):
|
| 431 |
+
return (
|
| 432 |
+
hasattr(self.src, "supports_prefetch")
|
| 433 |
+
and self.src.supports_prefetch
|
| 434 |
+
and hasattr(self.tgt, "supports_prefetch")
|
| 435 |
+
and self.tgt.supports_prefetch
|
| 436 |
+
)
|
fairseq-0.10.2/fairseq/data/indexed_dataset.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
import struct
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from fairseq.data.fasta_dataset import FastaDataset
|
| 14 |
+
from fairseq.file_io import PathManager
|
| 15 |
+
|
| 16 |
+
from . import FairseqDataset
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def __best_fitting_dtype(vocab_size=None):
|
| 20 |
+
if vocab_size is not None and vocab_size < 65500:
|
| 21 |
+
return np.uint16
|
| 22 |
+
else:
|
| 23 |
+
return np.int32
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_available_dataset_impl():
|
| 27 |
+
return ["raw", "lazy", "cached", "mmap", "fasta"]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def infer_dataset_impl(path):
|
| 31 |
+
if IndexedRawTextDataset.exists(path):
|
| 32 |
+
return "raw"
|
| 33 |
+
elif IndexedDataset.exists(path):
|
| 34 |
+
with open(index_file_path(path), "rb") as f:
|
| 35 |
+
magic = f.read(8)
|
| 36 |
+
if magic == IndexedDataset._HDR_MAGIC:
|
| 37 |
+
return "cached"
|
| 38 |
+
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
|
| 39 |
+
return "mmap"
|
| 40 |
+
else:
|
| 41 |
+
return None
|
| 42 |
+
elif FastaDataset.exists(path):
|
| 43 |
+
return "fasta"
|
| 44 |
+
else:
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def make_builder(out_file, impl, vocab_size=None):
|
| 49 |
+
if impl == "mmap":
|
| 50 |
+
return MMapIndexedDatasetBuilder(
|
| 51 |
+
out_file, dtype=__best_fitting_dtype(vocab_size)
|
| 52 |
+
)
|
| 53 |
+
elif impl == "fasta":
|
| 54 |
+
raise NotImplementedError
|
| 55 |
+
else:
|
| 56 |
+
return IndexedDatasetBuilder(out_file)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None):
|
| 60 |
+
if impl == "raw" and IndexedRawTextDataset.exists(path):
|
| 61 |
+
assert dictionary is not None
|
| 62 |
+
return IndexedRawTextDataset(path, dictionary)
|
| 63 |
+
elif impl == "lazy" and IndexedDataset.exists(path):
|
| 64 |
+
return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing)
|
| 65 |
+
elif impl == "cached" and IndexedDataset.exists(path):
|
| 66 |
+
return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing)
|
| 67 |
+
elif impl == "mmap" and MMapIndexedDataset.exists(path):
|
| 68 |
+
return MMapIndexedDataset(path)
|
| 69 |
+
elif impl == "fasta" and FastaDataset.exists(path):
|
| 70 |
+
from fairseq.data.fasta_dataset import EncodedFastaDataset
|
| 71 |
+
|
| 72 |
+
return EncodedFastaDataset(path, dictionary)
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def dataset_exists(path, impl):
|
| 77 |
+
if impl == "raw":
|
| 78 |
+
return IndexedRawTextDataset.exists(path)
|
| 79 |
+
elif impl == "mmap":
|
| 80 |
+
return MMapIndexedDataset.exists(path)
|
| 81 |
+
else:
|
| 82 |
+
return IndexedDataset.exists(path)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def read_longs(f, n):
|
| 86 |
+
a = np.empty(n, dtype=np.int64)
|
| 87 |
+
f.readinto(a)
|
| 88 |
+
return a
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def write_longs(f, a):
|
| 92 |
+
f.write(np.array(a, dtype=np.int64))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
dtypes = {
|
| 96 |
+
1: np.uint8,
|
| 97 |
+
2: np.int8,
|
| 98 |
+
3: np.int16,
|
| 99 |
+
4: np.int32,
|
| 100 |
+
5: np.int64,
|
| 101 |
+
6: np.float,
|
| 102 |
+
7: np.double,
|
| 103 |
+
8: np.uint16,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def code(dtype):
|
| 108 |
+
for k in dtypes.keys():
|
| 109 |
+
if dtypes[k] == dtype:
|
| 110 |
+
return k
|
| 111 |
+
raise ValueError(dtype)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def index_file_path(prefix_path):
|
| 115 |
+
return prefix_path + ".idx"
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def data_file_path(prefix_path):
|
| 119 |
+
return prefix_path + ".bin"
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class IndexedDataset(FairseqDataset):
|
| 123 |
+
"""Loader for TorchNet IndexedDataset"""
|
| 124 |
+
|
| 125 |
+
_HDR_MAGIC = b"TNTIDX\x00\x00"
|
| 126 |
+
|
| 127 |
+
def __init__(self, path, fix_lua_indexing=False):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.path = path
|
| 130 |
+
self.fix_lua_indexing = fix_lua_indexing
|
| 131 |
+
self.data_file = None
|
| 132 |
+
self.read_index(path)
|
| 133 |
+
|
| 134 |
+
def read_index(self, path):
|
| 135 |
+
with open(index_file_path(path), "rb") as f:
|
| 136 |
+
magic = f.read(8)
|
| 137 |
+
assert magic == self._HDR_MAGIC, (
|
| 138 |
+
"Index file doesn't match expected format. "
|
| 139 |
+
"Make sure that --dataset-impl is configured properly."
|
| 140 |
+
)
|
| 141 |
+
version = f.read(8)
|
| 142 |
+
assert struct.unpack("<Q", version) == (1,)
|
| 143 |
+
code, self.element_size = struct.unpack("<QQ", f.read(16))
|
| 144 |
+
self.dtype = dtypes[code]
|
| 145 |
+
self._len, self.s = struct.unpack("<QQ", f.read(16))
|
| 146 |
+
self.dim_offsets = read_longs(f, self._len + 1)
|
| 147 |
+
self.data_offsets = read_longs(f, self._len + 1)
|
| 148 |
+
self.sizes = read_longs(f, self.s)
|
| 149 |
+
|
| 150 |
+
def read_data(self, path):
|
| 151 |
+
self.data_file = open(data_file_path(path), "rb", buffering=0)
|
| 152 |
+
|
| 153 |
+
def check_index(self, i):
|
| 154 |
+
if i < 0 or i >= self._len:
|
| 155 |
+
raise IndexError("index out of range")
|
| 156 |
+
|
| 157 |
+
def __del__(self):
|
| 158 |
+
if self.data_file:
|
| 159 |
+
self.data_file.close()
|
| 160 |
+
|
| 161 |
+
@lru_cache(maxsize=8)
|
| 162 |
+
def __getitem__(self, i):
|
| 163 |
+
if not self.data_file:
|
| 164 |
+
self.read_data(self.path)
|
| 165 |
+
self.check_index(i)
|
| 166 |
+
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
| 167 |
+
a = np.empty(tensor_size, dtype=self.dtype)
|
| 168 |
+
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
| 169 |
+
self.data_file.readinto(a)
|
| 170 |
+
item = torch.from_numpy(a).long()
|
| 171 |
+
if self.fix_lua_indexing:
|
| 172 |
+
item -= 1 # subtract 1 for 0-based indexing
|
| 173 |
+
return item
|
| 174 |
+
|
| 175 |
+
def __len__(self):
|
| 176 |
+
return self._len
|
| 177 |
+
|
| 178 |
+
def num_tokens(self, index):
|
| 179 |
+
return self.sizes[index]
|
| 180 |
+
|
| 181 |
+
def size(self, index):
|
| 182 |
+
return self.sizes[index]
|
| 183 |
+
|
| 184 |
+
@staticmethod
|
| 185 |
+
def exists(path):
|
| 186 |
+
return PathManager.exists(index_file_path(path)) and PathManager.exists(
|
| 187 |
+
data_file_path(path)
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def supports_prefetch(self):
|
| 192 |
+
return False # avoid prefetching to save memory
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class IndexedCachedDataset(IndexedDataset):
|
| 196 |
+
def __init__(self, path, fix_lua_indexing=False):
|
| 197 |
+
super().__init__(path, fix_lua_indexing=fix_lua_indexing)
|
| 198 |
+
self.cache = None
|
| 199 |
+
self.cache_index = {}
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def supports_prefetch(self):
|
| 203 |
+
return True
|
| 204 |
+
|
| 205 |
+
def prefetch(self, indices):
|
| 206 |
+
if all(i in self.cache_index for i in indices):
|
| 207 |
+
return
|
| 208 |
+
if not self.data_file:
|
| 209 |
+
self.read_data(self.path)
|
| 210 |
+
indices = sorted(set(indices))
|
| 211 |
+
total_size = 0
|
| 212 |
+
for i in indices:
|
| 213 |
+
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
|
| 214 |
+
self.cache = np.empty(total_size, dtype=self.dtype)
|
| 215 |
+
ptx = 0
|
| 216 |
+
self.cache_index.clear()
|
| 217 |
+
for i in indices:
|
| 218 |
+
self.cache_index[i] = ptx
|
| 219 |
+
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
| 220 |
+
a = self.cache[ptx : ptx + size]
|
| 221 |
+
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
| 222 |
+
self.data_file.readinto(a)
|
| 223 |
+
ptx += size
|
| 224 |
+
if self.data_file:
|
| 225 |
+
# close and delete data file after prefetch so we can pickle
|
| 226 |
+
self.data_file.close()
|
| 227 |
+
self.data_file = None
|
| 228 |
+
|
| 229 |
+
@lru_cache(maxsize=8)
|
| 230 |
+
def __getitem__(self, i):
|
| 231 |
+
self.check_index(i)
|
| 232 |
+
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
| 233 |
+
a = np.empty(tensor_size, dtype=self.dtype)
|
| 234 |
+
ptx = self.cache_index[i]
|
| 235 |
+
np.copyto(a, self.cache[ptx : ptx + a.size])
|
| 236 |
+
item = torch.from_numpy(a).long()
|
| 237 |
+
if self.fix_lua_indexing:
|
| 238 |
+
item -= 1 # subtract 1 for 0-based indexing
|
| 239 |
+
return item
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class IndexedRawTextDataset(FairseqDataset):
|
| 243 |
+
"""Takes a text file as input and binarizes it in memory at instantiation.
|
| 244 |
+
Original lines are also kept in memory"""
|
| 245 |
+
|
| 246 |
+
def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
|
| 247 |
+
self.tokens_list = []
|
| 248 |
+
self.lines = []
|
| 249 |
+
self.sizes = []
|
| 250 |
+
self.append_eos = append_eos
|
| 251 |
+
self.reverse_order = reverse_order
|
| 252 |
+
self.read_data(path, dictionary)
|
| 253 |
+
self.size = len(self.tokens_list)
|
| 254 |
+
|
| 255 |
+
def read_data(self, path, dictionary):
|
| 256 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 257 |
+
for line in f:
|
| 258 |
+
self.lines.append(line.strip("\n"))
|
| 259 |
+
tokens = dictionary.encode_line(
|
| 260 |
+
line,
|
| 261 |
+
add_if_not_exist=False,
|
| 262 |
+
append_eos=self.append_eos,
|
| 263 |
+
reverse_order=self.reverse_order,
|
| 264 |
+
).long()
|
| 265 |
+
self.tokens_list.append(tokens)
|
| 266 |
+
self.sizes.append(len(tokens))
|
| 267 |
+
self.sizes = np.array(self.sizes)
|
| 268 |
+
|
| 269 |
+
def check_index(self, i):
|
| 270 |
+
if i < 0 or i >= self.size:
|
| 271 |
+
raise IndexError("index out of range")
|
| 272 |
+
|
| 273 |
+
@lru_cache(maxsize=8)
|
| 274 |
+
def __getitem__(self, i):
|
| 275 |
+
self.check_index(i)
|
| 276 |
+
return self.tokens_list[i]
|
| 277 |
+
|
| 278 |
+
def get_original_text(self, i):
|
| 279 |
+
self.check_index(i)
|
| 280 |
+
return self.lines[i]
|
| 281 |
+
|
| 282 |
+
def __del__(self):
|
| 283 |
+
pass
|
| 284 |
+
|
| 285 |
+
def __len__(self):
|
| 286 |
+
return self.size
|
| 287 |
+
|
| 288 |
+
def num_tokens(self, index):
|
| 289 |
+
return self.sizes[index]
|
| 290 |
+
|
| 291 |
+
def size(self, index):
|
| 292 |
+
return self.sizes[index]
|
| 293 |
+
|
| 294 |
+
@staticmethod
|
| 295 |
+
def exists(path):
|
| 296 |
+
return PathManager.exists(path)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class IndexedDatasetBuilder(object):
|
| 300 |
+
element_sizes = {
|
| 301 |
+
np.uint8: 1,
|
| 302 |
+
np.int8: 1,
|
| 303 |
+
np.int16: 2,
|
| 304 |
+
np.int32: 4,
|
| 305 |
+
np.int64: 8,
|
| 306 |
+
np.float: 4,
|
| 307 |
+
np.double: 8,
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
def __init__(self, out_file, dtype=np.int32):
|
| 311 |
+
self.out_file = open(out_file, "wb")
|
| 312 |
+
self.dtype = dtype
|
| 313 |
+
self.data_offsets = [0]
|
| 314 |
+
self.dim_offsets = [0]
|
| 315 |
+
self.sizes = []
|
| 316 |
+
self.element_size = self.element_sizes[self.dtype]
|
| 317 |
+
|
| 318 |
+
def add_item(self, tensor):
|
| 319 |
+
# +1 for Lua compatibility
|
| 320 |
+
bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype))
|
| 321 |
+
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
|
| 322 |
+
for s in tensor.size():
|
| 323 |
+
self.sizes.append(s)
|
| 324 |
+
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
|
| 325 |
+
|
| 326 |
+
def merge_file_(self, another_file):
|
| 327 |
+
index = IndexedDataset(another_file)
|
| 328 |
+
assert index.dtype == self.dtype
|
| 329 |
+
|
| 330 |
+
begin = self.data_offsets[-1]
|
| 331 |
+
for offset in index.data_offsets[1:]:
|
| 332 |
+
self.data_offsets.append(begin + offset)
|
| 333 |
+
self.sizes.extend(index.sizes)
|
| 334 |
+
begin = self.dim_offsets[-1]
|
| 335 |
+
for dim_offset in index.dim_offsets[1:]:
|
| 336 |
+
self.dim_offsets.append(begin + dim_offset)
|
| 337 |
+
|
| 338 |
+
with open(data_file_path(another_file), "rb") as f:
|
| 339 |
+
while True:
|
| 340 |
+
data = f.read(1024)
|
| 341 |
+
if data:
|
| 342 |
+
self.out_file.write(data)
|
| 343 |
+
else:
|
| 344 |
+
break
|
| 345 |
+
|
| 346 |
+
def finalize(self, index_file):
|
| 347 |
+
self.out_file.close()
|
| 348 |
+
index = open(index_file, "wb")
|
| 349 |
+
index.write(b"TNTIDX\x00\x00")
|
| 350 |
+
index.write(struct.pack("<Q", 1))
|
| 351 |
+
index.write(struct.pack("<QQ", code(self.dtype), self.element_size))
|
| 352 |
+
index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
|
| 353 |
+
write_longs(index, self.dim_offsets)
|
| 354 |
+
write_longs(index, self.data_offsets)
|
| 355 |
+
write_longs(index, self.sizes)
|
| 356 |
+
index.close()
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def _warmup_mmap_file(path):
|
| 360 |
+
with open(path, "rb") as stream:
|
| 361 |
+
while stream.read(100 * 1024 * 1024):
|
| 362 |
+
pass
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class MMapIndexedDataset(torch.utils.data.Dataset):
|
| 366 |
+
class Index(object):
|
| 367 |
+
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
| 368 |
+
|
| 369 |
+
@classmethod
|
| 370 |
+
def writer(cls, path, dtype):
|
| 371 |
+
class _Writer(object):
|
| 372 |
+
def __enter__(self):
|
| 373 |
+
self._file = open(path, "wb")
|
| 374 |
+
|
| 375 |
+
self._file.write(cls._HDR_MAGIC)
|
| 376 |
+
self._file.write(struct.pack("<Q", 1))
|
| 377 |
+
self._file.write(struct.pack("<B", code(dtype)))
|
| 378 |
+
|
| 379 |
+
return self
|
| 380 |
+
|
| 381 |
+
@staticmethod
|
| 382 |
+
def _get_pointers(sizes):
|
| 383 |
+
dtype_size = dtype().itemsize
|
| 384 |
+
address = 0
|
| 385 |
+
pointers = []
|
| 386 |
+
|
| 387 |
+
for size in sizes:
|
| 388 |
+
pointers.append(address)
|
| 389 |
+
address += size * dtype_size
|
| 390 |
+
|
| 391 |
+
return pointers
|
| 392 |
+
|
| 393 |
+
def write(self, sizes):
|
| 394 |
+
pointers = self._get_pointers(sizes)
|
| 395 |
+
|
| 396 |
+
self._file.write(struct.pack("<Q", len(sizes)))
|
| 397 |
+
|
| 398 |
+
sizes = np.array(sizes, dtype=np.int32)
|
| 399 |
+
self._file.write(sizes.tobytes(order="C"))
|
| 400 |
+
del sizes
|
| 401 |
+
|
| 402 |
+
pointers = np.array(pointers, dtype=np.int64)
|
| 403 |
+
self._file.write(pointers.tobytes(order="C"))
|
| 404 |
+
del pointers
|
| 405 |
+
|
| 406 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 407 |
+
self._file.close()
|
| 408 |
+
|
| 409 |
+
return _Writer()
|
| 410 |
+
|
| 411 |
+
def __init__(self, path):
|
| 412 |
+
with open(path, "rb") as stream:
|
| 413 |
+
magic_test = stream.read(9)
|
| 414 |
+
assert self._HDR_MAGIC == magic_test, (
|
| 415 |
+
"Index file doesn't match expected format. "
|
| 416 |
+
"Make sure that --dataset-impl is configured properly."
|
| 417 |
+
)
|
| 418 |
+
version = struct.unpack("<Q", stream.read(8))
|
| 419 |
+
assert (1,) == version
|
| 420 |
+
|
| 421 |
+
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
| 422 |
+
self._dtype = dtypes[dtype_code]
|
| 423 |
+
self._dtype_size = self._dtype().itemsize
|
| 424 |
+
|
| 425 |
+
self._len = struct.unpack("<Q", stream.read(8))[0]
|
| 426 |
+
offset = stream.tell()
|
| 427 |
+
|
| 428 |
+
_warmup_mmap_file(path)
|
| 429 |
+
|
| 430 |
+
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
| 431 |
+
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
| 432 |
+
self._sizes = np.frombuffer(
|
| 433 |
+
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
| 434 |
+
)
|
| 435 |
+
self._pointers = np.frombuffer(
|
| 436 |
+
self._bin_buffer,
|
| 437 |
+
dtype=np.int64,
|
| 438 |
+
count=self._len,
|
| 439 |
+
offset=offset + self._sizes.nbytes,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
def __del__(self):
|
| 443 |
+
self._bin_buffer_mmap._mmap.close()
|
| 444 |
+
del self._bin_buffer_mmap
|
| 445 |
+
|
| 446 |
+
@property
|
| 447 |
+
def dtype(self):
|
| 448 |
+
return self._dtype
|
| 449 |
+
|
| 450 |
+
@property
|
| 451 |
+
def sizes(self):
|
| 452 |
+
return self._sizes
|
| 453 |
+
|
| 454 |
+
@lru_cache(maxsize=8)
|
| 455 |
+
def __getitem__(self, i):
|
| 456 |
+
return self._pointers[i], self._sizes[i]
|
| 457 |
+
|
| 458 |
+
def __len__(self):
|
| 459 |
+
return self._len
|
| 460 |
+
|
| 461 |
+
def __init__(self, path):
|
| 462 |
+
super().__init__()
|
| 463 |
+
|
| 464 |
+
self._path = None
|
| 465 |
+
self._index = None
|
| 466 |
+
self._bin_buffer = None
|
| 467 |
+
|
| 468 |
+
self._do_init(path)
|
| 469 |
+
|
| 470 |
+
def __getstate__(self):
|
| 471 |
+
return self._path
|
| 472 |
+
|
| 473 |
+
def __setstate__(self, state):
|
| 474 |
+
self._do_init(state)
|
| 475 |
+
|
| 476 |
+
def _do_init(self, path):
|
| 477 |
+
self._path = path
|
| 478 |
+
self._index = self.Index(index_file_path(self._path))
|
| 479 |
+
|
| 480 |
+
_warmup_mmap_file(data_file_path(self._path))
|
| 481 |
+
self._bin_buffer_mmap = np.memmap(
|
| 482 |
+
data_file_path(self._path), mode="r", order="C"
|
| 483 |
+
)
|
| 484 |
+
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
| 485 |
+
|
| 486 |
+
def __del__(self):
|
| 487 |
+
self._bin_buffer_mmap._mmap.close()
|
| 488 |
+
del self._bin_buffer_mmap
|
| 489 |
+
del self._index
|
| 490 |
+
|
| 491 |
+
def __len__(self):
|
| 492 |
+
return len(self._index)
|
| 493 |
+
|
| 494 |
+
@lru_cache(maxsize=8)
|
| 495 |
+
def __getitem__(self, i):
|
| 496 |
+
ptr, size = self._index[i]
|
| 497 |
+
np_array = np.frombuffer(
|
| 498 |
+
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
| 499 |
+
)
|
| 500 |
+
if self._index.dtype != np.int64:
|
| 501 |
+
np_array = np_array.astype(np.int64)
|
| 502 |
+
|
| 503 |
+
return torch.from_numpy(np_array)
|
| 504 |
+
|
| 505 |
+
@property
|
| 506 |
+
def sizes(self):
|
| 507 |
+
return self._index.sizes
|
| 508 |
+
|
| 509 |
+
@property
|
| 510 |
+
def supports_prefetch(self):
|
| 511 |
+
return False
|
| 512 |
+
|
| 513 |
+
@staticmethod
|
| 514 |
+
def exists(path):
|
| 515 |
+
return PathManager.exists(index_file_path(path)) and PathManager.exists(
|
| 516 |
+
data_file_path(path)
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def get_indexed_dataset_to_local(path):
|
| 521 |
+
local_index_path = PathManager.get_local_path(index_file_path(path))
|
| 522 |
+
local_data_path = PathManager.get_local_path(data_file_path(path))
|
| 523 |
+
|
| 524 |
+
assert local_index_path.endswith(".idx") and local_data_path.endswith(".bin"), (
|
| 525 |
+
"PathManager.get_local_path does not return files with expected patterns: "
|
| 526 |
+
f"{local_index_path} and {local_data_path}"
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
local_path = local_data_path[:-4] # stripping surfix ".bin"
|
| 530 |
+
assert local_path == local_index_path[:-4] # stripping surfix ".idx"
|
| 531 |
+
return local_path
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class MMapIndexedDatasetBuilder(object):
|
| 535 |
+
def __init__(self, out_file, dtype=np.int64):
|
| 536 |
+
self._data_file = open(out_file, "wb")
|
| 537 |
+
self._dtype = dtype
|
| 538 |
+
self._sizes = []
|
| 539 |
+
|
| 540 |
+
def add_item(self, tensor):
|
| 541 |
+
np_array = np.array(tensor.numpy(), dtype=self._dtype)
|
| 542 |
+
self._data_file.write(np_array.tobytes(order="C"))
|
| 543 |
+
self._sizes.append(np_array.size)
|
| 544 |
+
|
| 545 |
+
def merge_file_(self, another_file):
|
| 546 |
+
# Concatenate index
|
| 547 |
+
index = MMapIndexedDataset.Index(index_file_path(another_file))
|
| 548 |
+
assert index.dtype == self._dtype
|
| 549 |
+
|
| 550 |
+
for size in index.sizes:
|
| 551 |
+
self._sizes.append(size)
|
| 552 |
+
|
| 553 |
+
# Concatenate data
|
| 554 |
+
with open(data_file_path(another_file), "rb") as f:
|
| 555 |
+
shutil.copyfileobj(f, self._data_file)
|
| 556 |
+
|
| 557 |
+
def finalize(self, index_file):
|
| 558 |
+
self._data_file.close()
|
| 559 |
+
|
| 560 |
+
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
|
| 561 |
+
index.write(self._sizes)
|
fairseq-0.10.2/fairseq/data/iterators.py
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
import operator
|
| 10 |
+
import os
|
| 11 |
+
import queue
|
| 12 |
+
import time
|
| 13 |
+
from threading import Thread
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
from fairseq.data import data_utils
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
# Object used by _background_consumer to signal the source is exhausted
|
| 23 |
+
# to the main thread.
|
| 24 |
+
_sentinel = object()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class CountingIterator(object):
|
| 28 |
+
"""Wrapper around an iterable that maintains the iteration count.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
iterable (iterable): iterable to wrap
|
| 32 |
+
start (int): starting iteration count. Note that this doesn't
|
| 33 |
+
actually advance the iterator.
|
| 34 |
+
total (int): override the iterator length returned by
|
| 35 |
+
``__len__``. This can be used to truncate *iterator*.
|
| 36 |
+
|
| 37 |
+
Attributes:
|
| 38 |
+
n (int): number of elements consumed from this iterator
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, iterable, start=None, total=None):
|
| 42 |
+
self.iterable = iterable
|
| 43 |
+
self.itr = iter(self)
|
| 44 |
+
|
| 45 |
+
if start is None:
|
| 46 |
+
self.n = getattr(iterable, "n", 0)
|
| 47 |
+
else:
|
| 48 |
+
self.n = start
|
| 49 |
+
|
| 50 |
+
if total is None:
|
| 51 |
+
self.total = self.n + len(iterable)
|
| 52 |
+
else:
|
| 53 |
+
self.total = total
|
| 54 |
+
|
| 55 |
+
def __len__(self):
|
| 56 |
+
return self.total
|
| 57 |
+
|
| 58 |
+
def __iter__(self):
|
| 59 |
+
for x in self.iterable:
|
| 60 |
+
if self.n >= self.total:
|
| 61 |
+
raise RuntimeError(
|
| 62 |
+
"Mismatch between actual and expected iterable length. "
|
| 63 |
+
"This may be caused by resuming training from a checkpoint using "
|
| 64 |
+
"a different number of GPUs, in which case you can try the "
|
| 65 |
+
"--reset-dataloader option. Alternatively you may have a train or "
|
| 66 |
+
"validation set that is smaller than the number of GPUs. If none "
|
| 67 |
+
"of these apply, please report this to the fairseq developers."
|
| 68 |
+
)
|
| 69 |
+
self.n += 1
|
| 70 |
+
yield x
|
| 71 |
+
|
| 72 |
+
def __next__(self):
|
| 73 |
+
return next(self.itr)
|
| 74 |
+
|
| 75 |
+
def has_next(self):
|
| 76 |
+
"""Whether the iterator has been exhausted."""
|
| 77 |
+
return self.n < len(self)
|
| 78 |
+
|
| 79 |
+
def skip(self, num_to_skip):
|
| 80 |
+
"""Fast-forward the iterator by skipping *num_to_skip* elements."""
|
| 81 |
+
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
|
| 82 |
+
return self
|
| 83 |
+
|
| 84 |
+
def take(self, n):
|
| 85 |
+
"""
|
| 86 |
+
Truncates the iterator to n elements at most.
|
| 87 |
+
"""
|
| 88 |
+
self.total = min(self.total, n)
|
| 89 |
+
|
| 90 |
+
# Propagate this change to the underlying iterator
|
| 91 |
+
# Only take after what we have already consumed (i.e. after restarting
|
| 92 |
+
# from checkpoint mid epoch, we have to subtract self.n which is the
|
| 93 |
+
# starting point)
|
| 94 |
+
#
|
| 95 |
+
# This to maintain the invariant self.total = self.n + len(iterable),
|
| 96 |
+
# before calling __next__ or __iter__
|
| 97 |
+
propagated_take = max(n - self.n, 0)
|
| 98 |
+
if hasattr(self.iterable, "take"):
|
| 99 |
+
self.iterable.take(propagated_take)
|
| 100 |
+
else:
|
| 101 |
+
self.iterable = itertools.islice(self.iterable, propagated_take)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class EpochBatchIterating(object):
|
| 105 |
+
def __len__(self) -> int:
|
| 106 |
+
raise NotImplementedError
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def next_epoch_idx(self):
|
| 110 |
+
raise NotImplementedError
|
| 111 |
+
|
| 112 |
+
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
|
| 113 |
+
"""Return a new iterator over the dataset.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
shuffle (bool, optional): shuffle batches before returning the
|
| 117 |
+
iterator (default: True).
|
| 118 |
+
fix_batches_to_gpus: ensure that batches are always
|
| 119 |
+
allocated to the same shards across epochs. Requires
|
| 120 |
+
that :attr:`dataset` supports prefetching (default: False).
|
| 121 |
+
"""
|
| 122 |
+
raise NotImplementedError
|
| 123 |
+
|
| 124 |
+
def end_of_epoch(self) -> bool:
|
| 125 |
+
"""Returns whether the most recent epoch iterator has been exhausted"""
|
| 126 |
+
raise NotImplementedError
|
| 127 |
+
|
| 128 |
+
@property
|
| 129 |
+
def iterations_in_epoch(self) -> int:
|
| 130 |
+
"""The number of consumed batches in the current epoch."""
|
| 131 |
+
raise NotImplementedError
|
| 132 |
+
|
| 133 |
+
def state_dict(self):
|
| 134 |
+
"""Returns a dictionary containing a whole state of the iterator."""
|
| 135 |
+
raise NotImplementedError
|
| 136 |
+
|
| 137 |
+
def load_state_dict(self, state_dict):
|
| 138 |
+
"""Copies the state of the iterator from the given *state_dict*."""
|
| 139 |
+
raise NotImplementedError
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class StreamingEpochBatchIterator(EpochBatchIterating):
|
| 143 |
+
def __init__(
|
| 144 |
+
self,
|
| 145 |
+
dataset,
|
| 146 |
+
epoch=1,
|
| 147 |
+
num_shards=1,
|
| 148 |
+
shard_id=0,
|
| 149 |
+
):
|
| 150 |
+
assert isinstance(dataset, torch.utils.data.IterableDataset)
|
| 151 |
+
self.dataset = dataset
|
| 152 |
+
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
|
| 153 |
+
self._current_epoch_iterator = None
|
| 154 |
+
self.num_shards = num_shards
|
| 155 |
+
self.shard_id = shard_id
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def next_epoch_idx(self):
|
| 159 |
+
"""Return the epoch index after *next_epoch_itr* is called."""
|
| 160 |
+
if self._current_epoch_iterator is not None and self.end_of_epoch():
|
| 161 |
+
return self.epoch + 1
|
| 162 |
+
else:
|
| 163 |
+
return self.epoch
|
| 164 |
+
|
| 165 |
+
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
|
| 166 |
+
self.epoch = self.next_epoch_idx
|
| 167 |
+
if hasattr(self.dataset, "set_epoch"):
|
| 168 |
+
self.dataset.set_epoch(self.epoch)
|
| 169 |
+
self._current_epoch_iterator = CountingIterator(
|
| 170 |
+
iterable=ShardedIterator(
|
| 171 |
+
iterable=self.dataset,
|
| 172 |
+
num_shards=self.num_shards,
|
| 173 |
+
shard_id=self.shard_id,
|
| 174 |
+
),
|
| 175 |
+
)
|
| 176 |
+
return self._current_epoch_iterator
|
| 177 |
+
|
| 178 |
+
def end_of_epoch(self) -> bool:
|
| 179 |
+
return not self._current_epoch_iterator.has_next()
|
| 180 |
+
|
| 181 |
+
@property
|
| 182 |
+
def iterations_in_epoch(self) -> int:
|
| 183 |
+
if self._current_epoch_iterator is not None:
|
| 184 |
+
return self._current_epoch_iterator.n
|
| 185 |
+
return 0
|
| 186 |
+
|
| 187 |
+
def state_dict(self):
|
| 188 |
+
return {
|
| 189 |
+
"epoch": self.epoch,
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
def load_state_dict(self, state_dict):
|
| 193 |
+
self.epoch = state_dict["epoch"]
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class EpochBatchIterator(EpochBatchIterating):
|
| 197 |
+
"""A multi-epoch iterator over a :class:`torch.utils.data.Dataset`.
|
| 198 |
+
|
| 199 |
+
Compared to :class:`torch.utils.data.DataLoader`, this iterator:
|
| 200 |
+
|
| 201 |
+
- can be reused across multiple epochs with the :func:`next_epoch_itr`
|
| 202 |
+
method (optionally shuffled between epochs)
|
| 203 |
+
- can be serialized/deserialized with the :func:`state_dict` and
|
| 204 |
+
:func:`load_state_dict` methods
|
| 205 |
+
- supports sharding with the *num_shards* and *shard_id* arguments
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
dataset (~torch.utils.data.Dataset): dataset from which to load the data
|
| 209 |
+
collate_fn (callable): merges a list of samples to form a mini-batch
|
| 210 |
+
batch_sampler (~torch.utils.data.Sampler or a callable): an iterator over batches of
|
| 211 |
+
indices, or a callable to create such an iterator (~torch.utils.data.Sampler).
|
| 212 |
+
A callable batch_sampler will be called for each epoch to enable per epoch dynamic
|
| 213 |
+
batch iterators defined by this callable batch_sampler.
|
| 214 |
+
seed (int, optional): seed for random number generator for
|
| 215 |
+
reproducibility (default: 1).
|
| 216 |
+
num_shards (int, optional): shard the data iterator into N
|
| 217 |
+
shards (default: 1).
|
| 218 |
+
shard_id (int, optional): which shard of the data iterator to
|
| 219 |
+
return (default: 0).
|
| 220 |
+
num_workers (int, optional): how many subprocesses to use for data
|
| 221 |
+
loading. 0 means the data will be loaded in the main process
|
| 222 |
+
(default: 0).
|
| 223 |
+
epoch (int, optional): the epoch to start the iterator from
|
| 224 |
+
(default: 1).
|
| 225 |
+
buffer_size (int, optional): the number of batches to keep ready in the
|
| 226 |
+
queue. Helps speeding up dataloading. When buffer_size is zero, the
|
| 227 |
+
default torch.utils.data.DataLoader preloading is used.
|
| 228 |
+
timeout (int, optional): if positive, the timeout value for collecting a batch
|
| 229 |
+
from workers. Should always be non-negative (default: ``0``).
|
| 230 |
+
disable_shuffling (bool, optional): force disable shuffling
|
| 231 |
+
(default: ``False``).
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
def __init__(
|
| 235 |
+
self,
|
| 236 |
+
dataset,
|
| 237 |
+
collate_fn,
|
| 238 |
+
batch_sampler,
|
| 239 |
+
seed=1,
|
| 240 |
+
num_shards=1,
|
| 241 |
+
shard_id=0,
|
| 242 |
+
num_workers=0,
|
| 243 |
+
epoch=1,
|
| 244 |
+
buffer_size=0,
|
| 245 |
+
timeout=0,
|
| 246 |
+
disable_shuffling=False,
|
| 247 |
+
):
|
| 248 |
+
assert isinstance(dataset, torch.utils.data.Dataset)
|
| 249 |
+
self.dataset = dataset
|
| 250 |
+
self.collate_fn = collate_fn
|
| 251 |
+
self.batch_sampler = batch_sampler
|
| 252 |
+
self._frozen_batches = (
|
| 253 |
+
tuple(batch_sampler) if not callable(batch_sampler) else None
|
| 254 |
+
)
|
| 255 |
+
self.seed = seed
|
| 256 |
+
self.num_shards = num_shards
|
| 257 |
+
self.shard_id = shard_id
|
| 258 |
+
self.num_workers = num_workers
|
| 259 |
+
# This upper limit here is to prevent people from abusing this feature
|
| 260 |
+
# in a shared computing environment.
|
| 261 |
+
self.buffer_size = min(buffer_size, 20)
|
| 262 |
+
self.timeout = timeout
|
| 263 |
+
self.disable_shuffling = disable_shuffling
|
| 264 |
+
|
| 265 |
+
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
|
| 266 |
+
self.shuffle = not disable_shuffling
|
| 267 |
+
self._cur_epoch_itr = None
|
| 268 |
+
self._next_epoch_itr = None
|
| 269 |
+
self._supports_prefetch = getattr(dataset, "supports_prefetch", False)
|
| 270 |
+
|
| 271 |
+
@property
|
| 272 |
+
def frozen_batches(self):
|
| 273 |
+
if self._frozen_batches is None:
|
| 274 |
+
self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch))
|
| 275 |
+
return self._frozen_batches
|
| 276 |
+
|
| 277 |
+
@property
|
| 278 |
+
def first_batch(self):
|
| 279 |
+
if len(self.frozen_batches) == 0:
|
| 280 |
+
raise Exception(
|
| 281 |
+
"The dataset is empty. This could indicate "
|
| 282 |
+
"that all elements in the dataset have been skipped. "
|
| 283 |
+
"Try increasing the max number of allowed tokens or using "
|
| 284 |
+
"a larger dataset."
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
if getattr(self.dataset, "supports_fetch_outside_dataloader", True):
|
| 288 |
+
return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]])
|
| 289 |
+
else:
|
| 290 |
+
return "DUMMY"
|
| 291 |
+
|
| 292 |
+
def __len__(self):
|
| 293 |
+
return int(math.ceil(len(self.frozen_batches) / float(self.num_shards)))
|
| 294 |
+
|
| 295 |
+
@property
|
| 296 |
+
def n(self):
|
| 297 |
+
return self.iterations_in_epoch
|
| 298 |
+
|
| 299 |
+
@property
|
| 300 |
+
def next_epoch_idx(self):
|
| 301 |
+
"""Return the epoch index after *next_epoch_itr* is called."""
|
| 302 |
+
if self._next_epoch_itr is not None:
|
| 303 |
+
return self.epoch
|
| 304 |
+
elif self._cur_epoch_itr is not None and self.end_of_epoch():
|
| 305 |
+
return self.epoch + 1
|
| 306 |
+
else:
|
| 307 |
+
return self.epoch
|
| 308 |
+
|
| 309 |
+
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
|
| 310 |
+
"""Return a new iterator over the dataset.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
shuffle (bool, optional): shuffle batches before returning the
|
| 314 |
+
iterator (default: True).
|
| 315 |
+
fix_batches_to_gpus: ensure that batches are always
|
| 316 |
+
allocated to the same shards across epochs. Requires
|
| 317 |
+
that :attr:`dataset` supports prefetching (default: False).
|
| 318 |
+
"""
|
| 319 |
+
if self.disable_shuffling:
|
| 320 |
+
shuffle = False
|
| 321 |
+
self.epoch = self.next_epoch_idx
|
| 322 |
+
if hasattr(self.dataset, "set_epoch"):
|
| 323 |
+
self.dataset.set_epoch(self.epoch)
|
| 324 |
+
if self._next_epoch_itr is not None:
|
| 325 |
+
self._cur_epoch_itr = self._next_epoch_itr
|
| 326 |
+
self._next_epoch_itr = None
|
| 327 |
+
else:
|
| 328 |
+
if callable(self.batch_sampler):
|
| 329 |
+
# reset _frozen_batches to refresh the next epoch
|
| 330 |
+
self._frozen_batches = None
|
| 331 |
+
self._cur_epoch_itr = self._get_iterator_for_epoch(
|
| 332 |
+
self.epoch,
|
| 333 |
+
shuffle,
|
| 334 |
+
fix_batches_to_gpus=fix_batches_to_gpus,
|
| 335 |
+
)
|
| 336 |
+
self.shuffle = shuffle
|
| 337 |
+
return self._cur_epoch_itr
|
| 338 |
+
|
| 339 |
+
def end_of_epoch(self) -> bool:
|
| 340 |
+
"""Returns whether the most recent epoch iterator has been exhausted"""
|
| 341 |
+
return not self._cur_epoch_itr.has_next()
|
| 342 |
+
|
| 343 |
+
@property
|
| 344 |
+
def iterations_in_epoch(self):
|
| 345 |
+
"""The number of consumed batches in the current epoch."""
|
| 346 |
+
if self._cur_epoch_itr is not None:
|
| 347 |
+
return self._cur_epoch_itr.n
|
| 348 |
+
elif self._next_epoch_itr is not None:
|
| 349 |
+
return self._next_epoch_itr.n
|
| 350 |
+
return 0
|
| 351 |
+
|
| 352 |
+
def state_dict(self):
|
| 353 |
+
"""Returns a dictionary containing a whole state of the iterator."""
|
| 354 |
+
if self.end_of_epoch():
|
| 355 |
+
epoch = self.epoch + 1
|
| 356 |
+
iter_in_epoch = 0
|
| 357 |
+
else:
|
| 358 |
+
epoch = self.epoch
|
| 359 |
+
iter_in_epoch = self.iterations_in_epoch
|
| 360 |
+
return {
|
| 361 |
+
"version": 2,
|
| 362 |
+
"epoch": epoch,
|
| 363 |
+
"iterations_in_epoch": iter_in_epoch,
|
| 364 |
+
"shuffle": self.shuffle,
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
def load_state_dict(self, state_dict):
|
| 368 |
+
"""Copies the state of the iterator from the given *state_dict*."""
|
| 369 |
+
self.epoch = state_dict["epoch"]
|
| 370 |
+
itr_pos = state_dict.get("iterations_in_epoch", 0)
|
| 371 |
+
version = state_dict.get("version", 1)
|
| 372 |
+
if itr_pos > 0:
|
| 373 |
+
# fast-forward epoch iterator
|
| 374 |
+
self._next_epoch_itr = self._get_iterator_for_epoch(
|
| 375 |
+
self.epoch,
|
| 376 |
+
shuffle=state_dict.get("shuffle", True),
|
| 377 |
+
offset=itr_pos,
|
| 378 |
+
)
|
| 379 |
+
if self._next_epoch_itr is None:
|
| 380 |
+
if version == 1:
|
| 381 |
+
# legacy behavior: we finished the epoch, increment epoch counter
|
| 382 |
+
self.epoch += 1
|
| 383 |
+
else:
|
| 384 |
+
raise RuntimeError(
|
| 385 |
+
"Cannot resume training due to dataloader mismatch, please "
|
| 386 |
+
"report this to the fairseq developers. You can relaunch "
|
| 387 |
+
"training with `--reset-dataloader` and it should work."
|
| 388 |
+
)
|
| 389 |
+
else:
|
| 390 |
+
self._next_epoch_itr = None
|
| 391 |
+
|
| 392 |
+
def _get_iterator_for_epoch(
|
| 393 |
+
self, epoch, shuffle, fix_batches_to_gpus=False, offset=0
|
| 394 |
+
):
|
| 395 |
+
def shuffle_batches(batches, seed):
|
| 396 |
+
with data_utils.numpy_seed(seed):
|
| 397 |
+
np.random.shuffle(batches)
|
| 398 |
+
return batches
|
| 399 |
+
|
| 400 |
+
if self._supports_prefetch:
|
| 401 |
+
batches = self.frozen_batches
|
| 402 |
+
|
| 403 |
+
if shuffle and not fix_batches_to_gpus:
|
| 404 |
+
batches = shuffle_batches(list(batches), self.seed + epoch)
|
| 405 |
+
|
| 406 |
+
batches = list(
|
| 407 |
+
ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
|
| 408 |
+
)
|
| 409 |
+
self.dataset.prefetch([i for s in batches for i in s])
|
| 410 |
+
|
| 411 |
+
if shuffle and fix_batches_to_gpus:
|
| 412 |
+
batches = shuffle_batches(batches, self.seed + epoch + self.shard_id)
|
| 413 |
+
else:
|
| 414 |
+
if shuffle:
|
| 415 |
+
batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch)
|
| 416 |
+
else:
|
| 417 |
+
batches = self.frozen_batches
|
| 418 |
+
batches = list(
|
| 419 |
+
ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
if offset > 0 and offset >= len(batches):
|
| 423 |
+
return None
|
| 424 |
+
|
| 425 |
+
if self.num_workers > 0:
|
| 426 |
+
os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"
|
| 427 |
+
|
| 428 |
+
# Create data loader
|
| 429 |
+
itr = torch.utils.data.DataLoader(
|
| 430 |
+
self.dataset,
|
| 431 |
+
collate_fn=self.collate_fn,
|
| 432 |
+
batch_sampler=batches[offset:],
|
| 433 |
+
num_workers=self.num_workers,
|
| 434 |
+
timeout=self.timeout,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Wrap with a BufferedIterator if needed
|
| 438 |
+
if self.buffer_size > 0:
|
| 439 |
+
itr = BufferedIterator(self.buffer_size, itr)
|
| 440 |
+
|
| 441 |
+
# Wrap with CoutingIterator
|
| 442 |
+
itr = CountingIterator(itr, start=offset)
|
| 443 |
+
return itr
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class GroupedIterator(CountingIterator):
|
| 447 |
+
"""Wrapper around an iterable that returns groups (chunks) of items.
|
| 448 |
+
|
| 449 |
+
Args:
|
| 450 |
+
iterable (iterable): iterable to wrap
|
| 451 |
+
chunk_size (int): size of each chunk
|
| 452 |
+
|
| 453 |
+
Attributes:
|
| 454 |
+
n (int): number of elements consumed from this iterator
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
def __init__(self, iterable, chunk_size):
|
| 458 |
+
itr = _chunk_iterator(iterable, chunk_size)
|
| 459 |
+
super().__init__(
|
| 460 |
+
itr,
|
| 461 |
+
start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))),
|
| 462 |
+
total=int(math.ceil(len(iterable) / float(chunk_size))),
|
| 463 |
+
)
|
| 464 |
+
self.chunk_size = chunk_size
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def _chunk_iterator(itr, chunk_size):
|
| 468 |
+
chunk = []
|
| 469 |
+
for x in itr:
|
| 470 |
+
chunk.append(x)
|
| 471 |
+
if len(chunk) == chunk_size:
|
| 472 |
+
yield chunk
|
| 473 |
+
chunk = []
|
| 474 |
+
if len(chunk) > 0:
|
| 475 |
+
yield chunk
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class ShardedIterator(CountingIterator):
|
| 479 |
+
"""A sharded wrapper around an iterable, padded to length.
|
| 480 |
+
|
| 481 |
+
Args:
|
| 482 |
+
iterable (iterable): iterable to wrap
|
| 483 |
+
num_shards (int): number of shards to split the iterable into
|
| 484 |
+
shard_id (int): which shard to iterator over
|
| 485 |
+
fill_value (Any, optional): padding value when the iterable doesn't
|
| 486 |
+
evenly divide *num_shards* (default: None).
|
| 487 |
+
|
| 488 |
+
Attributes:
|
| 489 |
+
n (int): number of elements consumed from this iterator
|
| 490 |
+
"""
|
| 491 |
+
|
| 492 |
+
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
|
| 493 |
+
if shard_id < 0 or shard_id >= num_shards:
|
| 494 |
+
raise ValueError("shard_id must be between 0 and num_shards")
|
| 495 |
+
sharded_len = int(math.ceil(len(iterable) / float(num_shards)))
|
| 496 |
+
itr = map(
|
| 497 |
+
operator.itemgetter(1),
|
| 498 |
+
itertools.zip_longest(
|
| 499 |
+
range(sharded_len),
|
| 500 |
+
itertools.islice(iterable, shard_id, len(iterable), num_shards),
|
| 501 |
+
fillvalue=fill_value,
|
| 502 |
+
),
|
| 503 |
+
)
|
| 504 |
+
super().__init__(
|
| 505 |
+
itr,
|
| 506 |
+
start=int(math.ceil(getattr(iterable, "n", 0) / float(num_shards))),
|
| 507 |
+
total=sharded_len,
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
class BackgroundConsumer(Thread):
|
| 512 |
+
def __init__(self, queue, source, max_len):
|
| 513 |
+
Thread.__init__(self)
|
| 514 |
+
|
| 515 |
+
self._queue = queue
|
| 516 |
+
self._source = source
|
| 517 |
+
self._max_len = max_len
|
| 518 |
+
self.count = 0
|
| 519 |
+
|
| 520 |
+
def run(self):
|
| 521 |
+
try:
|
| 522 |
+
for item in self._source:
|
| 523 |
+
self._queue.put(item)
|
| 524 |
+
|
| 525 |
+
# Stop if we reached the maximum length
|
| 526 |
+
self.count += 1
|
| 527 |
+
if self._max_len is not None and self.count >= self._max_len:
|
| 528 |
+
break
|
| 529 |
+
|
| 530 |
+
# Signal the consumer we are done.
|
| 531 |
+
self._queue.put(_sentinel)
|
| 532 |
+
except Exception as e:
|
| 533 |
+
self._queue.put(e)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
class BufferedIterator(object):
|
| 537 |
+
def __init__(self, size, iterable):
|
| 538 |
+
self._queue = queue.Queue(size)
|
| 539 |
+
self._iterable = iterable
|
| 540 |
+
self._consumer = None
|
| 541 |
+
|
| 542 |
+
self.start_time = time.time()
|
| 543 |
+
self.warning_time = None
|
| 544 |
+
|
| 545 |
+
self.total = len(iterable)
|
| 546 |
+
|
| 547 |
+
def _create_consumer(self):
|
| 548 |
+
self._consumer = BackgroundConsumer(
|
| 549 |
+
self._queue,
|
| 550 |
+
self._iterable,
|
| 551 |
+
self.total,
|
| 552 |
+
)
|
| 553 |
+
self._consumer.daemon = True
|
| 554 |
+
self._consumer.start()
|
| 555 |
+
|
| 556 |
+
def __iter__(self):
|
| 557 |
+
return self
|
| 558 |
+
|
| 559 |
+
def __len__(self):
|
| 560 |
+
return self.total
|
| 561 |
+
|
| 562 |
+
def take(self, n):
|
| 563 |
+
self.total = min(self.total, n)
|
| 564 |
+
|
| 565 |
+
# Propagate this change to the underlying iterator
|
| 566 |
+
if hasattr(self._iterable, "take"):
|
| 567 |
+
self._iterable.take(n)
|
| 568 |
+
|
| 569 |
+
def __next__(self):
|
| 570 |
+
# Create consumer if not created yet
|
| 571 |
+
if self._consumer is None:
|
| 572 |
+
self._create_consumer()
|
| 573 |
+
|
| 574 |
+
# Notify the user if there is a data loading bottleneck
|
| 575 |
+
if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)):
|
| 576 |
+
if time.time() - self.start_time > 5 * 60:
|
| 577 |
+
if (
|
| 578 |
+
self.warning_time is None
|
| 579 |
+
or time.time() - self.warning_time > 15 * 60
|
| 580 |
+
):
|
| 581 |
+
logger.debug(
|
| 582 |
+
"Data loading buffer is empty or nearly empty. This may "
|
| 583 |
+
"indicate a data loading bottleneck, and increasing the "
|
| 584 |
+
"number of workers (--num-workers) may help."
|
| 585 |
+
)
|
| 586 |
+
self.warning_time = time.time()
|
| 587 |
+
|
| 588 |
+
# Get next example
|
| 589 |
+
item = self._queue.get(True)
|
| 590 |
+
if isinstance(item, Exception):
|
| 591 |
+
raise item
|
| 592 |
+
if item is _sentinel:
|
| 593 |
+
raise StopIteration()
|
| 594 |
+
return item
|
fairseq-0.10.2/fairseq/data/legacy/masked_lm_dataset.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Dict, List, Tuple
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from fairseq.data import Dictionary, FairseqDataset, data_utils
|
| 12 |
+
from fairseq.data.concat_dataset import ConcatDataset
|
| 13 |
+
from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
|
| 14 |
+
from fairseq.data.token_block_dataset import TokenBlockDataset
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MaskedLMDataset(FairseqDataset):
|
| 18 |
+
"""
|
| 19 |
+
A wrapper Dataset for masked language modelling. The dataset
|
| 20 |
+
wraps around TokenBlockDataset or BlockedPairDataset and creates a batch
|
| 21 |
+
where the input blocks are masked according to the specified masking
|
| 22 |
+
probability. Additionally the batch can also contain sentence level targets
|
| 23 |
+
if this is specified.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
dataset: Dataset which generates blocks of data. Only BlockPairDataset
|
| 27 |
+
and TokenBlockDataset are supported.
|
| 28 |
+
sizes: Sentence lengths
|
| 29 |
+
vocab: Dictionary with the vocabulary and special tokens.
|
| 30 |
+
pad_idx: Id of padding token in dictionary
|
| 31 |
+
mask_idx: Id of mask token in dictionary
|
| 32 |
+
classif_token_idx: Id of classification token in dictionary. This is the
|
| 33 |
+
token associated with the sentence embedding (Eg: CLS for BERT)
|
| 34 |
+
sep_token_idx: Id of separator token in dictionary
|
| 35 |
+
(Eg: SEP in BERT)
|
| 36 |
+
seed: Seed for random number generator for reproducibility.
|
| 37 |
+
shuffle: Shuffle the elements before batching.
|
| 38 |
+
has_pairs: Specifies whether the underlying dataset
|
| 39 |
+
generates a pair of blocks along with a sentence_target or not.
|
| 40 |
+
Setting it to True assumes that the underlying dataset generates a
|
| 41 |
+
label for the pair of sentences which is surfaced as
|
| 42 |
+
sentence_target. The default value assumes a single block with no
|
| 43 |
+
sentence target.
|
| 44 |
+
segment_id: An optional segment id for filling in the segment labels
|
| 45 |
+
when we are in the single block setting (Eg: XLM). Default is 0.
|
| 46 |
+
masking_ratio: specifies what percentage of the blocks should be masked.
|
| 47 |
+
masking_prob: specifies the probability of a given token being
|
| 48 |
+
replaced with the "MASK" token.
|
| 49 |
+
random_token_prob: specifies the probability of a given token being
|
| 50 |
+
replaced by a random token from the vocabulary.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
dataset: FairseqDataset,
|
| 56 |
+
sizes: np.ndarray,
|
| 57 |
+
vocab: Dictionary,
|
| 58 |
+
pad_idx: int,
|
| 59 |
+
mask_idx: int,
|
| 60 |
+
classif_token_idx: int,
|
| 61 |
+
sep_token_idx: int,
|
| 62 |
+
seed: int = 1,
|
| 63 |
+
shuffle: bool = True,
|
| 64 |
+
has_pairs: bool = True,
|
| 65 |
+
segment_id: int = 0,
|
| 66 |
+
masking_ratio: float = 0.15,
|
| 67 |
+
masking_prob: float = 0.8,
|
| 68 |
+
random_token_prob: float = 0.1,
|
| 69 |
+
):
|
| 70 |
+
# Make sure the input datasets are the ones supported
|
| 71 |
+
assert (
|
| 72 |
+
isinstance(dataset, TokenBlockDataset)
|
| 73 |
+
or isinstance(dataset, BlockPairDataset)
|
| 74 |
+
or isinstance(dataset, ConcatDataset)
|
| 75 |
+
), (
|
| 76 |
+
"MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or "
|
| 77 |
+
"ConcatDataset"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
self.dataset = dataset
|
| 81 |
+
self.sizes = np.array(sizes)
|
| 82 |
+
self.vocab = vocab
|
| 83 |
+
self.pad_idx = pad_idx
|
| 84 |
+
self.mask_idx = mask_idx
|
| 85 |
+
self.classif_token_idx = classif_token_idx
|
| 86 |
+
self.sep_token_idx = sep_token_idx
|
| 87 |
+
self.shuffle = shuffle
|
| 88 |
+
self.seed = seed
|
| 89 |
+
self.has_pairs = has_pairs
|
| 90 |
+
self.segment_id = segment_id
|
| 91 |
+
self.masking_ratio = masking_ratio
|
| 92 |
+
self.masking_prob = masking_prob
|
| 93 |
+
self.random_token_prob = random_token_prob
|
| 94 |
+
|
| 95 |
+
# If we have only one block then sizes needs to be updated to include
|
| 96 |
+
# the classification token
|
| 97 |
+
if not has_pairs:
|
| 98 |
+
self.sizes = self.sizes + 1
|
| 99 |
+
|
| 100 |
+
def __getitem__(self, index: int):
|
| 101 |
+
# if has_pairs, then expect 2 blocks and a sentence target
|
| 102 |
+
if self.has_pairs:
|
| 103 |
+
(block_one, block_two, sentence_target) = self.dataset[index]
|
| 104 |
+
else:
|
| 105 |
+
block_one = self.dataset[index]
|
| 106 |
+
|
| 107 |
+
return {
|
| 108 |
+
"id": index,
|
| 109 |
+
"block_one": block_one,
|
| 110 |
+
"block_two": block_two if self.has_pairs else None,
|
| 111 |
+
"sentence_target": sentence_target if self.has_pairs else None,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
def __len__(self):
|
| 115 |
+
return len(self.dataset)
|
| 116 |
+
|
| 117 |
+
def _mask_block(
|
| 118 |
+
self,
|
| 119 |
+
sentence: np.ndarray,
|
| 120 |
+
mask_idx: int,
|
| 121 |
+
pad_idx: int,
|
| 122 |
+
dictionary_token_range: Tuple,
|
| 123 |
+
):
|
| 124 |
+
"""
|
| 125 |
+
Mask tokens for Masked Language Model training
|
| 126 |
+
Samples mask_ratio tokens that will be predicted by LM.
|
| 127 |
+
|
| 128 |
+
Note:This function may not be efficient enough since we had multiple
|
| 129 |
+
conversions between np and torch, we can replace them with torch
|
| 130 |
+
operators later.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
sentence: 1d tensor to be masked
|
| 134 |
+
mask_idx: index to use for masking the sentence
|
| 135 |
+
pad_idx: index to use for masking the target for tokens we aren't
|
| 136 |
+
predicting
|
| 137 |
+
dictionary_token_range: range of indices in dictionary which can
|
| 138 |
+
be used for random word replacement
|
| 139 |
+
(e.g. without special characters)
|
| 140 |
+
Return:
|
| 141 |
+
masked_sent: masked sentence
|
| 142 |
+
target: target with words which we are not predicting replaced
|
| 143 |
+
by pad_idx
|
| 144 |
+
"""
|
| 145 |
+
masked_sent = np.copy(sentence)
|
| 146 |
+
sent_length = len(sentence)
|
| 147 |
+
mask_num = math.ceil(sent_length * self.masking_ratio)
|
| 148 |
+
mask = np.random.choice(sent_length, mask_num, replace=False)
|
| 149 |
+
target = np.copy(sentence)
|
| 150 |
+
|
| 151 |
+
for i in range(sent_length):
|
| 152 |
+
if i in mask:
|
| 153 |
+
rand = np.random.random()
|
| 154 |
+
|
| 155 |
+
# replace with mask if probability is less than masking_prob
|
| 156 |
+
# (Eg: 0.8)
|
| 157 |
+
if rand < self.masking_prob:
|
| 158 |
+
masked_sent[i] = mask_idx
|
| 159 |
+
|
| 160 |
+
# replace with random token if probability is less than
|
| 161 |
+
# masking_prob + random_token_prob (Eg: 0.9)
|
| 162 |
+
elif rand < (self.masking_prob + self.random_token_prob):
|
| 163 |
+
# sample random token from dictionary
|
| 164 |
+
masked_sent[i] = np.random.randint(
|
| 165 |
+
dictionary_token_range[0], dictionary_token_range[1]
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
target[i] = pad_idx
|
| 169 |
+
|
| 170 |
+
return masked_sent, target
|
| 171 |
+
|
| 172 |
+
def _collate(self, samples: List[Dict], pad_idx: int, eos_idx: int):
|
| 173 |
+
"""
|
| 174 |
+
Does the heavy lifting for creating a batch from the input list of
|
| 175 |
+
examples. The logic is as follows:
|
| 176 |
+
1. Mask the input blocks. In case has_pair is True then we have 2
|
| 177 |
+
blocks to mask.
|
| 178 |
+
2. Prepend the first masked block tensor with the special token
|
| 179 |
+
used as sentence embedding. Eg: CLS in BERT. This happens
|
| 180 |
+
irrespective of the value of has_pair.
|
| 181 |
+
3. If has_pair is True, then append the first masked block with the
|
| 182 |
+
special separator token (eg: SEP for BERT) and compute segment
|
| 183 |
+
label accordingly. In this case, also append the second masked
|
| 184 |
+
block with this special separator token and compute its segment
|
| 185 |
+
label.
|
| 186 |
+
4. For the targets tensor, prepend and append with padding index
|
| 187 |
+
accordingly.
|
| 188 |
+
5. Concatenate all tensors.
|
| 189 |
+
"""
|
| 190 |
+
if len(samples) == 0:
|
| 191 |
+
return {}
|
| 192 |
+
# To ensure determinism, we reset the state of the PRNG after every
|
| 193 |
+
# batch based on the seed and the first id of the batch. This ensures
|
| 194 |
+
# that across epochs we get the same mask for the same example. This
|
| 195 |
+
# is needed for reproducibility and is how BERT does masking
|
| 196 |
+
# TODO: Can we add deteminism without this constraint?
|
| 197 |
+
with data_utils.numpy_seed(self.seed + samples[0]["id"]):
|
| 198 |
+
for s in samples:
|
| 199 |
+
|
| 200 |
+
# token range is needed for replacing with random token during
|
| 201 |
+
# masking
|
| 202 |
+
token_range = (self.vocab.nspecial, len(self.vocab))
|
| 203 |
+
|
| 204 |
+
# mask according to specified probabilities.
|
| 205 |
+
masked_blk_one, masked_tgt_one = self._mask_block(
|
| 206 |
+
s["block_one"],
|
| 207 |
+
self.mask_idx,
|
| 208 |
+
self.pad_idx,
|
| 209 |
+
token_range,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
tokens = np.concatenate([[self.classif_token_idx], masked_blk_one])
|
| 213 |
+
targets = np.concatenate([[self.pad_idx], masked_tgt_one])
|
| 214 |
+
segments = np.ones(len(tokens)) * self.segment_id
|
| 215 |
+
|
| 216 |
+
# if has_pairs is True then we need to add the SEP token to both
|
| 217 |
+
# the blocks after masking and re-compute segments based on the new
|
| 218 |
+
# lengths.
|
| 219 |
+
if self.has_pairs:
|
| 220 |
+
tokens_one = np.concatenate([tokens, [self.sep_token_idx]])
|
| 221 |
+
targets_one = np.concatenate([targets, [self.pad_idx]])
|
| 222 |
+
|
| 223 |
+
masked_blk_two, masked_tgt_two = self._mask_block(
|
| 224 |
+
s["block_two"], self.mask_idx, self.pad_idx, token_range
|
| 225 |
+
)
|
| 226 |
+
tokens_two = np.concatenate([masked_blk_two, [self.sep_token_idx]])
|
| 227 |
+
targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]])
|
| 228 |
+
|
| 229 |
+
# block + 1 sep + 1 special (CLS)
|
| 230 |
+
segments_one = np.zeros(len(tokens_one))
|
| 231 |
+
# block + 1 sep
|
| 232 |
+
segments_two = np.ones(len(tokens_two))
|
| 233 |
+
|
| 234 |
+
tokens = np.concatenate([tokens_one, tokens_two])
|
| 235 |
+
targets = np.concatenate([targets_one, targets_two])
|
| 236 |
+
segments = np.concatenate([segments_one, segments_two])
|
| 237 |
+
|
| 238 |
+
s["source"] = torch.LongTensor(tokens)
|
| 239 |
+
s["segment_labels"] = torch.LongTensor(segments)
|
| 240 |
+
s["lm_target"] = torch.LongTensor(targets)
|
| 241 |
+
|
| 242 |
+
def merge(key):
|
| 243 |
+
return data_utils.collate_tokens(
|
| 244 |
+
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
return {
|
| 248 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
| 249 |
+
"ntokens": sum(len(s["source"]) for s in samples),
|
| 250 |
+
"net_input": {
|
| 251 |
+
"src_tokens": merge("source"),
|
| 252 |
+
"segment_labels": merge("segment_labels"),
|
| 253 |
+
},
|
| 254 |
+
"lm_target": merge("lm_target"),
|
| 255 |
+
"sentence_target": torch.LongTensor([s["sentence_target"] for s in samples])
|
| 256 |
+
if self.has_pairs
|
| 257 |
+
else None,
|
| 258 |
+
"nsentences": len(samples),
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
def collater(self, samples: List[Dict]):
|
| 262 |
+
"""Merge a list of samples to form a mini-batch.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
samples (List[dict]): samples to collate
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
dict: a mini-batch of data
|
| 269 |
+
"""
|
| 270 |
+
return self._collate(samples, self.vocab.pad(), self.vocab.eos())
|
| 271 |
+
|
| 272 |
+
def num_tokens(self, index: int):
|
| 273 |
+
"""
|
| 274 |
+
Return the number of tokens in a sample. This value is used to
|
| 275 |
+
enforce max-tokens during batching.
|
| 276 |
+
"""
|
| 277 |
+
return self.sizes[index]
|
| 278 |
+
|
| 279 |
+
def size(self, index: int):
|
| 280 |
+
"""
|
| 281 |
+
Return an example's size as a float or tuple. This value is used when
|
| 282 |
+
filtering a dataset with max-positions.
|
| 283 |
+
"""
|
| 284 |
+
return self.sizes[index]
|
| 285 |
+
|
| 286 |
+
def ordered_indices(self):
|
| 287 |
+
"""
|
| 288 |
+
Return an ordered list of indices. Batches will be constructed based
|
| 289 |
+
on this order.
|
| 290 |
+
"""
|
| 291 |
+
if self.shuffle:
|
| 292 |
+
return np.random.permutation(len(self))
|
| 293 |
+
else:
|
| 294 |
+
order = [np.arange(len(self))]
|
| 295 |
+
order.append(self.sizes)
|
| 296 |
+
return np.lexsort(order)
|
| 297 |
+
|
| 298 |
+
@property
|
| 299 |
+
def supports_prefetch(self):
|
| 300 |
+
return getattr(self.dataset, "supports_prefetch", False)
|
| 301 |
+
|
| 302 |
+
def prefetch(self, indices):
|
| 303 |
+
self.dataset.prefetch(indices)
|
fairseq-0.10.2/fairseq/data/lm_context_window_dataset.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from fairseq.data.monolingual_dataset import MonolingualDataset
|
| 9 |
+
|
| 10 |
+
from . import FairseqDataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LMContextWindowDataset(FairseqDataset):
|
| 14 |
+
"""Wraps a MonolingualDataset and provides more context for evaluation."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, dataset, tokens_per_sample, context_window, pad_idx):
|
| 17 |
+
assert isinstance(dataset, MonolingualDataset)
|
| 18 |
+
assert context_window > 0
|
| 19 |
+
self.dataset = dataset
|
| 20 |
+
self.tokens_per_sample = tokens_per_sample
|
| 21 |
+
self.context_window = context_window
|
| 22 |
+
self.pad_idx = pad_idx
|
| 23 |
+
self.prev_tokens = np.empty([0])
|
| 24 |
+
|
| 25 |
+
def __getitem__(self, index):
|
| 26 |
+
return self.dataset[index]
|
| 27 |
+
|
| 28 |
+
def __len__(self):
|
| 29 |
+
return len(self.dataset)
|
| 30 |
+
|
| 31 |
+
def collater(self, samples):
|
| 32 |
+
sample = self.dataset.collater(samples)
|
| 33 |
+
|
| 34 |
+
pad = self.pad_idx
|
| 35 |
+
max_sample_len = self.tokens_per_sample + self.context_window
|
| 36 |
+
|
| 37 |
+
bsz, tsz = sample["net_input"]["src_tokens"].shape
|
| 38 |
+
start_idxs = [0] * bsz
|
| 39 |
+
toks = sample["net_input"]["src_tokens"]
|
| 40 |
+
lengths = sample["net_input"]["src_lengths"]
|
| 41 |
+
tgt = sample["target"]
|
| 42 |
+
new_toks = np.empty([bsz, tsz + self.context_window], dtype=np.int64)
|
| 43 |
+
new_tgt = np.full([bsz, tsz + self.context_window], pad, dtype=np.int64)
|
| 44 |
+
sample_lens = toks.ne(pad).long().sum(dim=1).cpu()
|
| 45 |
+
for i in range(bsz):
|
| 46 |
+
sample_len = sample_lens[i]
|
| 47 |
+
extra = len(self.prev_tokens) + sample_len - max_sample_len
|
| 48 |
+
if extra > 0:
|
| 49 |
+
self.prev_tokens = self.prev_tokens[extra:]
|
| 50 |
+
pads = np.full(self.context_window - len(self.prev_tokens), pad)
|
| 51 |
+
new_toks[i] = np.concatenate([self.prev_tokens, toks[i].numpy(), pads])
|
| 52 |
+
new_tgt[
|
| 53 |
+
i, len(self.prev_tokens) : len(self.prev_tokens) + len(tgt[i])
|
| 54 |
+
] = tgt[i]
|
| 55 |
+
start_idxs[i] = len(self.prev_tokens)
|
| 56 |
+
lengths[i] += len(self.prev_tokens)
|
| 57 |
+
self.prev_tokens = new_toks[i][new_toks[i] != pad][-self.context_window :]
|
| 58 |
+
sample["net_input"]["src_tokens"] = torch.from_numpy(new_toks)
|
| 59 |
+
sample["target"] = torch.from_numpy(new_tgt)
|
| 60 |
+
sample["start_indices"] = start_idxs
|
| 61 |
+
|
| 62 |
+
return sample
|
| 63 |
+
|
| 64 |
+
def num_tokens(self, index):
|
| 65 |
+
return self.dataset.num_tokens(index)
|
| 66 |
+
|
| 67 |
+
def size(self, index):
|
| 68 |
+
return self.dataset.size(index)
|
| 69 |
+
|
| 70 |
+
def ordered_indices(self):
|
| 71 |
+
# NOTE we don't shuffle the data to retain access to the previous dataset elements
|
| 72 |
+
return np.arange(len(self.dataset))
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def supports_prefetch(self):
|
| 76 |
+
return getattr(self.dataset, "supports_prefetch", False)
|
| 77 |
+
|
| 78 |
+
def prefetch(self, indices):
|
| 79 |
+
return self.dataset.prefetch(indices)
|
fairseq-0.10.2/fairseq/data/monolingual_dataset.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from . import FairseqDataset, data_utils
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def collate(samples, pad_idx, eos_idx):
|
| 13 |
+
if len(samples) == 0:
|
| 14 |
+
return {}
|
| 15 |
+
|
| 16 |
+
def merge(key, is_list=False):
|
| 17 |
+
if is_list:
|
| 18 |
+
res = []
|
| 19 |
+
for i in range(len(samples[0][key])):
|
| 20 |
+
res.append(
|
| 21 |
+
data_utils.collate_tokens(
|
| 22 |
+
[s[key][i] for s in samples],
|
| 23 |
+
pad_idx,
|
| 24 |
+
eos_idx,
|
| 25 |
+
left_pad=False,
|
| 26 |
+
)
|
| 27 |
+
)
|
| 28 |
+
return res
|
| 29 |
+
else:
|
| 30 |
+
return data_utils.collate_tokens(
|
| 31 |
+
[s[key] for s in samples],
|
| 32 |
+
pad_idx,
|
| 33 |
+
eos_idx,
|
| 34 |
+
left_pad=False,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
src_tokens = merge("source")
|
| 38 |
+
if samples[0]["target"] is not None:
|
| 39 |
+
is_target_list = isinstance(samples[0]["target"], list)
|
| 40 |
+
target = merge("target", is_target_list)
|
| 41 |
+
else:
|
| 42 |
+
target = src_tokens
|
| 43 |
+
|
| 44 |
+
return {
|
| 45 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
| 46 |
+
"nsentences": len(samples),
|
| 47 |
+
"ntokens": sum(len(s["source"]) for s in samples),
|
| 48 |
+
"net_input": {
|
| 49 |
+
"src_tokens": src_tokens,
|
| 50 |
+
"src_lengths": torch.LongTensor([s["source"].numel() for s in samples]),
|
| 51 |
+
},
|
| 52 |
+
"target": target,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class MonolingualDataset(FairseqDataset):
|
| 57 |
+
"""
|
| 58 |
+
A wrapper around torch.utils.data.Dataset for monolingual data.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
dataset (torch.utils.data.Dataset): dataset to wrap
|
| 62 |
+
sizes (List[int]): sentence lengths
|
| 63 |
+
vocab (~fairseq.data.Dictionary): vocabulary
|
| 64 |
+
shuffle (bool, optional): shuffle the elements before batching
|
| 65 |
+
(default: True).
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
dataset,
|
| 71 |
+
sizes,
|
| 72 |
+
src_vocab,
|
| 73 |
+
tgt_vocab,
|
| 74 |
+
add_eos_for_other_targets,
|
| 75 |
+
shuffle,
|
| 76 |
+
targets=None,
|
| 77 |
+
add_bos_token=False,
|
| 78 |
+
):
|
| 79 |
+
self.dataset = dataset
|
| 80 |
+
self.sizes = np.array(sizes)
|
| 81 |
+
self.vocab = src_vocab
|
| 82 |
+
self.tgt_vocab = tgt_vocab
|
| 83 |
+
self.add_eos_for_other_targets = add_eos_for_other_targets
|
| 84 |
+
self.shuffle = shuffle
|
| 85 |
+
self.add_bos_token = add_bos_token
|
| 86 |
+
|
| 87 |
+
assert targets is None or all(
|
| 88 |
+
t in {"self", "future", "past"} for t in targets
|
| 89 |
+
), "targets must be none or one of 'self', 'future', 'past'"
|
| 90 |
+
if targets is not None and len(targets) == 0:
|
| 91 |
+
targets = None
|
| 92 |
+
self.targets = targets
|
| 93 |
+
|
| 94 |
+
def __getitem__(self, index):
|
| 95 |
+
if self.targets is not None:
|
| 96 |
+
# *future_target* is the original sentence
|
| 97 |
+
# *source* is shifted right by 1 (maybe left-padded with eos)
|
| 98 |
+
# *past_target* is shifted right by 2 (left-padded as needed)
|
| 99 |
+
#
|
| 100 |
+
# Left-to-right language models should condition on *source* and
|
| 101 |
+
# predict *future_target*.
|
| 102 |
+
# Right-to-left language models should condition on *source* and
|
| 103 |
+
# predict *past_target*.
|
| 104 |
+
source, future_target, past_target = self.dataset[index]
|
| 105 |
+
source, target = self._make_source_target(
|
| 106 |
+
source, future_target, past_target
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
source = self.dataset[index]
|
| 110 |
+
target = None
|
| 111 |
+
source, target = self._maybe_add_bos(source, target)
|
| 112 |
+
return {"id": index, "source": source, "target": target}
|
| 113 |
+
|
| 114 |
+
def __len__(self):
|
| 115 |
+
return len(self.dataset)
|
| 116 |
+
|
| 117 |
+
def _make_source_target(self, source, future_target, past_target):
|
| 118 |
+
if self.targets is not None:
|
| 119 |
+
target = []
|
| 120 |
+
|
| 121 |
+
if (
|
| 122 |
+
self.add_eos_for_other_targets
|
| 123 |
+
and (("self" in self.targets) or ("past" in self.targets))
|
| 124 |
+
and source[-1] != self.vocab.eos()
|
| 125 |
+
):
|
| 126 |
+
# append eos at the end of source
|
| 127 |
+
source = torch.cat([source, source.new([self.vocab.eos()])])
|
| 128 |
+
|
| 129 |
+
if "future" in self.targets:
|
| 130 |
+
future_target = torch.cat(
|
| 131 |
+
[future_target, future_target.new([self.vocab.pad()])]
|
| 132 |
+
)
|
| 133 |
+
if "past" in self.targets:
|
| 134 |
+
# first token is before the start of sentence which is only used in "none" break mode when
|
| 135 |
+
# add_eos_for_other_targets is False
|
| 136 |
+
past_target = torch.cat(
|
| 137 |
+
[
|
| 138 |
+
past_target.new([self.vocab.pad()]),
|
| 139 |
+
past_target[1:],
|
| 140 |
+
source[-2, None],
|
| 141 |
+
]
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
for t in self.targets:
|
| 145 |
+
if t == "self":
|
| 146 |
+
target.append(source)
|
| 147 |
+
elif t == "future":
|
| 148 |
+
target.append(future_target)
|
| 149 |
+
elif t == "past":
|
| 150 |
+
target.append(past_target)
|
| 151 |
+
else:
|
| 152 |
+
raise Exception("invalid target " + t)
|
| 153 |
+
|
| 154 |
+
if len(target) == 1:
|
| 155 |
+
target = target[0]
|
| 156 |
+
else:
|
| 157 |
+
target = future_target
|
| 158 |
+
|
| 159 |
+
return source, self._filter_vocab(target)
|
| 160 |
+
|
| 161 |
+
def _maybe_add_bos(self, source, target):
|
| 162 |
+
if self.add_bos_token:
|
| 163 |
+
source = torch.cat([source.new([self.vocab.bos()]), source])
|
| 164 |
+
if target is not None:
|
| 165 |
+
target = torch.cat([target.new([self.tgt_vocab.bos()]), target])
|
| 166 |
+
return source, target
|
| 167 |
+
|
| 168 |
+
def _filter_vocab(self, target):
|
| 169 |
+
if len(self.tgt_vocab) != len(self.vocab):
|
| 170 |
+
|
| 171 |
+
def _filter(target):
|
| 172 |
+
mask = target.ge(len(self.tgt_vocab))
|
| 173 |
+
if mask.any():
|
| 174 |
+
target[mask] = self.tgt_vocab.unk()
|
| 175 |
+
return target
|
| 176 |
+
|
| 177 |
+
if isinstance(target, list):
|
| 178 |
+
return [_filter(t) for t in target]
|
| 179 |
+
return _filter(target)
|
| 180 |
+
return target
|
| 181 |
+
|
| 182 |
+
def collater(self, samples):
|
| 183 |
+
"""Merge a list of samples to form a mini-batch.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
samples (List[dict]): samples to collate
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
dict: a mini-batch with the following keys:
|
| 190 |
+
|
| 191 |
+
- `id` (LongTensor): example IDs in the original input order
|
| 192 |
+
- `ntokens` (int): total number of tokens in the batch
|
| 193 |
+
- `net_input` (dict): the input to the Model, containing keys:
|
| 194 |
+
|
| 195 |
+
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
|
| 196 |
+
the source sentence of shape `(bsz, src_len)`. Padding will
|
| 197 |
+
appear on the right.
|
| 198 |
+
|
| 199 |
+
- `target` (LongTensor): a padded 2D Tensor of tokens in the
|
| 200 |
+
target sentence of shape `(bsz, tgt_len)`. Padding will appear
|
| 201 |
+
on the right.
|
| 202 |
+
"""
|
| 203 |
+
return collate(samples, self.vocab.pad(), self.vocab.eos())
|
| 204 |
+
|
| 205 |
+
def num_tokens(self, index):
|
| 206 |
+
"""Return the number of tokens in a sample. This value is used to
|
| 207 |
+
enforce ``--max-tokens`` during batching."""
|
| 208 |
+
return self.sizes[index]
|
| 209 |
+
|
| 210 |
+
def size(self, index):
|
| 211 |
+
"""Return an example's size as a float or tuple. This value is used when
|
| 212 |
+
filtering a dataset with ``--max-positions``."""
|
| 213 |
+
return self.sizes[index]
|
| 214 |
+
|
| 215 |
+
def ordered_indices(self):
|
| 216 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
| 217 |
+
on this order."""
|
| 218 |
+
if self.shuffle:
|
| 219 |
+
order = [np.random.permutation(len(self))]
|
| 220 |
+
else:
|
| 221 |
+
order = [np.arange(len(self))]
|
| 222 |
+
order.append(self.sizes)
|
| 223 |
+
return np.lexsort(order)
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def supports_prefetch(self):
|
| 227 |
+
return getattr(self.dataset, "supports_prefetch", False)
|
| 228 |
+
|
| 229 |
+
def prefetch(self, indices):
|
| 230 |
+
self.dataset.prefetch(indices)
|
fairseq-0.10.2/fairseq/data/nested_dictionary_dataset.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data.dataloader import default_collate
|
| 10 |
+
|
| 11 |
+
from . import FairseqDataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _flatten(dico, prefix=None):
|
| 15 |
+
"""Flatten a nested dictionary."""
|
| 16 |
+
new_dico = OrderedDict()
|
| 17 |
+
if isinstance(dico, dict):
|
| 18 |
+
prefix = prefix + "." if prefix is not None else ""
|
| 19 |
+
for k, v in dico.items():
|
| 20 |
+
if v is None:
|
| 21 |
+
continue
|
| 22 |
+
new_dico.update(_flatten(v, prefix + k))
|
| 23 |
+
elif isinstance(dico, list):
|
| 24 |
+
for i, v in enumerate(dico):
|
| 25 |
+
new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]"))
|
| 26 |
+
else:
|
| 27 |
+
new_dico = OrderedDict({prefix: dico})
|
| 28 |
+
return new_dico
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _unflatten(dico):
|
| 32 |
+
"""Unflatten a flattened dictionary into a nested dictionary."""
|
| 33 |
+
new_dico = OrderedDict()
|
| 34 |
+
for full_k, v in dico.items():
|
| 35 |
+
full_k = full_k.split(".")
|
| 36 |
+
node = new_dico
|
| 37 |
+
for k in full_k[:-1]:
|
| 38 |
+
if k.startswith("[") and k.endswith("]"):
|
| 39 |
+
k = int(k[1:-1])
|
| 40 |
+
if k not in node:
|
| 41 |
+
node[k] = OrderedDict()
|
| 42 |
+
node = node[k]
|
| 43 |
+
node[full_k[-1]] = v
|
| 44 |
+
return new_dico
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class NestedDictionaryDataset(FairseqDataset):
|
| 48 |
+
def __init__(self, defn, sizes=None):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.defn = _flatten(defn)
|
| 51 |
+
self.sizes = [sizes] if not isinstance(sizes, (list, tuple)) else sizes
|
| 52 |
+
|
| 53 |
+
first = None
|
| 54 |
+
for v in self.defn.values():
|
| 55 |
+
if not isinstance(
|
| 56 |
+
v,
|
| 57 |
+
(
|
| 58 |
+
FairseqDataset,
|
| 59 |
+
torch.utils.data.Dataset,
|
| 60 |
+
),
|
| 61 |
+
):
|
| 62 |
+
raise ValueError("Expected Dataset but found: {}".format(v.__class__))
|
| 63 |
+
first = first or v
|
| 64 |
+
if len(v) > 0:
|
| 65 |
+
assert len(v) == len(first), "dataset lengths must match"
|
| 66 |
+
|
| 67 |
+
self._len = len(first)
|
| 68 |
+
|
| 69 |
+
def __getitem__(self, index):
|
| 70 |
+
return OrderedDict((k, ds[index]) for k, ds in self.defn.items())
|
| 71 |
+
|
| 72 |
+
def __len__(self):
|
| 73 |
+
return self._len
|
| 74 |
+
|
| 75 |
+
def collater(self, samples):
|
| 76 |
+
"""Merge a list of samples to form a mini-batch.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
samples (List[dict]): samples to collate
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
dict: a mini-batch suitable for forwarding with a Model
|
| 83 |
+
"""
|
| 84 |
+
if len(samples) == 0:
|
| 85 |
+
return {}
|
| 86 |
+
sample = OrderedDict()
|
| 87 |
+
for k, ds in self.defn.items():
|
| 88 |
+
try:
|
| 89 |
+
sample[k] = ds.collater([s[k] for s in samples])
|
| 90 |
+
except NotImplementedError:
|
| 91 |
+
sample[k] = default_collate([s[k] for s in samples])
|
| 92 |
+
return _unflatten(sample)
|
| 93 |
+
|
| 94 |
+
def num_tokens(self, index):
|
| 95 |
+
"""Return the number of tokens in a sample. This value is used to
|
| 96 |
+
enforce ``--max-tokens`` during batching."""
|
| 97 |
+
return max(s[index] for s in self.sizes)
|
| 98 |
+
|
| 99 |
+
def size(self, index):
|
| 100 |
+
"""Return an example's size as a float or tuple. This value is used when
|
| 101 |
+
filtering a dataset with ``--max-positions``."""
|
| 102 |
+
if len(self.sizes) == 1:
|
| 103 |
+
return self.sizes[0][index]
|
| 104 |
+
else:
|
| 105 |
+
return (s[index] for s in self.sizes)
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def supports_prefetch(self):
|
| 109 |
+
"""Whether this dataset supports prefetching."""
|
| 110 |
+
return any(ds.supports_prefetch for ds in self.defn.values())
|
| 111 |
+
|
| 112 |
+
def prefetch(self, indices):
|
| 113 |
+
"""Prefetch the data required for this epoch."""
|
| 114 |
+
for ds in self.defn.values():
|
| 115 |
+
if getattr(ds, "supports_prefetch", False):
|
| 116 |
+
ds.prefetch(indices)
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
| 120 |
+
return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values())
|
| 121 |
+
|
| 122 |
+
def set_epoch(self, epoch):
|
| 123 |
+
super().set_epoch(epoch)
|
| 124 |
+
for ds in self.defn.values():
|
| 125 |
+
ds.set_epoch(epoch)
|
fairseq-0.10.2/fairseq/data/noising.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from fairseq.data import data_utils
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class WordNoising(object):
|
| 12 |
+
"""Generate a noisy version of a sentence, without changing words themselves."""
|
| 13 |
+
|
| 14 |
+
def __init__(self, dictionary, bpe_cont_marker="@@", bpe_end_marker=None):
|
| 15 |
+
self.dictionary = dictionary
|
| 16 |
+
self.bpe_end = None
|
| 17 |
+
if bpe_cont_marker:
|
| 18 |
+
self.bpe_end = np.array(
|
| 19 |
+
[
|
| 20 |
+
not self.dictionary[i].endswith(bpe_cont_marker)
|
| 21 |
+
for i in range(len(self.dictionary))
|
| 22 |
+
]
|
| 23 |
+
)
|
| 24 |
+
elif bpe_end_marker:
|
| 25 |
+
self.bpe_end = np.array(
|
| 26 |
+
[
|
| 27 |
+
self.dictionary[i].endswith(bpe_end_marker)
|
| 28 |
+
for i in range(len(self.dictionary))
|
| 29 |
+
]
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
self.get_word_idx = (
|
| 33 |
+
self._get_bpe_word_idx if self.bpe_end is not None else self._get_token_idx
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def noising(self, x, lengths, noising_prob=0.0):
|
| 37 |
+
raise NotImplementedError()
|
| 38 |
+
|
| 39 |
+
def _get_bpe_word_idx(self, x):
|
| 40 |
+
"""
|
| 41 |
+
Given a list of BPE tokens, for every index in the tokens list,
|
| 42 |
+
return the index of the word grouping that it belongs to.
|
| 43 |
+
For example, for input x corresponding to ["how", "are", "y@@", "ou"],
|
| 44 |
+
return [[0], [1], [2], [2]].
|
| 45 |
+
"""
|
| 46 |
+
# x: (T x B)
|
| 47 |
+
bpe_end = self.bpe_end[x]
|
| 48 |
+
|
| 49 |
+
if x.size(0) == 1 and x.size(1) == 1:
|
| 50 |
+
# Special case when we only have one word in x. If x = [[N]],
|
| 51 |
+
# bpe_end is a scalar (bool) instead of a 2-dim array of bools,
|
| 52 |
+
# which makes the sum operation below fail.
|
| 53 |
+
return np.array([[0]])
|
| 54 |
+
|
| 55 |
+
# do a reduce front sum to generate word ids
|
| 56 |
+
word_idx = bpe_end[::-1].cumsum(0)[::-1]
|
| 57 |
+
word_idx = word_idx.max(0)[None, :] - word_idx
|
| 58 |
+
return word_idx
|
| 59 |
+
|
| 60 |
+
def _get_token_idx(self, x):
|
| 61 |
+
"""
|
| 62 |
+
This is to extend noising functions to be able to apply to non-bpe
|
| 63 |
+
tokens, e.g. word or characters.
|
| 64 |
+
"""
|
| 65 |
+
x = torch.t(x)
|
| 66 |
+
word_idx = np.array([range(len(x_i)) for x_i in x])
|
| 67 |
+
return np.transpose(word_idx)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class WordDropout(WordNoising):
|
| 71 |
+
"""Randomly drop input words. If not passing blank_idx (default is None),
|
| 72 |
+
then dropped words will be removed. Otherwise, it will be replaced by the
|
| 73 |
+
blank_idx."""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
dictionary,
|
| 78 |
+
default_dropout_prob=0.1,
|
| 79 |
+
bpe_cont_marker="@@",
|
| 80 |
+
bpe_end_marker=None,
|
| 81 |
+
):
|
| 82 |
+
super().__init__(dictionary, bpe_cont_marker, bpe_end_marker)
|
| 83 |
+
self.default_dropout_prob = default_dropout_prob
|
| 84 |
+
|
| 85 |
+
def noising(self, x, lengths, dropout_prob=None, blank_idx=None):
|
| 86 |
+
if dropout_prob is None:
|
| 87 |
+
dropout_prob = self.default_dropout_prob
|
| 88 |
+
# x: (T x B), lengths: B
|
| 89 |
+
if dropout_prob == 0:
|
| 90 |
+
return x, lengths
|
| 91 |
+
|
| 92 |
+
assert 0 < dropout_prob < 1
|
| 93 |
+
|
| 94 |
+
# be sure to drop entire words
|
| 95 |
+
word_idx = self.get_word_idx(x)
|
| 96 |
+
sentences = []
|
| 97 |
+
modified_lengths = []
|
| 98 |
+
for i in range(lengths.size(0)):
|
| 99 |
+
# Since dropout probabilities need to apply over non-pad tokens,
|
| 100 |
+
# it is not trivial to generate the keep mask without consider
|
| 101 |
+
# input lengths; otherwise, this could be done outside the loop
|
| 102 |
+
|
| 103 |
+
# We want to drop whole words based on word_idx grouping
|
| 104 |
+
num_words = max(word_idx[:, i]) + 1
|
| 105 |
+
|
| 106 |
+
# ith example: [x0, x1, ..., eos, pad, ..., pad]
|
| 107 |
+
# We should only generate keep probs for non-EOS tokens. Thus if the
|
| 108 |
+
# input sentence ends in EOS, the last word idx is not included in
|
| 109 |
+
# the dropout mask generation and we append True to always keep EOS.
|
| 110 |
+
# Otherwise, just generate the dropout mask for all word idx
|
| 111 |
+
# positions.
|
| 112 |
+
has_eos = x[lengths[i] - 1, i] == self.dictionary.eos()
|
| 113 |
+
if has_eos: # has eos?
|
| 114 |
+
keep = np.random.rand(num_words - 1) >= dropout_prob
|
| 115 |
+
keep = np.append(keep, [True]) # keep EOS symbol
|
| 116 |
+
else:
|
| 117 |
+
keep = np.random.rand(num_words) >= dropout_prob
|
| 118 |
+
|
| 119 |
+
words = x[: lengths[i], i].tolist()
|
| 120 |
+
|
| 121 |
+
# TODO: speed up the following loop
|
| 122 |
+
# drop words from the input according to keep
|
| 123 |
+
new_s = [
|
| 124 |
+
w if keep[word_idx[j, i]] else blank_idx for j, w in enumerate(words)
|
| 125 |
+
]
|
| 126 |
+
new_s = [w for w in new_s if w is not None]
|
| 127 |
+
# we need to have at least one word in the sentence (more than the
|
| 128 |
+
# start / end sentence symbols)
|
| 129 |
+
if len(new_s) <= 1:
|
| 130 |
+
# insert at beginning in case the only token left is EOS
|
| 131 |
+
# EOS should be at end of list.
|
| 132 |
+
new_s.insert(0, words[np.random.randint(0, len(words))])
|
| 133 |
+
assert len(new_s) >= 1 and (
|
| 134 |
+
not has_eos # Either don't have EOS at end or last token is EOS
|
| 135 |
+
or (len(new_s) >= 2 and new_s[-1] == self.dictionary.eos())
|
| 136 |
+
), "New sentence is invalid."
|
| 137 |
+
sentences.append(new_s)
|
| 138 |
+
modified_lengths.append(len(new_s))
|
| 139 |
+
# re-construct input
|
| 140 |
+
modified_lengths = torch.LongTensor(modified_lengths)
|
| 141 |
+
modified_x = torch.LongTensor(
|
| 142 |
+
modified_lengths.max(), modified_lengths.size(0)
|
| 143 |
+
).fill_(self.dictionary.pad())
|
| 144 |
+
for i in range(modified_lengths.size(0)):
|
| 145 |
+
modified_x[: modified_lengths[i], i].copy_(torch.LongTensor(sentences[i]))
|
| 146 |
+
|
| 147 |
+
return modified_x, modified_lengths
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class WordShuffle(WordNoising):
|
| 151 |
+
"""Shuffle words by no more than k positions."""
|
| 152 |
+
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
dictionary,
|
| 156 |
+
default_max_shuffle_distance=3,
|
| 157 |
+
bpe_cont_marker="@@",
|
| 158 |
+
bpe_end_marker=None,
|
| 159 |
+
):
|
| 160 |
+
super().__init__(dictionary, bpe_cont_marker, bpe_end_marker)
|
| 161 |
+
self.default_max_shuffle_distance = 3
|
| 162 |
+
|
| 163 |
+
def noising(self, x, lengths, max_shuffle_distance=None):
|
| 164 |
+
if max_shuffle_distance is None:
|
| 165 |
+
max_shuffle_distance = self.default_max_shuffle_distance
|
| 166 |
+
# x: (T x B), lengths: B
|
| 167 |
+
if max_shuffle_distance == 0:
|
| 168 |
+
return x, lengths
|
| 169 |
+
|
| 170 |
+
# max_shuffle_distance < 1 will return the same sequence
|
| 171 |
+
assert max_shuffle_distance > 1
|
| 172 |
+
|
| 173 |
+
# define noise word scores
|
| 174 |
+
noise = np.random.uniform(
|
| 175 |
+
0,
|
| 176 |
+
max_shuffle_distance,
|
| 177 |
+
size=(x.size(0), x.size(1)),
|
| 178 |
+
)
|
| 179 |
+
noise[0] = -1 # do not move start sentence symbol
|
| 180 |
+
# be sure to shuffle entire words
|
| 181 |
+
word_idx = self.get_word_idx(x)
|
| 182 |
+
x2 = x.clone()
|
| 183 |
+
for i in range(lengths.size(0)):
|
| 184 |
+
length_no_eos = lengths[i]
|
| 185 |
+
if x[lengths[i] - 1, i] == self.dictionary.eos():
|
| 186 |
+
length_no_eos = lengths[i] - 1
|
| 187 |
+
# generate a random permutation
|
| 188 |
+
scores = word_idx[:length_no_eos, i] + noise[word_idx[:length_no_eos, i], i]
|
| 189 |
+
# ensure no reordering inside a word
|
| 190 |
+
scores += 1e-6 * np.arange(length_no_eos.item())
|
| 191 |
+
permutation = scores.argsort()
|
| 192 |
+
# shuffle words
|
| 193 |
+
x2[:length_no_eos, i].copy_(
|
| 194 |
+
x2[:length_no_eos, i][torch.from_numpy(permutation)]
|
| 195 |
+
)
|
| 196 |
+
return x2, lengths
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class UnsupervisedMTNoising(WordNoising):
|
| 200 |
+
"""
|
| 201 |
+
Implements the default configuration for noising in UnsupervisedMT
|
| 202 |
+
(github.com/facebookresearch/UnsupervisedMT)
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
dictionary,
|
| 208 |
+
max_word_shuffle_distance,
|
| 209 |
+
word_dropout_prob,
|
| 210 |
+
word_blanking_prob,
|
| 211 |
+
bpe_cont_marker="@@",
|
| 212 |
+
bpe_end_marker=None,
|
| 213 |
+
):
|
| 214 |
+
super().__init__(dictionary)
|
| 215 |
+
self.max_word_shuffle_distance = max_word_shuffle_distance
|
| 216 |
+
self.word_dropout_prob = word_dropout_prob
|
| 217 |
+
self.word_blanking_prob = word_blanking_prob
|
| 218 |
+
|
| 219 |
+
self.word_dropout = WordDropout(
|
| 220 |
+
dictionary=dictionary,
|
| 221 |
+
bpe_cont_marker=bpe_cont_marker,
|
| 222 |
+
bpe_end_marker=bpe_end_marker,
|
| 223 |
+
)
|
| 224 |
+
self.word_shuffle = WordShuffle(
|
| 225 |
+
dictionary=dictionary,
|
| 226 |
+
bpe_cont_marker=bpe_cont_marker,
|
| 227 |
+
bpe_end_marker=bpe_end_marker,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
def noising(self, x, lengths):
|
| 231 |
+
# 1. Word Shuffle
|
| 232 |
+
noisy_src_tokens, noisy_src_lengths = self.word_shuffle.noising(
|
| 233 |
+
x=x,
|
| 234 |
+
lengths=lengths,
|
| 235 |
+
max_shuffle_distance=self.max_word_shuffle_distance,
|
| 236 |
+
)
|
| 237 |
+
# 2. Word Dropout
|
| 238 |
+
noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising(
|
| 239 |
+
x=noisy_src_tokens,
|
| 240 |
+
lengths=noisy_src_lengths,
|
| 241 |
+
dropout_prob=self.word_dropout_prob,
|
| 242 |
+
)
|
| 243 |
+
# 3. Word Blanking
|
| 244 |
+
noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising(
|
| 245 |
+
x=noisy_src_tokens,
|
| 246 |
+
lengths=noisy_src_lengths,
|
| 247 |
+
dropout_prob=self.word_blanking_prob,
|
| 248 |
+
blank_idx=self.dictionary.unk(),
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
return noisy_src_tokens
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class NoisingDataset(torch.utils.data.Dataset):
|
| 255 |
+
def __init__(
|
| 256 |
+
self,
|
| 257 |
+
src_dataset,
|
| 258 |
+
src_dict,
|
| 259 |
+
seed,
|
| 260 |
+
noiser=None,
|
| 261 |
+
noising_class=UnsupervisedMTNoising,
|
| 262 |
+
**kwargs
|
| 263 |
+
):
|
| 264 |
+
"""
|
| 265 |
+
Wrap a :class:`~torch.utils.data.Dataset` and apply noise to the
|
| 266 |
+
samples based on the supplied noising configuration.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
src_dataset (~torch.utils.data.Dataset): dataset to wrap.
|
| 270 |
+
to build self.src_dataset --
|
| 271 |
+
a LanguagePairDataset with src dataset as the source dataset and
|
| 272 |
+
None as the target dataset. Should NOT have padding so that
|
| 273 |
+
src_lengths are accurately calculated by language_pair_dataset
|
| 274 |
+
collate function.
|
| 275 |
+
We use language_pair_dataset here to encapsulate the tgt_dataset
|
| 276 |
+
so we can re-use the LanguagePairDataset collater to format the
|
| 277 |
+
batches in the structure that SequenceGenerator expects.
|
| 278 |
+
src_dict (~fairseq.data.Dictionary): source dictionary
|
| 279 |
+
seed (int): seed to use when generating random noise
|
| 280 |
+
noiser (WordNoising): a pre-initialized :class:`WordNoising`
|
| 281 |
+
instance. If this is None, a new instance will be created using
|
| 282 |
+
*noising_class* and *kwargs*.
|
| 283 |
+
noising_class (class, optional): class to use to initialize a
|
| 284 |
+
default :class:`WordNoising` instance.
|
| 285 |
+
kwargs (dict, optional): arguments to initialize the default
|
| 286 |
+
:class:`WordNoising` instance given by *noiser*.
|
| 287 |
+
"""
|
| 288 |
+
self.src_dataset = src_dataset
|
| 289 |
+
self.src_dict = src_dict
|
| 290 |
+
self.seed = seed
|
| 291 |
+
self.noiser = (
|
| 292 |
+
noiser
|
| 293 |
+
if noiser is not None
|
| 294 |
+
else noising_class(
|
| 295 |
+
dictionary=src_dict,
|
| 296 |
+
**kwargs,
|
| 297 |
+
)
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
def __getitem__(self, index):
|
| 301 |
+
"""
|
| 302 |
+
Returns a single noisy sample. Multiple samples are fed to the collater
|
| 303 |
+
create a noising dataset batch.
|
| 304 |
+
"""
|
| 305 |
+
src_tokens = self.src_dataset[index]
|
| 306 |
+
src_lengths = torch.LongTensor([len(src_tokens)])
|
| 307 |
+
src_tokens = src_tokens.unsqueeze(0)
|
| 308 |
+
|
| 309 |
+
# Transpose src tokens to fit expected shape of x in noising function
|
| 310 |
+
# (batch size, sequence length) -> (sequence length, batch size)
|
| 311 |
+
src_tokens_t = torch.t(src_tokens)
|
| 312 |
+
|
| 313 |
+
with data_utils.numpy_seed(self.seed + index):
|
| 314 |
+
noisy_src_tokens = self.noiser.noising(src_tokens_t, src_lengths)
|
| 315 |
+
|
| 316 |
+
# Transpose back to expected src_tokens format
|
| 317 |
+
# (sequence length, 1) -> (1, sequence length)
|
| 318 |
+
noisy_src_tokens = torch.t(noisy_src_tokens)
|
| 319 |
+
return noisy_src_tokens[0]
|
| 320 |
+
|
| 321 |
+
def __len__(self):
|
| 322 |
+
"""
|
| 323 |
+
The length of the noising dataset is the length of src.
|
| 324 |
+
"""
|
| 325 |
+
return len(self.src_dataset)
|
| 326 |
+
|
| 327 |
+
@property
|
| 328 |
+
def supports_prefetch(self):
|
| 329 |
+
return self.src_dataset.supports_prefetch
|
| 330 |
+
|
| 331 |
+
def prefetch(self, indices):
|
| 332 |
+
if self.src_dataset.supports_prefetch:
|
| 333 |
+
self.src_dataset.prefetch(indices)
|
fairseq-0.10.2/fairseq/data/numel_dataset.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from . import BaseWrapperDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class NumelDataset(BaseWrapperDataset):
|
| 13 |
+
def __init__(self, dataset, reduce=False):
|
| 14 |
+
super().__init__(dataset)
|
| 15 |
+
self.reduce = reduce
|
| 16 |
+
|
| 17 |
+
def __getitem__(self, index):
|
| 18 |
+
item = self.dataset[index]
|
| 19 |
+
if torch.is_tensor(item):
|
| 20 |
+
return torch.numel(item)
|
| 21 |
+
else:
|
| 22 |
+
return np.size(item)
|
| 23 |
+
|
| 24 |
+
def __len__(self):
|
| 25 |
+
return len(self.dataset)
|
| 26 |
+
|
| 27 |
+
def collater(self, samples):
|
| 28 |
+
if self.reduce:
|
| 29 |
+
return sum(samples)
|
| 30 |
+
else:
|
| 31 |
+
return torch.tensor(samples)
|
fairseq-0.10.2/fairseq/data/plasma_utils.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import subprocess
|
| 7 |
+
import tempfile
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class PlasmaArray(object):
|
| 11 |
+
"""
|
| 12 |
+
Wrapper around numpy arrays that automatically moves the data to shared
|
| 13 |
+
memory upon serialization. This is particularly helpful when passing numpy
|
| 14 |
+
arrays through multiprocessing, so that data is not unnecessarily
|
| 15 |
+
duplicated or pickled.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, array):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.array = array
|
| 21 |
+
self.disable = array.nbytes < 134217728 # disable for arrays <128MB
|
| 22 |
+
self.object_id = None
|
| 23 |
+
self.path = None
|
| 24 |
+
|
| 25 |
+
# variables with underscores shouldn't be pickled
|
| 26 |
+
self._client = None
|
| 27 |
+
self._server = None
|
| 28 |
+
self._server_tmp = None
|
| 29 |
+
self._plasma = None
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def plasma(self):
|
| 33 |
+
if self._plasma is None and not self.disable:
|
| 34 |
+
try:
|
| 35 |
+
import pyarrow.plasma as plasma
|
| 36 |
+
|
| 37 |
+
self._plasma = plasma
|
| 38 |
+
except ImportError:
|
| 39 |
+
self._plasma = None
|
| 40 |
+
return self._plasma
|
| 41 |
+
|
| 42 |
+
def start_server(self):
|
| 43 |
+
if self.plasma is None or self._server is not None:
|
| 44 |
+
return
|
| 45 |
+
assert self.object_id is None
|
| 46 |
+
assert self.path is None
|
| 47 |
+
self._server_tmp = tempfile.NamedTemporaryFile()
|
| 48 |
+
self.path = self._server_tmp.name
|
| 49 |
+
self._server = subprocess.Popen(
|
| 50 |
+
[
|
| 51 |
+
"plasma_store",
|
| 52 |
+
"-m",
|
| 53 |
+
str(int(1.05 * self.array.nbytes)),
|
| 54 |
+
"-s",
|
| 55 |
+
self.path,
|
| 56 |
+
]
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def client(self):
|
| 61 |
+
if self._client is None:
|
| 62 |
+
assert self.path is not None
|
| 63 |
+
self._client = self.plasma.connect(self.path)
|
| 64 |
+
return self._client
|
| 65 |
+
|
| 66 |
+
def __getstate__(self):
|
| 67 |
+
if self.plasma is None:
|
| 68 |
+
return self.__dict__
|
| 69 |
+
if self.object_id is None:
|
| 70 |
+
self.start_server()
|
| 71 |
+
self.object_id = self.client.put(self.array)
|
| 72 |
+
state = self.__dict__.copy()
|
| 73 |
+
del state["array"]
|
| 74 |
+
state["_client"] = None
|
| 75 |
+
state["_server"] = None
|
| 76 |
+
state["_server_tmp"] = None
|
| 77 |
+
state["_plasma"] = None
|
| 78 |
+
return state
|
| 79 |
+
|
| 80 |
+
def __setstate__(self, state):
|
| 81 |
+
self.__dict__.update(state)
|
| 82 |
+
if self.plasma is None:
|
| 83 |
+
return
|
| 84 |
+
self.array = self.client.get(self.object_id)
|
| 85 |
+
|
| 86 |
+
def __del__(self):
|
| 87 |
+
if self._server is not None:
|
| 88 |
+
self._server.kill()
|
| 89 |
+
self._server = None
|
| 90 |
+
self._server_tmp.close()
|
| 91 |
+
self._server_tmp = None
|
fairseq-0.10.2/fairseq/data/prepend_token_dataset.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from . import BaseWrapperDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PrependTokenDataset(BaseWrapperDataset):
|
| 13 |
+
def __init__(self, dataset, token=None):
|
| 14 |
+
super().__init__(dataset)
|
| 15 |
+
self.token = token
|
| 16 |
+
if token is not None:
|
| 17 |
+
self._sizes = np.array(dataset.sizes) + 1
|
| 18 |
+
else:
|
| 19 |
+
self._sizes = dataset.sizes
|
| 20 |
+
|
| 21 |
+
def __getitem__(self, idx):
|
| 22 |
+
item = self.dataset[idx]
|
| 23 |
+
if self.token is not None:
|
| 24 |
+
item = torch.cat([item.new([self.token]), item])
|
| 25 |
+
return item
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def sizes(self):
|
| 29 |
+
return self._sizes
|
| 30 |
+
|
| 31 |
+
def num_tokens(self, index):
|
| 32 |
+
n = self.dataset.num_tokens(index)
|
| 33 |
+
if self.token is not None:
|
| 34 |
+
n += 1
|
| 35 |
+
return n
|
| 36 |
+
|
| 37 |
+
def size(self, index):
|
| 38 |
+
n = self.dataset.size(index)
|
| 39 |
+
if self.token is not None:
|
| 40 |
+
n += 1
|
| 41 |
+
return n
|
fairseq-0.10.2/fairseq/data/raw_label_dataset.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from . import FairseqDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RawLabelDataset(FairseqDataset):
|
| 12 |
+
def __init__(self, labels):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.labels = labels
|
| 15 |
+
|
| 16 |
+
def __getitem__(self, index):
|
| 17 |
+
return self.labels[index]
|
| 18 |
+
|
| 19 |
+
def __len__(self):
|
| 20 |
+
return len(self.labels)
|
| 21 |
+
|
| 22 |
+
def collater(self, samples):
|
| 23 |
+
return torch.tensor(samples)
|
fairseq-0.10.2/fairseq/data/replace_dataset.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from . import BaseWrapperDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ReplaceDataset(BaseWrapperDataset):
|
| 10 |
+
"""Replaces tokens found in the dataset by a specified replacement token
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
dataset (~torch.utils.data.Dataset): dataset to replace tokens in
|
| 14 |
+
replace_map(Dictionary[int,int]): map of token to replace -> replacement token
|
| 15 |
+
offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be
|
| 16 |
+
as many as the number of objects returned by the underlying dataset __getitem__ method.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, dataset, replace_map, offsets):
|
| 20 |
+
super().__init__(dataset)
|
| 21 |
+
assert len(replace_map) > 0
|
| 22 |
+
self.replace_map = replace_map
|
| 23 |
+
self.offsets = offsets
|
| 24 |
+
|
| 25 |
+
def __getitem__(self, index):
|
| 26 |
+
item = self.dataset[index]
|
| 27 |
+
is_tuple = isinstance(item, tuple)
|
| 28 |
+
srcs = item if is_tuple else [item]
|
| 29 |
+
|
| 30 |
+
for offset, src in zip(self.offsets, srcs):
|
| 31 |
+
for k, v in self.replace_map.items():
|
| 32 |
+
src_off = src[offset:] if offset >= 0 else src[:offset]
|
| 33 |
+
src_off.masked_fill_(src_off == k, v)
|
| 34 |
+
|
| 35 |
+
item = srcs if is_tuple else srcs[0]
|
| 36 |
+
return item
|
fairseq-0.10.2/fairseq/data/roll_dataset.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from . import BaseWrapperDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RollDataset(BaseWrapperDataset):
|
| 12 |
+
def __init__(self, dataset, shifts):
|
| 13 |
+
super().__init__(dataset)
|
| 14 |
+
self.shifts = shifts
|
| 15 |
+
|
| 16 |
+
def __getitem__(self, index):
|
| 17 |
+
item = self.dataset[index]
|
| 18 |
+
return torch.roll(item, self.shifts)
|
fairseq-0.10.2/fairseq/data/round_robin_zip_datasets.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from . import FairseqDataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RoundRobinZipDatasets(FairseqDataset):
|
| 14 |
+
"""Zip multiple :class:`~fairseq.data.FairseqDataset` instances together.
|
| 15 |
+
|
| 16 |
+
Shorter datasets are repeated in a round-robin fashion to match the length
|
| 17 |
+
of the longest one.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of
|
| 21 |
+
:class:`~fairseq.data.FairseqDataset` instances.
|
| 22 |
+
eval_key (str, optional): a key used at evaluation time that causes
|
| 23 |
+
this instance to pass-through batches from *datasets[eval_key]*.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, datasets, eval_key=None):
|
| 27 |
+
super().__init__()
|
| 28 |
+
assert isinstance(datasets, OrderedDict)
|
| 29 |
+
self.datasets = datasets
|
| 30 |
+
self.eval_key = eval_key
|
| 31 |
+
|
| 32 |
+
self.longest_dataset = None
|
| 33 |
+
self.longest_dataset_key = None
|
| 34 |
+
for key, dataset in datasets.items():
|
| 35 |
+
assert isinstance(dataset, FairseqDataset)
|
| 36 |
+
if self.longest_dataset is None or len(dataset) > len(self.longest_dataset):
|
| 37 |
+
self.longest_dataset = dataset
|
| 38 |
+
self.longest_dataset_key = key
|
| 39 |
+
|
| 40 |
+
self._ordered_indices = None
|
| 41 |
+
|
| 42 |
+
def _map_index(self, key, index):
|
| 43 |
+
assert (
|
| 44 |
+
self._ordered_indices is not None
|
| 45 |
+
), "Must call RoundRobinZipDatasets.ordered_indices() first"
|
| 46 |
+
return self._ordered_indices[key][index % len(self.datasets[key])]
|
| 47 |
+
|
| 48 |
+
def __getitem__(self, index):
|
| 49 |
+
if self.eval_key is None:
|
| 50 |
+
return OrderedDict(
|
| 51 |
+
[
|
| 52 |
+
(key, dataset[self._map_index(key, index)])
|
| 53 |
+
for key, dataset in self.datasets.items()
|
| 54 |
+
]
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
# at evaluation time it's useful to pass-through batches from a single key
|
| 58 |
+
return self.datasets[self.eval_key][self._map_index(self.eval_key, index)]
|
| 59 |
+
|
| 60 |
+
def __len__(self):
|
| 61 |
+
return len(self.longest_dataset)
|
| 62 |
+
|
| 63 |
+
def collater(self, samples):
|
| 64 |
+
"""Merge a list of samples to form a mini-batch."""
|
| 65 |
+
if len(samples) == 0:
|
| 66 |
+
return None
|
| 67 |
+
if self.eval_key is None:
|
| 68 |
+
return OrderedDict(
|
| 69 |
+
[
|
| 70 |
+
(key, dataset.collater([sample[key] for sample in samples]))
|
| 71 |
+
for key, dataset in self.datasets.items()
|
| 72 |
+
]
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
# at evaluation time it's useful to pass-through batches from a single key
|
| 76 |
+
return self.datasets[self.eval_key].collater(samples)
|
| 77 |
+
|
| 78 |
+
def num_tokens(self, index):
|
| 79 |
+
"""Return an example's length (number of tokens), used for batching."""
|
| 80 |
+
# TODO make it configurable whether to use max() or sum() here
|
| 81 |
+
return max(
|
| 82 |
+
dataset.num_tokens(self._map_index(key, index))
|
| 83 |
+
for key, dataset in self.datasets.items()
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def size(self, index):
|
| 87 |
+
"""Return an example's size as a float or tuple. This value is used when
|
| 88 |
+
filtering a dataset with ``--max-positions``."""
|
| 89 |
+
return {
|
| 90 |
+
key: dataset.size(self._map_index(key, index))
|
| 91 |
+
for key, dataset in self.datasets.items()
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
def ordered_indices(self):
|
| 95 |
+
"""Ordered indices for batching."""
|
| 96 |
+
if self._ordered_indices is None:
|
| 97 |
+
# Call the underlying dataset's ordered_indices() here, so that we
|
| 98 |
+
# get the same random ordering as we would have from using the
|
| 99 |
+
# underlying dataset directly.
|
| 100 |
+
self._ordered_indices = OrderedDict(
|
| 101 |
+
[
|
| 102 |
+
(key, dataset.ordered_indices())
|
| 103 |
+
for key, dataset in self.datasets.items()
|
| 104 |
+
]
|
| 105 |
+
)
|
| 106 |
+
return np.arange(len(self))
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def supports_prefetch(self):
|
| 110 |
+
return all(
|
| 111 |
+
getattr(dataset, "supports_prefetch", False)
|
| 112 |
+
for dataset in self.datasets.values()
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def prefetch(self, indices):
|
| 116 |
+
for key, dataset in self.datasets.items():
|
| 117 |
+
dataset.prefetch([self._map_index(key, index) for index in indices])
|
fairseq-0.10.2/fairseq/data/sort_dataset.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from . import BaseWrapperDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SortDataset(BaseWrapperDataset):
|
| 12 |
+
def __init__(self, dataset, sort_order):
|
| 13 |
+
super().__init__(dataset)
|
| 14 |
+
if not isinstance(sort_order, (list, tuple)):
|
| 15 |
+
sort_order = [sort_order]
|
| 16 |
+
self.sort_order = sort_order
|
| 17 |
+
|
| 18 |
+
assert all(len(so) == len(dataset) for so in sort_order)
|
| 19 |
+
|
| 20 |
+
def ordered_indices(self):
|
| 21 |
+
return np.lexsort(self.sort_order)
|
fairseq-0.10.2/fairseq/data/subsample_dataset.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from . import BaseWrapperDataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SubsampleDataset(BaseWrapperDataset):
|
| 17 |
+
"""Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
dataset (~torch.utils.data.Dataset): dataset to subsample
|
| 21 |
+
size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, dataset, size_ratio, shuffle=False):
|
| 25 |
+
super().__init__(dataset)
|
| 26 |
+
assert size_ratio < 1
|
| 27 |
+
self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int)
|
| 28 |
+
self.indices = np.random.choice(
|
| 29 |
+
list(range(len(self.dataset))), self.actual_size, replace=False
|
| 30 |
+
)
|
| 31 |
+
self.shuffle = shuffle
|
| 32 |
+
logger.info(
|
| 33 |
+
"subsampled dataset from {} to {} (ratio={})".format(
|
| 34 |
+
len(self.dataset), self.actual_size, size_ratio
|
| 35 |
+
)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, index):
|
| 39 |
+
return self.dataset[self.indices[index]]
|
| 40 |
+
|
| 41 |
+
def __len__(self):
|
| 42 |
+
return self.actual_size
|
| 43 |
+
|
| 44 |
+
def collater(self, samples):
|
| 45 |
+
return self.dataset.collater(samples)
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def sizes(self):
|
| 49 |
+
return self.dataset.sizes[self.indices]
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def name(self):
|
| 53 |
+
return self.dataset.name
|
| 54 |
+
|
| 55 |
+
def num_tokens(self, index):
|
| 56 |
+
return self.dataset.num_tokens(self.indices[index])
|
| 57 |
+
|
| 58 |
+
def size(self, index):
|
| 59 |
+
return self.dataset.size(self.indices[index])
|
| 60 |
+
|
| 61 |
+
def ordered_indices(self):
|
| 62 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
| 63 |
+
on this order."""
|
| 64 |
+
if self.shuffle:
|
| 65 |
+
order = [np.random.permutation(len(self))]
|
| 66 |
+
else:
|
| 67 |
+
order = [np.arange(len(self))]
|
| 68 |
+
order.append(self.sizes)
|
| 69 |
+
return np.lexsort(order)
|
| 70 |
+
|
| 71 |
+
def prefetch(self, indices):
|
| 72 |
+
self.dataset.prefetch(self.indices[indices])
|
fairseq-0.10.2/fairseq/data/token_block_utils_fast.cpp
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
fairseq-0.10.2/fairseq/data/token_block_utils_fast.pyx
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# cython: language_level=3
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from itertools import chain
|
| 10 |
+
from libc.math cimport ceil
|
| 11 |
+
|
| 12 |
+
cimport cython
|
| 13 |
+
cimport numpy as np
|
| 14 |
+
|
| 15 |
+
from libc.stdint cimport int32_t, int64_t
|
| 16 |
+
|
| 17 |
+
DTYPE = np.int64
|
| 18 |
+
ctypedef int64_t DTYPE_t
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@cython.boundscheck(False)
|
| 22 |
+
@cython.wraparound(False)
|
| 23 |
+
@cython.nonecheck(False)
|
| 24 |
+
cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size):
|
| 25 |
+
cdef DTYPE_t total_size = sizes.sum()
|
| 26 |
+
cdef DTYPE_t length = <DTYPE_t> ceil(total_size / <double> block_size)
|
| 27 |
+
cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE)
|
| 28 |
+
cdef DTYPE_t[:, :] slice_indices_view = slice_indices
|
| 29 |
+
cdef DTYPE_t i
|
| 30 |
+
cdef DTYPE_t start
|
| 31 |
+
cdef DTYPE_t end
|
| 32 |
+
for i in range(length):
|
| 33 |
+
start = i * block_size
|
| 34 |
+
end = min(start + block_size, total_size)
|
| 35 |
+
slice_indices_view[i][0] = start
|
| 36 |
+
slice_indices_view[i][1] = end
|
| 37 |
+
return slice_indices
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
cdef np.ndarray[DTYPE_t, ndim=2] _fast_convert_to_np_array(list list_of_list):
|
| 41 |
+
"""
|
| 42 |
+
Faster function to convert DTYPE_t list of list.
|
| 43 |
+
Only fast when there are huge number of rows and low number of columns.
|
| 44 |
+
"""
|
| 45 |
+
cdef np.ndarray[DTYPE_t, ndim=1] flat = np.fromiter(chain.from_iterable(list_of_list), DTYPE, -1)
|
| 46 |
+
return flat.reshape((len(list_of_list), -1))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@cython.boundscheck(False)
|
| 50 |
+
@cython.wraparound(False)
|
| 51 |
+
@cython.nonecheck(False)
|
| 52 |
+
cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len):
|
| 53 |
+
cdef DTYPE_t tok_idx = 0
|
| 54 |
+
cdef DTYPE_t sz_idx = 0
|
| 55 |
+
cdef DTYPE_t curr_size = 0
|
| 56 |
+
cdef DTYPE_t i = 0
|
| 57 |
+
cdef DTYPE_t length
|
| 58 |
+
cdef DTYPE_t total_size
|
| 59 |
+
cdef DTYPE_t[:] sizes_view = sizes
|
| 60 |
+
cdef np.ndarray[DTYPE_t, ndim=2] slice_indices
|
| 61 |
+
cdef list slice_indices_list = []
|
| 62 |
+
|
| 63 |
+
if break_mode is None or break_mode == 'none':
|
| 64 |
+
slice_indices = _get_slice_indices_none_mode(sizes, block_size)
|
| 65 |
+
elif break_mode == 'complete':
|
| 66 |
+
while sz_idx < len(sizes_view):
|
| 67 |
+
if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0:
|
| 68 |
+
curr_size += sizes_view[sz_idx]
|
| 69 |
+
sz_idx += 1
|
| 70 |
+
else:
|
| 71 |
+
slice_indices_list.append((tok_idx, tok_idx + curr_size))
|
| 72 |
+
tok_idx += curr_size
|
| 73 |
+
curr_size = 0
|
| 74 |
+
if curr_size > 0:
|
| 75 |
+
slice_indices_list.append((tok_idx, tok_idx + curr_size))
|
| 76 |
+
slice_indices = _fast_convert_to_np_array(slice_indices_list)
|
| 77 |
+
elif break_mode == 'complete_doc':
|
| 78 |
+
while sz_idx < len(sizes_view):
|
| 79 |
+
if (
|
| 80 |
+
(curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0)
|
| 81 |
+
# an empty sentence indicates end-of-document:
|
| 82 |
+
and sizes_view[sz_idx] != document_sep_len
|
| 83 |
+
):
|
| 84 |
+
curr_size += sizes_view[sz_idx]
|
| 85 |
+
sz_idx += 1
|
| 86 |
+
else:
|
| 87 |
+
# Only keep non-empty documents.
|
| 88 |
+
if curr_size > 1:
|
| 89 |
+
slice_indices_list.append((tok_idx, tok_idx + curr_size))
|
| 90 |
+
tok_idx += curr_size
|
| 91 |
+
curr_size = 0
|
| 92 |
+
if sizes_view[sz_idx] == document_sep_len:
|
| 93 |
+
tok_idx += sizes_view[sz_idx]
|
| 94 |
+
sz_idx += 1
|
| 95 |
+
if curr_size > 1:
|
| 96 |
+
slice_indices_list.append((tok_idx, tok_idx + curr_size))
|
| 97 |
+
slice_indices = _fast_convert_to_np_array(slice_indices_list)
|
| 98 |
+
elif break_mode == 'eos':
|
| 99 |
+
slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE)
|
| 100 |
+
cumsum = sizes.cumsum(axis=0)
|
| 101 |
+
slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1]
|
| 102 |
+
slice_indices[:, 1] = cumsum
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError('Invalid break_mode: ' + break_mode)
|
| 105 |
+
return slice_indices
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@cython.boundscheck(False)
|
| 109 |
+
@cython.wraparound(False)
|
| 110 |
+
@cython.nonecheck(False)
|
| 111 |
+
cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices):
|
| 112 |
+
cdef DTYPE_t start_ds_idx
|
| 113 |
+
cdef DTYPE_t start_offset
|
| 114 |
+
cdef DTYPE_t end_ds_idx
|
| 115 |
+
cdef DTYPE_t i
|
| 116 |
+
cdef DTYPE_t s
|
| 117 |
+
cdef DTYPE_t e
|
| 118 |
+
cdef DatasetSearcher ds = DatasetSearcher(sizes)
|
| 119 |
+
cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE)
|
| 120 |
+
cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index
|
| 121 |
+
cdef DTYPE_t[:, :] slice_indices_view = slice_indices
|
| 122 |
+
cdef Py_ssize_t x_max = slice_indices.shape[0]
|
| 123 |
+
|
| 124 |
+
for i in range(x_max):
|
| 125 |
+
s = slice_indices_view[i][0]
|
| 126 |
+
e = slice_indices_view[i][1]
|
| 127 |
+
ds.seek(s)
|
| 128 |
+
start_ds_idx = ds.current_index
|
| 129 |
+
start_offset = ds.current_offset
|
| 130 |
+
if e <= s:
|
| 131 |
+
end_ds_idx = start_ds_idx
|
| 132 |
+
else:
|
| 133 |
+
ds.seek(e - 1)
|
| 134 |
+
end_ds_idx = ds.current_index
|
| 135 |
+
block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset
|
| 136 |
+
block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index
|
| 137 |
+
block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset
|
| 138 |
+
return block_to_dataset_index
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
cdef class DatasetSearcher(object):
|
| 142 |
+
"""Helper for mapping "flat" indices to indices and offsets in an
|
| 143 |
+
underlying dataset."""
|
| 144 |
+
cdef DTYPE_t current_i
|
| 145 |
+
cdef DTYPE_t current_offset
|
| 146 |
+
cdef DTYPE_t current_index
|
| 147 |
+
cdef DTYPE_t[:] sizes
|
| 148 |
+
|
| 149 |
+
def __init__(self, DTYPE_t[:] sizes):
|
| 150 |
+
self.sizes = sizes
|
| 151 |
+
self.reset()
|
| 152 |
+
|
| 153 |
+
cdef reset(self):
|
| 154 |
+
self.current_offset = 0 # offset within current index in underlying dataset
|
| 155 |
+
self.current_i = 0 # "flat" index
|
| 156 |
+
self.current_index = 0 # index in underlying dataset
|
| 157 |
+
|
| 158 |
+
@cython.boundscheck(False)
|
| 159 |
+
@cython.wraparound(False)
|
| 160 |
+
@cython.nonecheck(False)
|
| 161 |
+
cdef int step(self, DTYPE_t i):
|
| 162 |
+
cdef DTYPE_t to_consume
|
| 163 |
+
cdef DTYPE_t remaining
|
| 164 |
+
if i < self.current_i:
|
| 165 |
+
self.reset()
|
| 166 |
+
if i > self.current_i:
|
| 167 |
+
to_consume = i - self.current_i
|
| 168 |
+
remaining = self.sizes[self.current_index] - self.current_offset
|
| 169 |
+
if remaining > to_consume:
|
| 170 |
+
self.current_offset += to_consume
|
| 171 |
+
self.current_i += to_consume
|
| 172 |
+
else:
|
| 173 |
+
assert remaining >= 0
|
| 174 |
+
self.current_i += remaining
|
| 175 |
+
self.current_index += 1
|
| 176 |
+
self.current_offset = 0
|
| 177 |
+
return 1
|
| 178 |
+
return 0
|
| 179 |
+
|
| 180 |
+
@cython.boundscheck(False)
|
| 181 |
+
@cython.wraparound(False)
|
| 182 |
+
@cython.nonecheck(False)
|
| 183 |
+
cdef seek(self, DTYPE_t i):
|
| 184 |
+
cdef int not_done = 1
|
| 185 |
+
while not_done == 1:
|
| 186 |
+
not_done = self.step(i)
|
| 187 |
+
assert self.current_i == i
|
fairseq-0.10.2/fairseq/data/transform_eos_dataset.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from . import FairseqDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TransformEosDataset(FairseqDataset):
|
| 12 |
+
"""A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
|
| 13 |
+
|
| 14 |
+
Note that the transformation is applied in :func:`collater`.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
dataset (~fairseq.data.FairseqDataset): dataset to wrap
|
| 18 |
+
eos (int): index of the end-of-sentence symbol
|
| 19 |
+
append_eos_to_src (bool, optional): append EOS to the end of src
|
| 20 |
+
remove_eos_from_src (bool, optional): remove EOS from the end of src
|
| 21 |
+
append_eos_to_tgt (bool, optional): append EOS to the end of tgt
|
| 22 |
+
remove_eos_from_tgt (bool, optional): remove EOS from the end of tgt
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
dataset,
|
| 28 |
+
eos,
|
| 29 |
+
append_eos_to_src=False,
|
| 30 |
+
remove_eos_from_src=False,
|
| 31 |
+
append_eos_to_tgt=False,
|
| 32 |
+
remove_eos_from_tgt=False,
|
| 33 |
+
has_target=True,
|
| 34 |
+
):
|
| 35 |
+
if not isinstance(dataset, FairseqDataset):
|
| 36 |
+
raise ValueError("dataset must be an instance of FairseqDataset")
|
| 37 |
+
if append_eos_to_src and remove_eos_from_src:
|
| 38 |
+
raise ValueError("cannot combine append_eos_to_src and remove_eos_from_src")
|
| 39 |
+
if append_eos_to_tgt and remove_eos_from_tgt:
|
| 40 |
+
raise ValueError("cannot combine append_eos_to_tgt and remove_eos_from_tgt")
|
| 41 |
+
|
| 42 |
+
self.dataset = dataset
|
| 43 |
+
self.eos = torch.LongTensor([eos])
|
| 44 |
+
self.append_eos_to_src = append_eos_to_src
|
| 45 |
+
self.remove_eos_from_src = remove_eos_from_src
|
| 46 |
+
self.append_eos_to_tgt = append_eos_to_tgt
|
| 47 |
+
self.remove_eos_from_tgt = remove_eos_from_tgt
|
| 48 |
+
self.has_target = has_target
|
| 49 |
+
|
| 50 |
+
# precompute how we should adjust the reported sizes
|
| 51 |
+
self._src_delta = 0
|
| 52 |
+
self._src_delta += 1 if append_eos_to_src else 0
|
| 53 |
+
self._src_delta -= 1 if remove_eos_from_src else 0
|
| 54 |
+
self._tgt_delta = 0
|
| 55 |
+
self._tgt_delta += 1 if append_eos_to_tgt else 0
|
| 56 |
+
self._tgt_delta -= 1 if remove_eos_from_tgt else 0
|
| 57 |
+
|
| 58 |
+
self._checked_src = False
|
| 59 |
+
self._checked_tgt = False
|
| 60 |
+
|
| 61 |
+
def _check_src(self, src, expect_eos):
|
| 62 |
+
if not self._checked_src:
|
| 63 |
+
assert (src[-1] == self.eos[0]) == expect_eos
|
| 64 |
+
self._checked_src = True
|
| 65 |
+
|
| 66 |
+
def _check_tgt(self, tgt, expect_eos):
|
| 67 |
+
if self.has_target and not self._checked_tgt:
|
| 68 |
+
assert (tgt[-1] == self.eos[0]) == expect_eos
|
| 69 |
+
self._checked_tgt = True
|
| 70 |
+
|
| 71 |
+
def __getitem__(self, index):
|
| 72 |
+
return self.dataset[index]
|
| 73 |
+
|
| 74 |
+
def __len__(self):
|
| 75 |
+
return len(self.dataset)
|
| 76 |
+
|
| 77 |
+
def collater(self, samples):
|
| 78 |
+
def transform(item):
|
| 79 |
+
if self.append_eos_to_src:
|
| 80 |
+
self.eos = self.eos.to(device=item["source"].device)
|
| 81 |
+
self._check_src(item["source"], expect_eos=False)
|
| 82 |
+
item["source"] = torch.cat([item["source"], self.eos])
|
| 83 |
+
if self.remove_eos_from_src:
|
| 84 |
+
self.eos = self.eos.to(device=item["source"].device)
|
| 85 |
+
self._check_src(item["source"], expect_eos=True)
|
| 86 |
+
item["source"] = item["source"][:-1]
|
| 87 |
+
if self.append_eos_to_tgt:
|
| 88 |
+
self.eos = self.eos.to(device=item["target"].device)
|
| 89 |
+
self._check_tgt(item["target"], expect_eos=False)
|
| 90 |
+
item["target"] = torch.cat([item["target"], self.eos])
|
| 91 |
+
if self.remove_eos_from_tgt:
|
| 92 |
+
self.eos = self.eos.to(device=item["target"].device)
|
| 93 |
+
self._check_tgt(item["target"], expect_eos=True)
|
| 94 |
+
item["target"] = item["target"][:-1]
|
| 95 |
+
return item
|
| 96 |
+
|
| 97 |
+
samples = list(map(transform, samples))
|
| 98 |
+
return self.dataset.collater(samples)
|
| 99 |
+
|
| 100 |
+
def num_tokens(self, index):
|
| 101 |
+
return self.dataset.num_tokens(index)
|
| 102 |
+
|
| 103 |
+
def size(self, index):
|
| 104 |
+
if self.has_target:
|
| 105 |
+
src_len, tgt_len = self.dataset.size(index)
|
| 106 |
+
return (src_len + self._src_delta, tgt_len + self._tgt_delta)
|
| 107 |
+
else:
|
| 108 |
+
return self.dataset.size(index)
|
| 109 |
+
|
| 110 |
+
def ordered_indices(self):
|
| 111 |
+
# NOTE: we assume that the ordering does not change based on the
|
| 112 |
+
# addition or removal of eos
|
| 113 |
+
return self.dataset.ordered_indices()
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def supports_prefetch(self):
|
| 117 |
+
return getattr(self.dataset, "supports_prefetch", False)
|
| 118 |
+
|
| 119 |
+
def prefetch(self, indices):
|
| 120 |
+
return self.dataset.prefetch(indices)
|