diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..abecf8fbe941be1cfe0cb9fc20f47e06bd5be290 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+fairseq-0.10.2/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/__init__.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96c3657a5dfda82479726d301c9fb1e4cdd8d6ce
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1e8006ca2ecaa2e35b56e3c8ba837a007ab4e93
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa49ca7677681fed6af7579d0af8989295bbd313
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1563b1be8e2937c24204f47a747de2ad7afa6d8e
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb2889aa229413ecfd12072d0353b07ac4c6c2af
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2d9cd5e0f7d731b877f1ad6ae51e0c41a74fbe1
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/dictionary.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/dictionary.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2178a6f62ee31358ded09f840ff4272076b5f605
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/dictionary.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/id_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/id_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..56fe292ddb6b418954fac1f16be271a556aeb816
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/id_dataset.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aee56d6a26e61c9c392d29f2bb772cba3bdc2d3e
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae7ba3410aff285298924c57199024cde1e626be
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88d95564d6166c1e13a48dfc134a6124c9827ade
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87293b6e9583b3ae3f3de1eccf15e9d24516c332
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..223ee06ac60b91925a696a992e57e95094f84db0
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2472cdfb1758d4734a519a45723a8d1affb5aab
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13a29dddd4f3171cb1bf276a26a8c0a3621d0bed
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/concat_dataset.py b/fairseq-0.10.2/fairseq/data/concat_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..01a4078bb159fa44b2d1062b9a971fe7f1abd1c2
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/concat_dataset.py
@@ -0,0 +1,124 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import bisect
+
+import numpy as np
+from torch.utils.data.dataloader import default_collate
+
+from . import FairseqDataset
+
+
+class ConcatDataset(FairseqDataset):
+ @staticmethod
+ def cumsum(sequence, sample_ratios):
+ r, s = [], 0
+ for e, ratio in zip(sequence, sample_ratios):
+ curr_len = int(ratio * len(e))
+ r.append(curr_len + s)
+ s += curr_len
+ return r
+
+ def __init__(self, datasets, sample_ratios=1):
+ super(ConcatDataset, self).__init__()
+ assert len(datasets) > 0, "datasets should not be an empty iterable"
+ self.datasets = list(datasets)
+ if isinstance(sample_ratios, int):
+ sample_ratios = [sample_ratios] * len(self.datasets)
+ self.sample_ratios = sample_ratios
+ self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
+ self.real_sizes = [len(d) for d in self.datasets]
+
+ def __len__(self):
+ return self.cumulative_sizes[-1]
+
+ def __getitem__(self, idx):
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
+ return self.datasets[dataset_idx][sample_idx]
+
+ def _get_dataset_and_sample_index(self, idx: int):
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ sample_idx = sample_idx % self.real_sizes[dataset_idx]
+ return dataset_idx, sample_idx
+
+ def collater(self, samples, **extra_args):
+ # For now only supports datasets with same underlying collater implementations
+ if hasattr(self.datasets[0], "collater"):
+ return self.datasets[0].collater(samples, **extra_args)
+ else:
+ return default_collate(samples, **extra_args)
+
+ def size(self, idx: int):
+ """
+ Return an example's size as a float or tuple.
+ """
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
+ return self.datasets[dataset_idx].size(sample_idx)
+
+ def num_tokens(self, index: int):
+ return np.max(self.size(index))
+
+ def attr(self, attr: str, index: int):
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
+ return getattr(self.datasets[dataset_idx], attr, None)
+
+ @property
+ def sizes(self):
+ _dataset_sizes = []
+ for ds, sr in zip(self.datasets, self.sample_ratios):
+ if isinstance(ds.sizes, np.ndarray):
+ _dataset_sizes.append(np.tile(ds.sizes, sr))
+ else:
+ # Only support underlying dataset with single size array.
+ assert isinstance(ds.sizes, list)
+ _dataset_sizes.append(np.tile(ds.sizes[0], sr))
+ return np.concatenate(_dataset_sizes)
+
+ @property
+ def supports_prefetch(self):
+ return all(d.supports_prefetch for d in self.datasets)
+
+ def ordered_indices(self):
+ """
+ Returns indices sorted by length. So less padding is needed.
+ """
+ if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
+ # special handling for concatenating lang_pair_datasets
+ indices = np.arange(len(self))
+ sizes = self.sizes
+ tgt_sizes = (
+ sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
+ )
+ src_sizes = (
+ sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
+ )
+ # sort by target length, then source length
+ if tgt_sizes is not None:
+ indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
+ return indices[np.argsort(src_sizes[indices], kind="mergesort")]
+ else:
+ return np.argsort(self.sizes)
+
+ def prefetch(self, indices):
+ frm = 0
+ for to, ds in zip(self.cumulative_sizes, self.datasets):
+ real_size = len(ds)
+ if getattr(ds, "supports_prefetch", False):
+ ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
+ frm = to
+
+ @property
+ def can_reuse_epoch_itr_across_epochs(self):
+ return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
+
+ def set_epoch(self, epoch):
+ super().set_epoch(epoch)
+ for ds in self.datasets:
+ if hasattr(ds, "set_epoch"):
+ ds.set_epoch(epoch)
diff --git a/fairseq-0.10.2/fairseq/data/concat_sentences_dataset.py b/fairseq-0.10.2/fairseq/data/concat_sentences_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..625a29370e90f9d1d7274024afb902ed83a22325
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/concat_sentences_dataset.py
@@ -0,0 +1,54 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from . import FairseqDataset
+
+
+class ConcatSentencesDataset(FairseqDataset):
+ def __init__(self, *datasets):
+ super().__init__()
+ self.datasets = datasets
+ assert all(
+ len(ds) == len(datasets[0]) for ds in datasets
+ ), "datasets must have the same length"
+
+ def __getitem__(self, index):
+ return torch.cat([ds[index] for ds in self.datasets])
+
+ def __len__(self):
+ return len(self.datasets[0])
+
+ def collater(self, samples):
+ return self.datasets[0].collater(samples)
+
+ @property
+ def sizes(self):
+ return sum(ds.sizes for ds in self.datasets)
+
+ def num_tokens(self, index):
+ return sum(ds.num_tokens(index) for ds in self.datasets)
+
+ def size(self, index):
+ return sum(ds.size(index) for ds in self.datasets)
+
+ def ordered_indices(self):
+ return self.datasets[0].ordered_indices()
+
+ @property
+ def supports_prefetch(self):
+ return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)
+
+ def prefetch(self, indices):
+ for ds in self.datasets:
+ if getattr(ds, "supports_prefetch", False):
+ ds.prefetch(indices)
+
+ def set_epoch(self, epoch):
+ super().set_epoch(epoch)
+ for ds in self.datasets:
+ if hasattr(ds, "set_epoch"):
+ ds.set_epoch(epoch)
diff --git a/fairseq-0.10.2/fairseq/data/dictionary.py b/fairseq-0.10.2/fairseq/data/dictionary.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2df08e092350c5d5feb34723fae8744c7286a44
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/dictionary.py
@@ -0,0 +1,387 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+from collections import Counter
+from multiprocessing import Pool
+
+import torch
+from fairseq import utils
+from fairseq.binarizer import safe_readline
+from fairseq.data import data_utils
+from fairseq.file_io import PathManager
+from fairseq.tokenizer import tokenize_line
+
+
+class Dictionary(object):
+ """A mapping from symbols to consecutive integers"""
+
+ def __init__(
+ self,
+ *, # begin keyword-only arguments
+ bos="",
+ pad="",
+ eos="",
+ unk="",
+ extra_special_symbols=None,
+ ):
+ self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
+ self.symbols = []
+ self.count = []
+ self.indices = {}
+ self.bos_index = self.add_symbol(bos)
+ self.pad_index = self.add_symbol(pad)
+ self.eos_index = self.add_symbol(eos)
+ self.unk_index = self.add_symbol(unk)
+ if extra_special_symbols:
+ for s in extra_special_symbols:
+ self.add_symbol(s)
+ self.nspecial = len(self.symbols)
+
+ def __eq__(self, other):
+ return self.indices == other.indices
+
+ def __getitem__(self, idx):
+ if idx < len(self.symbols):
+ return self.symbols[idx]
+ return self.unk_word
+
+ def __len__(self):
+ """Returns the number of symbols in the dictionary"""
+ return len(self.symbols)
+
+ def __contains__(self, sym):
+ return sym in self.indices
+
+ def index(self, sym):
+ """Returns the index of the specified symbol"""
+ assert isinstance(sym, str)
+ if sym in self.indices:
+ return self.indices[sym]
+ return self.unk_index
+
+ def string(
+ self,
+ tensor,
+ bpe_symbol=None,
+ escape_unk=False,
+ extra_symbols_to_ignore=None,
+ unk_string=None,
+ ):
+ """Helper for converting a tensor of token indices to a string.
+
+ Can optionally remove BPE symbols or escape words.
+ """
+ if torch.is_tensor(tensor) and tensor.dim() == 2:
+ return "\n".join(
+ self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore)
+ for t in tensor
+ )
+
+ extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
+ extra_symbols_to_ignore.add(self.eos())
+
+ def token_string(i):
+ if i == self.unk():
+ if unk_string is not None:
+ return unk_string
+ else:
+ return self.unk_string(escape_unk)
+ else:
+ return self[i]
+
+ if hasattr(self, "bos_index"):
+ extra_symbols_to_ignore.add(self.bos())
+
+ sent = " ".join(
+ token_string(i)
+ for i in tensor
+ if utils.item(i) not in extra_symbols_to_ignore
+ )
+
+ return data_utils.post_process(sent, bpe_symbol)
+
+ def unk_string(self, escape=False):
+ """Return unknown string, optionally escaped as: <>"""
+ if escape:
+ return "<{}>".format(self.unk_word)
+ else:
+ return self.unk_word
+
+ def add_symbol(self, word, n=1, overwrite=False):
+ """Adds a word to the dictionary"""
+ if word in self.indices and not overwrite:
+ idx = self.indices[word]
+ self.count[idx] = self.count[idx] + n
+ return idx
+ else:
+ idx = len(self.symbols)
+ self.indices[word] = idx
+ self.symbols.append(word)
+ self.count.append(n)
+ return idx
+
+ def update(self, new_dict):
+ """Updates counts from new dictionary."""
+ for word in new_dict.symbols:
+ idx2 = new_dict.indices[word]
+ if word in self.indices:
+ idx = self.indices[word]
+ self.count[idx] = self.count[idx] + new_dict.count[idx2]
+ else:
+ idx = len(self.symbols)
+ self.indices[word] = idx
+ self.symbols.append(word)
+ self.count.append(new_dict.count[idx2])
+
+ def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
+ """Sort symbols by frequency in descending order, ignoring special ones.
+
+ Args:
+ - threshold defines the minimum word count
+ - nwords defines the total number of words in the final dictionary,
+ including special symbols
+ - padding_factor can be used to pad the dictionary size to be a
+ multiple of 8, which is important on some hardware (e.g., Nvidia
+ Tensor Cores).
+ """
+ if nwords <= 0:
+ nwords = len(self)
+
+ new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial)))
+ new_symbols = self.symbols[: self.nspecial]
+ new_count = self.count[: self.nspecial]
+
+ c = Counter(
+ dict(
+ sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :]))
+ )
+ )
+ for symbol, count in c.most_common(nwords - self.nspecial):
+ if count >= threshold:
+ new_indices[symbol] = len(new_symbols)
+ new_symbols.append(symbol)
+ new_count.append(count)
+ else:
+ break
+
+ assert len(new_symbols) == len(new_indices)
+
+ self.count = list(new_count)
+ self.symbols = list(new_symbols)
+ self.indices = new_indices
+
+ self.pad_to_multiple_(padding_factor)
+
+ def pad_to_multiple_(self, padding_factor):
+ """Pad Dictionary size to be a multiple of *padding_factor*."""
+ if padding_factor > 1:
+ i = 0
+ while len(self) % padding_factor != 0:
+ symbol = "madeupword{:04d}".format(i)
+ self.add_symbol(symbol, n=0)
+ i += 1
+
+ def bos(self):
+ """Helper to get index of beginning-of-sentence symbol"""
+ return self.bos_index
+
+ def pad(self):
+ """Helper to get index of pad symbol"""
+ return self.pad_index
+
+ def eos(self):
+ """Helper to get index of end-of-sentence symbol"""
+ return self.eos_index
+
+ def unk(self):
+ """Helper to get index of unk symbol"""
+ return self.unk_index
+
+ @classmethod
+ def load(cls, f):
+ """Loads the dictionary from a text file with the format:
+
+ ```
+
+
+ ...
+ ```
+ """
+ d = cls()
+ d.add_from_file(f)
+ return d
+
+ def add_from_file(self, f):
+ """
+ Loads a pre-existing dictionary from a text file and adds its symbols
+ to this instance.
+ """
+ if isinstance(f, str):
+ try:
+ with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
+ self.add_from_file(fd)
+ except FileNotFoundError as fnfe:
+ raise fnfe
+ except UnicodeError:
+ raise Exception(
+ "Incorrect encoding detected in {}, please "
+ "rebuild the dataset".format(f)
+ )
+ return
+
+ lines = f.readlines()
+ indices_start_line = self._load_meta(lines)
+
+ for line in lines[indices_start_line:]:
+ try:
+ line, field = line.rstrip().rsplit(" ", 1)
+ if field == "#fairseq:overwrite":
+ overwrite = True
+ line, field = line.rsplit(" ", 1)
+ else:
+ overwrite = False
+ count = int(field)
+ word = line
+ if word in self and not overwrite:
+ raise RuntimeError(
+ "Duplicate word found when loading Dictionary: '{}'. "
+ "Duplicate words can overwrite earlier ones by adding the "
+ "#fairseq:overwrite flag at the end of the corresponding row "
+ "in the dictionary file. If using the Camembert model, please "
+ "download an updated copy of the model file.".format(word)
+ )
+ self.add_symbol(word, n=count, overwrite=overwrite)
+ except ValueError:
+ raise ValueError(
+ "Incorrect dictionary format, expected ' [flags]'"
+ )
+
+ def _save(self, f, kv_iterator):
+ if isinstance(f, str):
+ PathManager.mkdirs(os.path.dirname(f))
+ with PathManager.open(f, "w", encoding="utf-8") as fd:
+ return self.save(fd)
+ for k, v in kv_iterator:
+ print("{} {}".format(k, v), file=f)
+
+ def _get_meta(self):
+ return [], []
+
+ def _load_meta(self, lines):
+ return 0
+
+ def save(self, f):
+ """Stores dictionary into a text file"""
+ ex_keys, ex_vals = self._get_meta()
+ self._save(
+ f,
+ zip(
+ ex_keys + self.symbols[self.nspecial :],
+ ex_vals + self.count[self.nspecial :],
+ ),
+ )
+
+ def dummy_sentence(self, length):
+ t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
+ t[-1] = self.eos()
+ return t
+
+ def encode_line(
+ self,
+ line,
+ line_tokenizer=tokenize_line,
+ add_if_not_exist=True,
+ consumer=None,
+ append_eos=True,
+ reverse_order=False,
+ ):
+ words = line_tokenizer(line)
+ if reverse_order:
+ words = list(reversed(words))
+ nwords = len(words)
+ ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
+
+ for i, word in enumerate(words):
+ if add_if_not_exist:
+ idx = self.add_symbol(word)
+ else:
+ idx = self.index(word)
+ if consumer is not None:
+ consumer(word, idx)
+ ids[i] = idx
+ if append_eos:
+ ids[nwords] = self.eos_index
+ return ids
+
+ @staticmethod
+ def _add_file_to_dictionary_single_worker(
+ filename, tokenize, eos_word, worker_id=0, num_workers=1
+ ):
+ counter = Counter()
+ with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f:
+ size = os.fstat(f.fileno()).st_size
+ chunk_size = size // num_workers
+ offset = worker_id * chunk_size
+ end = offset + chunk_size
+ f.seek(offset)
+ if offset > 0:
+ safe_readline(f) # drop first incomplete line
+ line = f.readline()
+ while line:
+ for word in tokenize(line):
+ counter.update([word])
+ counter.update([eos_word])
+ if f.tell() > end:
+ break
+ line = f.readline()
+ return counter
+
+ @staticmethod
+ def add_file_to_dictionary(filename, dict, tokenize, num_workers):
+ def merge_result(counter):
+ for w, c in sorted(counter.items()):
+ dict.add_symbol(w, c)
+
+ if num_workers > 1:
+ pool = Pool(processes=num_workers)
+ results = []
+ for worker_id in range(num_workers):
+ results.append(
+ pool.apply_async(
+ Dictionary._add_file_to_dictionary_single_worker,
+ (filename, tokenize, dict.eos_word, worker_id, num_workers),
+ )
+ )
+ pool.close()
+ pool.join()
+ for r in results:
+ merge_result(r.get())
+ else:
+ merge_result(
+ Dictionary._add_file_to_dictionary_single_worker(
+ filename, tokenize, dict.eos_word
+ )
+ )
+
+
+class TruncatedDictionary(object):
+ def __init__(self, wrapped_dict, length):
+ self.__class__ = type(
+ wrapped_dict.__class__.__name__,
+ (self.__class__, wrapped_dict.__class__),
+ {},
+ )
+ self.__dict__ = wrapped_dict.__dict__
+ self.wrapped_dict = wrapped_dict
+ self.length = min(len(self.wrapped_dict), length)
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, i):
+ if i < self.length:
+ return self.wrapped_dict[i]
+ return self.wrapped_dict.unk()
diff --git a/fairseq-0.10.2/fairseq/data/encoders/byte_utils.py b/fairseq-0.10.2/fairseq/data/encoders/byte_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a305c080926c2d094b7e8ae48f5331da82025a75
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/encoders/byte_utils.py
@@ -0,0 +1,51 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import re
+
+
+WHITESPACE_NORMALIZER = re.compile(r"\s+")
+SPACE = chr(32)
+SPACE_ESCAPE = chr(9601)
+# excluding non-breaking space (160) here
+PRINTABLE_LATIN = set(
+ list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1))
+)
+BYTE_TO_BCHAR = {
+ b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256)
+}
+BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
+
+
+def byte_encode(x: str) -> str:
+ normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
+ return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
+
+
+def byte_decode(x: str) -> str:
+ try:
+ return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
+ except ValueError:
+ return ""
+
+
+def smart_byte_decode(x: str) -> str:
+ output = byte_decode(x)
+ if output == "":
+ # DP the best recovery (max valid chars) if it's broken
+ n_bytes = len(x)
+ f = [0 for _ in range(n_bytes + 1)]
+ pt = [0 for _ in range(n_bytes + 1)]
+ for i in range(1, n_bytes + 1):
+ f[i], pt[i] = f[i - 1], i - 1
+ for j in range(1, min(4, i) + 1):
+ if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
+ f[i], pt[i] = f[i - j] + 1, i - j
+ cur_pt = n_bytes
+ while cur_pt > 0:
+ if f[cur_pt] == f[pt[cur_pt]] + 1:
+ output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
+ cur_pt = pt[cur_pt]
+ return output
diff --git a/fairseq-0.10.2/fairseq/data/encoders/fastbpe.py b/fairseq-0.10.2/fairseq/data/encoders/fastbpe.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d4ad850409d69a5b2476480a9a3aa229038686
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/encoders/fastbpe.py
@@ -0,0 +1,35 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from fairseq import file_utils
+from fairseq.data.encoders import register_bpe
+
+
+@register_bpe("fastbpe")
+class fastBPE(object):
+ @staticmethod
+ def add_args(parser):
+ # fmt: off
+ parser.add_argument('--bpe-codes', type=str,
+ help='path to fastBPE BPE')
+ # fmt: on
+
+ def __init__(self, args):
+ if args.bpe_codes is None:
+ raise ValueError("--bpe-codes is required for --bpe=fastbpe")
+ codes = file_utils.cached_path(args.bpe_codes)
+ try:
+ import fastBPE
+
+ self.bpe = fastBPE.fastBPE(codes)
+ self.bpe_symbol = "@@ "
+ except ImportError:
+ raise ImportError("Please install fastBPE with: pip install fastBPE")
+
+ def encode(self, x: str) -> str:
+ return self.bpe.apply([x])[0]
+
+ def decode(self, x: str) -> str:
+ return (x + " ").replace(self.bpe_symbol, "").rstrip()
diff --git a/fairseq-0.10.2/fairseq/data/encoders/gpt2_bpe_utils.py b/fairseq-0.10.2/fairseq/data/encoders/gpt2_bpe_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..688d4e36e358df2dcc432d37d3e57bd81e2f1ed1
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/encoders/gpt2_bpe_utils.py
@@ -0,0 +1,140 @@
+"""
+Byte pair encoding utilities from GPT-2.
+
+Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
+Original license: MIT
+"""
+
+import json
+from functools import lru_cache
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1))
+ + list(range(ord("¡"), ord("¬") + 1))
+ + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2 ** 8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2 ** 8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+class Encoder:
+ def __init__(self, encoder, bpe_merges, errors="replace"):
+ self.encoder = encoder
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {}
+
+ try:
+ import regex as re
+
+ self.re = re
+ except ImportError:
+ raise ImportError("Please install regex with: pip install regex")
+
+ # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
+ self.pat = self.re.compile(
+ r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
+ )
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ for token in self.re.findall(self.pat, text):
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
+ bpe_tokens.extend(
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
+ )
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = "".join([self.decoder.get(token, token) for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode(
+ "utf-8", errors=self.errors
+ )
+ return text
+
+
+def get_encoder(encoder_json_path, vocab_bpe_path):
+ with open(encoder_json_path, "r") as f:
+ encoder = json.load(f)
+ with open(vocab_bpe_path, "r", encoding="utf-8") as f:
+ bpe_data = f.read()
+ bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
+ return Encoder(
+ encoder=encoder,
+ bpe_merges=bpe_merges,
+ )
diff --git a/fairseq-0.10.2/fairseq/data/encoders/moses_tokenizer.py b/fairseq-0.10.2/fairseq/data/encoders/moses_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c24844263a98c58a160b624b7741947ee290884
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/encoders/moses_tokenizer.py
@@ -0,0 +1,51 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from fairseq.data.encoders import register_tokenizer
+
+
+@register_tokenizer("moses")
+class MosesTokenizer(object):
+ @staticmethod
+ def add_args(parser):
+ # fmt: off
+ parser.add_argument('--moses-source-lang', metavar='SRC',
+ help='source language')
+ parser.add_argument('--moses-target-lang', metavar='TARGET',
+ help='target language')
+ parser.add_argument('--moses-no-dash-splits', action='store_true', default=False,
+ help='don\'t apply dash split rules')
+ parser.add_argument('--moses-no-escape', action='store_true', default=False,
+ help='don\'t perform HTML escaping on apostrophy, quotes, etc.')
+ # fmt: on
+
+ def __init__(self, args):
+ self.args = args
+
+ if getattr(args, "moses_source_lang", None) is None:
+ args.moses_source_lang = getattr(args, "source_lang", "en")
+ if getattr(args, "moses_target_lang", None) is None:
+ args.moses_target_lang = getattr(args, "target_lang", "en")
+
+ try:
+ from sacremoses import MosesTokenizer, MosesDetokenizer
+
+ self.tok = MosesTokenizer(args.moses_source_lang)
+ self.detok = MosesDetokenizer(args.moses_target_lang)
+ except ImportError:
+ raise ImportError(
+ "Please install Moses tokenizer with: pip install sacremoses"
+ )
+
+ def encode(self, x: str) -> str:
+ return self.tok.tokenize(
+ x,
+ aggressive_dash_splits=(not self.args.moses_no_dash_splits),
+ return_str=True,
+ escape=(not self.args.moses_no_escape),
+ )
+
+ def decode(self, x: str) -> str:
+ return self.detok.detokenize(x.split())
diff --git a/fairseq-0.10.2/fairseq/data/encoders/nltk_tokenizer.py b/fairseq-0.10.2/fairseq/data/encoders/nltk_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b617e7314f0e3aae298b685eea54cbb16312203
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/encoders/nltk_tokenizer.py
@@ -0,0 +1,23 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from fairseq.data.encoders import register_tokenizer
+
+
+@register_tokenizer("nltk")
+class NLTKTokenizer(object):
+ def __init__(self, source_lang=None, target_lang=None):
+ try:
+ from nltk.tokenize import word_tokenize
+
+ self.word_tokenize = word_tokenize
+ except ImportError:
+ raise ImportError("Please install nltk with: pip install nltk")
+
+ def encode(self, x: str) -> str:
+ return " ".join(self.word_tokenize(x))
+
+ def decode(self, x: str) -> str:
+ return x
diff --git a/fairseq-0.10.2/fairseq/data/encoders/subword_nmt_bpe.py b/fairseq-0.10.2/fairseq/data/encoders/subword_nmt_bpe.py
new file mode 100644
index 0000000000000000000000000000000000000000..e85f99af396a77d273ba2a4ecaf3db399017b9af
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/encoders/subword_nmt_bpe.py
@@ -0,0 +1,54 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from fairseq import file_utils
+from fairseq.data.encoders import register_bpe
+
+
+@register_bpe("subword_nmt")
+class SubwordNMTBPE(object):
+ @staticmethod
+ def add_args(parser):
+ # fmt: off
+ parser.add_argument('--bpe-codes', type=str,
+ help='path to subword NMT BPE')
+ parser.add_argument('--bpe-separator', default='@@',
+ help='BPE separator')
+ # fmt: on
+
+ def __init__(self, args):
+ if args.bpe_codes is None:
+ raise ValueError("--bpe-codes is required for --bpe=subword_nmt")
+ codes = file_utils.cached_path(args.bpe_codes)
+ try:
+ from subword_nmt import apply_bpe
+
+ bpe_parser = apply_bpe.create_parser()
+ bpe_args = bpe_parser.parse_args(
+ [
+ "--codes",
+ codes,
+ "--separator",
+ args.bpe_separator,
+ ]
+ )
+ self.bpe = apply_bpe.BPE(
+ bpe_args.codes,
+ bpe_args.merges,
+ bpe_args.separator,
+ None,
+ bpe_args.glossaries,
+ )
+ self.bpe_symbol = bpe_args.separator + " "
+ except ImportError:
+ raise ImportError(
+ "Please install subword_nmt with: pip install subword-nmt"
+ )
+
+ def encode(self, x: str) -> str:
+ return self.bpe.process_line(x)
+
+ def decode(self, x: str) -> str:
+ return (x + " ").replace(self.bpe_symbol, "").rstrip()
diff --git a/fairseq-0.10.2/fairseq/data/encoders/utils.py b/fairseq-0.10.2/fairseq/data/encoders/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d93eb532ef84f0e2bc708b777229ab2cb76ca14b
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/encoders/utils.py
@@ -0,0 +1,30 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from fairseq.data import encoders
+
+
+def get_whole_word_mask(args, dictionary):
+ bpe = encoders.build_bpe(args)
+ if bpe is not None:
+
+ def is_beginning_of_word(i):
+ if i < dictionary.nspecial:
+ # special elements are always considered beginnings
+ return True
+ tok = dictionary[i]
+ if tok.startswith("madeupword"):
+ return True
+ try:
+ return bpe.is_beginning_of_word(tok)
+ except ValueError:
+ return True
+
+ mask_whole_words = torch.ByteTensor(
+ list(map(is_beginning_of_word, range(len(dictionary))))
+ )
+ return mask_whole_words
+ return None
diff --git a/fairseq-0.10.2/fairseq/data/fairseq_dataset.py b/fairseq-0.10.2/fairseq/data/fairseq_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed08c1ba200f3d4b95053c02aaa227169fe80d26
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/fairseq_dataset.py
@@ -0,0 +1,191 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch.utils.data
+from fairseq.data import data_utils
+
+
+class EpochListening:
+ """Mixin for receiving updates whenever the epoch increments."""
+
+ @property
+ def can_reuse_epoch_itr_across_epochs(self):
+ """
+ Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for
+ this dataset across epochs.
+
+ This needs to return ``False`` if the sample sizes can change across
+ epochs, in which case we may need to regenerate batches at each epoch.
+ If your dataset relies in ``set_epoch`` then you should consider setting
+ this to ``False``.
+ """
+ return True
+
+ def set_epoch(self, epoch):
+ """Will receive the updated epoch number at the beginning of the epoch."""
+ pass
+
+
+class FairseqDataset(torch.utils.data.Dataset, EpochListening):
+ """A dataset that provides helpers for batching."""
+
+ def __getitem__(self, index):
+ raise NotImplementedError
+
+ def __len__(self):
+ raise NotImplementedError
+
+ def collater(self, samples):
+ """Merge a list of samples to form a mini-batch.
+
+ Args:
+ samples (List[dict]): samples to collate
+
+ Returns:
+ dict: a mini-batch suitable for forwarding with a Model
+ """
+ raise NotImplementedError
+
+ def num_tokens(self, index):
+ """Return the number of tokens in a sample. This value is used to
+ enforce ``--max-tokens`` during batching."""
+ raise NotImplementedError
+
+ def size(self, index):
+ """Return an example's size as a float or tuple. This value is used when
+ filtering a dataset with ``--max-positions``."""
+ raise NotImplementedError
+
+ def ordered_indices(self):
+ """Return an ordered list of indices. Batches will be constructed based
+ on this order."""
+ return np.arange(len(self), dtype=np.int64)
+
+ @property
+ def supports_prefetch(self):
+ """Whether this dataset supports prefetching."""
+ return False
+
+ def attr(self, attr: str, index: int):
+ return getattr(self, attr, None)
+
+ def prefetch(self, indices):
+ """Prefetch the data required for this epoch."""
+ raise NotImplementedError
+
+ def get_batch_shapes(self):
+ """
+ Return a list of valid batch shapes, for example::
+
+ [(8, 512), (16, 256), (32, 128)]
+
+ The first dimension of each tuple is the batch size and can be ``None``
+ to automatically infer the max batch size based on ``--max-tokens``.
+ The second dimension of each tuple is the max supported length as given
+ by :func:`fairseq.data.FairseqDataset.num_tokens`.
+
+ This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size`
+ to restrict batch shapes. This is useful on TPUs to avoid too many
+ dynamic shapes (and recompilations).
+ """
+ return None
+
+ def batch_by_size(
+ self,
+ indices,
+ max_tokens=None,
+ max_sentences=None,
+ required_batch_size_multiple=1,
+ ):
+ """
+ Given an ordered set of indices, return batches according to
+ *max_tokens*, *max_sentences* and *required_batch_size_multiple*.
+ """
+ from fairseq.data import data_utils
+
+ fixed_shapes = self.get_batch_shapes()
+ if fixed_shapes is not None:
+
+ def adjust_bsz(bsz, num_tokens):
+ if bsz is None:
+ assert max_tokens is not None, "Must specify --max-tokens"
+ bsz = max_tokens // num_tokens
+ if max_sentences is not None:
+ bsz = min(bsz, max_sentences)
+ elif (
+ bsz >= required_batch_size_multiple
+ and bsz % required_batch_size_multiple != 0
+ ):
+ bsz -= bsz % required_batch_size_multiple
+ return bsz
+
+ fixed_shapes = np.array(
+ [
+ [adjust_bsz(bsz, num_tokens), num_tokens]
+ for (bsz, num_tokens) in fixed_shapes
+ ]
+ )
+
+ return data_utils.batch_by_size(
+ indices,
+ num_tokens_fn=self.num_tokens,
+ max_tokens=max_tokens,
+ max_sentences=max_sentences,
+ required_batch_size_multiple=required_batch_size_multiple,
+ fixed_shapes=fixed_shapes,
+ )
+
+ def filter_indices_by_size(self, indices, max_sizes):
+ """
+ Filter a list of sample indices. Remove those that are longer than
+ specified in *max_sizes*.
+
+ WARNING: don't update, override method in child classes
+
+ Args:
+ indices (np.array): original array of sample indices
+ max_sizes (int or list[int] or tuple[int]): max sample size,
+ can be defined separately for src and tgt (then list or tuple)
+
+ Returns:
+ np.array: filtered sample array
+ list: list of removed indices
+ """
+ if isinstance(max_sizes, float) or isinstance(max_sizes, int):
+ if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
+ ignored = indices[self.sizes[indices] > max_sizes].tolist()
+ indices = indices[self.sizes[indices] <= max_sizes]
+ elif (
+ hasattr(self, "sizes")
+ and isinstance(self.sizes, list)
+ and len(self.sizes) == 1
+ ):
+ ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
+ indices = indices[self.sizes[0][indices] <= max_sizes]
+ else:
+ indices, ignored = data_utils._filter_by_size_dynamic(
+ indices, self.size, max_sizes
+ )
+ else:
+ indices, ignored = data_utils._filter_by_size_dynamic(
+ indices, self.size, max_sizes
+ )
+ return indices, ignored
+
+ @property
+ def supports_fetch_outside_dataloader(self):
+ """Whether this dataset supports fetching outside the workers of the dataloader."""
+ return True
+
+
+class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening):
+ """
+ For datasets that need to be read sequentially, usually because the data is
+ being streamed or otherwise can't be manipulated on a single machine.
+ """
+
+ def __iter__(self):
+ raise NotImplementedError
diff --git a/fairseq-0.10.2/fairseq/data/fasta_dataset.py b/fairseq-0.10.2/fairseq/data/fasta_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..007011974a997fd7446dd29d7eba097d7513bab0
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/fasta_dataset.py
@@ -0,0 +1,107 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import subprocess
+import threading
+from pathlib import Path
+
+import numpy as np
+import torch
+
+
+def fasta_file_path(prefix_path):
+ return prefix_path + ".fasta"
+
+
+class FastaDataset(torch.utils.data.Dataset):
+ """
+ For loading protein sequence datasets in the common FASTA data format
+ """
+
+ def __init__(self, path: str, cache_indices=False):
+ self.fn = fasta_file_path(path)
+ self.threadlocal = threading.local()
+ self.cache = Path(f"{path}.fasta.idx.npy")
+ if cache_indices:
+ if self.cache.exists():
+ self.offsets, self.sizes = np.load(self.cache)
+ else:
+ self.offsets, self.sizes = self._build_index(path)
+ np.save(self.cache, np.stack([self.offsets, self.sizes]))
+ else:
+ self.offsets, self.sizes = self._build_index(path)
+
+ def _get_file(self):
+ if not hasattr(self.threadlocal, "f"):
+ self.threadlocal.f = open(self.fn, "r")
+ return self.threadlocal.f
+
+ def __getitem__(self, idx):
+ f = self._get_file()
+ f.seek(self.offsets[idx])
+ desc = f.readline().strip()
+ line = f.readline()
+ seq = ""
+ while line != "" and line[0] != ">":
+ seq += line.strip()
+ line = f.readline()
+ return desc, seq
+
+ def __len__(self):
+ return self.offsets.size
+
+ def _build_index(self, path: str):
+ # Use grep and awk to get 100M/s on local SSD.
+ # Should process your enormous 100G fasta in ~10 min single core...
+ path = fasta_file_path(path)
+ bytes_offsets = subprocess.check_output(
+ f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
+ "| grep --byte-offset '^>' -o | cut -d: -f1",
+ shell=True,
+ )
+ fasta_lengths = subprocess.check_output(
+ f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
+ "| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'",
+ shell=True,
+ )
+ bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ")
+ sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ")
+ return bytes_np, sizes_np
+
+ def __setstate__(self, state):
+ self.__dict__ = state
+ self.threadlocal = threading.local()
+
+ def __getstate__(self):
+ d = {}
+ for i, v in self.__dict__.items():
+ if i != "threadlocal":
+ d[i] = v
+ return d
+
+ def __del__(self):
+ if hasattr(self.threadlocal, "f"):
+ self.threadlocal.f.close()
+ del self.threadlocal.f
+
+ @staticmethod
+ def exists(path):
+ return os.path.exists(fasta_file_path(path))
+
+
+class EncodedFastaDataset(FastaDataset):
+ """
+ The FastaDataset returns raw sequences - this allows us to return
+ indices with a dictionary instead.
+ """
+
+ def __init__(self, path, dictionary):
+ super().__init__(path, cache_indices=True)
+ self.dictionary = dictionary
+
+ def __getitem__(self, idx):
+ desc, seq = super().__getitem__(idx)
+ return self.dictionary.encode_line(seq, line_tokenizer=list).long()
diff --git a/fairseq-0.10.2/fairseq/data/id_dataset.py b/fairseq-0.10.2/fairseq/data/id_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e4d7969cf2a26e852b466f165a6fadabae3b35f
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/id_dataset.py
@@ -0,0 +1,19 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from . import FairseqDataset
+
+
+class IdDataset(FairseqDataset):
+ def __getitem__(self, index):
+ return index
+
+ def __len__(self):
+ return 0
+
+ def collater(self, samples):
+ return torch.tensor(samples)
diff --git a/fairseq-0.10.2/fairseq/data/language_pair_dataset.py b/fairseq-0.10.2/fairseq/data/language_pair_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..db95e14f55d459e83d0665f8b6174ea3eaef3cd3
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/language_pair_dataset.py
@@ -0,0 +1,475 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+
+import numpy as np
+import torch
+from fairseq.data import FairseqDataset, data_utils
+
+
+logger = logging.getLogger(__name__)
+
+
+def collate(
+ samples,
+ pad_idx,
+ eos_idx,
+ left_pad_source=True,
+ left_pad_target=False,
+ input_feeding=True,
+ pad_to_length=None,
+ pad_to_multiple=1,
+):
+ if len(samples) == 0:
+ return {}
+
+ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
+ return data_utils.collate_tokens(
+ [s[key] for s in samples],
+ pad_idx,
+ eos_idx,
+ left_pad,
+ move_eos_to_beginning,
+ pad_to_length=pad_to_length,
+ pad_to_multiple=pad_to_multiple,
+ )
+
+ def check_alignment(alignment, src_len, tgt_len):
+ if alignment is None or len(alignment) == 0:
+ return False
+ if (
+ alignment[:, 0].max().item() >= src_len - 1
+ or alignment[:, 1].max().item() >= tgt_len - 1
+ ):
+ logger.warning("alignment size mismatch found, skipping alignment!")
+ return False
+ return True
+
+ def compute_alignment_weights(alignments):
+ """
+ Given a tensor of shape [:, 2] containing the source-target indices
+ corresponding to the alignments, a weight vector containing the
+ inverse frequency of each target index is computed.
+ For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
+ a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
+ index 3 is repeated twice)
+ """
+ align_tgt = alignments[:, 1]
+ _, align_tgt_i, align_tgt_c = torch.unique(
+ align_tgt, return_inverse=True, return_counts=True
+ )
+ align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
+ return 1.0 / align_weights.float()
+
+ id = torch.LongTensor([s["id"] for s in samples])
+
+ # import pdb;pdb.set_trace()
+
+ src_tokens = merge(
+ "source",
+ left_pad=left_pad_source,
+ pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
+ )
+ # sort by descending source length
+ src_lengths = torch.LongTensor(
+ [s["source"].ne(pad_idx).long().sum() for s in samples]
+ )
+
+ # import pdb;pdb.set_trace()
+
+ src_lengths, sort_order = src_lengths.sort(descending=True)
+ id = id.index_select(0, sort_order)
+ src_tokens = src_tokens.index_select(0, sort_order)
+
+ prev_output_tokens = None
+ target = None
+ if samples[0].get("target", None) is not None:
+ target = merge(
+ "target",
+ left_pad=left_pad_target,
+ pad_to_length=pad_to_length["target"]
+ if pad_to_length is not None
+ else None,
+ )
+ target = target.index_select(0, sort_order)
+ tgt_lengths = torch.LongTensor(
+ [s["target"].ne(pad_idx).long().sum() for s in samples]
+ ).index_select(0, sort_order)
+ ntokens = tgt_lengths.sum().item()
+
+ if samples[0].get("prev_output_tokens", None) is not None:
+ prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
+ elif input_feeding:
+ # we create a shifted version of targets for feeding the
+ # previous output token(s) into the next decoder step
+ prev_output_tokens = merge(
+ "target",
+ left_pad=left_pad_target,
+ move_eos_to_beginning=True,
+ pad_to_length=pad_to_length["target"]
+ if pad_to_length is not None
+ else None,
+ )
+ else:
+ ntokens = src_lengths.sum().item()
+
+ batch = {
+ "id": id,
+ "nsentences": len(samples),
+ "ntokens": ntokens,
+ "net_input": {
+ "src_tokens": src_tokens,
+ "src_lengths": src_lengths,
+ },
+ "target": target,
+ }
+ if prev_output_tokens is not None:
+ batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
+ 0, sort_order
+ )
+
+ if samples[0].get("alignment", None) is not None:
+ bsz, tgt_sz = batch["target"].shape
+ src_sz = batch["net_input"]["src_tokens"].shape[1]
+
+ offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
+ offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz
+ if left_pad_source:
+ offsets[:, 0] += src_sz - src_lengths
+ if left_pad_target:
+ offsets[:, 1] += tgt_sz - tgt_lengths
+
+ alignments = [
+ alignment + offset
+ for align_idx, offset, src_len, tgt_len in zip(
+ sort_order, offsets, src_lengths, tgt_lengths
+ )
+ for alignment in [samples[align_idx]["alignment"].view(-1, 2)]
+ if check_alignment(alignment, src_len, tgt_len)
+ ]
+
+ if len(alignments) > 0:
+ alignments = torch.cat(alignments, dim=0)
+ align_weights = compute_alignment_weights(alignments)
+
+ batch["alignments"] = alignments
+ batch["align_weights"] = align_weights
+
+ if samples[0].get("constraints", None) is not None:
+ # Collate the packed constraints across the samples, padding to
+ # the length of the longest sample.
+ lens = [sample.get("constraints").size(0) for sample in samples]
+ max_len = max(lens)
+ constraints = torch.zeros((len(samples), max(lens))).long()
+ for i, sample in enumerate(samples):
+ constraints[i, 0 : lens[i]] = samples[i].get("constraints")
+ batch["constraints"] = constraints
+
+ return batch
+
+
+class LanguagePairDataset(FairseqDataset):
+ """
+ A pair of torch.utils.data.Datasets.
+
+ Args:
+ src (torch.utils.data.Dataset): source dataset to wrap
+ src_sizes (List[int]): source sentence lengths
+ src_dict (~fairseq.data.Dictionary): source vocabulary
+ tgt (torch.utils.data.Dataset, optional): target dataset to wrap
+ tgt_sizes (List[int], optional): target sentence lengths
+ tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
+ left_pad_source (bool, optional): pad source tensors on the left side
+ (default: True).
+ left_pad_target (bool, optional): pad target tensors on the left side
+ (default: False).
+ shuffle (bool, optional): shuffle dataset elements before batching
+ (default: True).
+ input_feeding (bool, optional): create a shifted version of the targets
+ to be passed into the model for teacher forcing (default: True).
+ remove_eos_from_source (bool, optional): if set, removes eos from end
+ of source if it's present (default: False).
+ append_eos_to_target (bool, optional): if set, appends eos to end of
+ target if it's absent (default: False).
+ align_dataset (torch.utils.data.Dataset, optional): dataset
+ containing alignments.
+ constraints (Tensor, optional): 2d tensor with a concatenated, zero-
+ delimited list of constraints for each sentence.
+ append_bos (bool, optional): if set, appends bos to the beginning of
+ source/target sentence.
+ num_buckets (int, optional): if set to a value greater than 0, then
+ batches will be bucketed into the given number of batch shapes.
+ src_lang_id (int, optional): source language ID, if set, the collated batch
+ will contain a field 'src_lang_id' in 'net_input' which indicates the
+ source language of the samples.
+ tgt_lang_id (int, optional): target language ID, if set, the collated batch
+ will contain a field 'tgt_lang_id' which indicates the target language
+ of the samples.
+ """
+
+ def __init__(
+ self,
+ src,
+ src_sizes,
+ src_dict,
+ tgt=None,
+ tgt_sizes=None,
+ tgt_dict=None,
+ left_pad_source=True,
+ left_pad_target=False,
+ shuffle=True,
+ input_feeding=True,
+ remove_eos_from_source=False,
+ append_eos_to_target=False,
+ align_dataset=None,
+ constraints=None,
+ append_bos=False,
+ eos=None,
+ num_buckets=0,
+ src_lang_id=None,
+ tgt_lang_id=None,
+ pad_to_multiple=1,
+ ):
+ if tgt_dict is not None:
+ assert src_dict.pad() == tgt_dict.pad()
+ assert src_dict.eos() == tgt_dict.eos()
+ assert src_dict.unk() == tgt_dict.unk()
+ if tgt is not None:
+ assert len(src) == len(
+ tgt
+ ), "Source and target must contain the same number of examples"
+ self.src = src
+ self.tgt = tgt
+ self.src_sizes = np.array(src_sizes)
+ self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
+ self.sizes = (
+ np.vstack((self.src_sizes, self.tgt_sizes)).T
+ if self.tgt_sizes is not None
+ else self.src_sizes
+ )
+ self.src_dict = src_dict
+ self.tgt_dict = tgt_dict
+ self.left_pad_source = left_pad_source
+ self.left_pad_target = left_pad_target
+ self.shuffle = shuffle
+ self.input_feeding = input_feeding
+ self.remove_eos_from_source = remove_eos_from_source
+ self.append_eos_to_target = append_eos_to_target
+ self.align_dataset = align_dataset
+ if self.align_dataset is not None:
+ assert (
+ self.tgt_sizes is not None
+ ), "Both source and target needed when alignments are provided"
+ self.constraints = constraints
+ self.append_bos = append_bos
+ self.eos = eos if eos is not None else src_dict.eos()
+ self.src_lang_id = src_lang_id
+ self.tgt_lang_id = tgt_lang_id
+ if num_buckets > 0:
+ from fairseq.data import BucketPadLengthDataset
+
+ self.src = BucketPadLengthDataset(
+ self.src,
+ sizes=self.src_sizes,
+ num_buckets=num_buckets,
+ pad_idx=self.src_dict.pad(),
+ left_pad=self.left_pad_source,
+ )
+ self.src_sizes = self.src.sizes
+ logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
+ if self.tgt is not None:
+ self.tgt = BucketPadLengthDataset(
+ self.tgt,
+ sizes=self.tgt_sizes,
+ num_buckets=num_buckets,
+ pad_idx=self.tgt_dict.pad(),
+ left_pad=self.left_pad_target,
+ )
+ self.tgt_sizes = self.tgt.sizes
+ logger.info(
+ "bucketing target lengths: {}".format(list(self.tgt.buckets))
+ )
+
+ # determine bucket sizes using self.num_tokens, which will return
+ # the padded lengths (thanks to BucketPadLengthDataset)
+ num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
+ self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
+ self.buckets = [
+ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
+ ]
+ else:
+ self.buckets = None
+ self.pad_to_multiple = pad_to_multiple
+
+ def get_batch_shapes(self):
+ return self.buckets
+
+ def __getitem__(self, index):
+ tgt_item = self.tgt[index] if self.tgt is not None else None
+ src_item = self.src[index]
+ # Append EOS to end of tgt sentence if it does not have an EOS and remove
+ # EOS from end of src sentence if it exists. This is useful when we use
+ # use existing datasets for opposite directions i.e., when we want to
+ # use tgt_dataset as src_dataset and vice versa
+ if self.append_eos_to_target:
+ eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
+ if self.tgt and self.tgt[index][-1] != eos:
+ tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
+
+ if self.append_bos:
+ bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
+ if self.tgt and self.tgt[index][0] != bos:
+ tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
+
+ bos = self.src_dict.bos()
+ if self.src[index][0] != bos:
+ src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
+
+ if self.remove_eos_from_source:
+ eos = self.src_dict.eos()
+ if self.src[index][-1] == eos:
+ src_item = self.src[index][:-1]
+
+ example = {
+ "id": index,
+ "source": src_item,
+ "target": tgt_item,
+ }
+ if self.align_dataset is not None:
+ example["alignment"] = self.align_dataset[index]
+ if self.constraints is not None:
+ example["constraints"] = self.constraints[index]
+ return example
+
+ def __len__(self):
+ return len(self.src)
+
+ def collater(self, samples, pad_to_length=None):
+ """Merge a list of samples to form a mini-batch.
+
+ Args:
+ samples (List[dict]): samples to collate
+ pad_to_length (dict, optional): a dictionary of
+ {'source': source_pad_to_length, 'target': target_pad_to_length}
+ to indicate the max length to pad to in source and target respectively.
+
+ Returns:
+ dict: a mini-batch with the following keys:
+
+ - `id` (LongTensor): example IDs in the original input order
+ - `ntokens` (int): total number of tokens in the batch
+ - `net_input` (dict): the input to the Model, containing keys:
+
+ - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
+ the source sentence of shape `(bsz, src_len)`. Padding will
+ appear on the left if *left_pad_source* is ``True``.
+ - `src_lengths` (LongTensor): 1D Tensor of the unpadded
+ lengths of each source sentence of shape `(bsz)`
+ - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
+ tokens in the target sentence, shifted right by one
+ position for teacher forcing, of shape `(bsz, tgt_len)`.
+ This key will not be present if *input_feeding* is
+ ``False``. Padding will appear on the left if
+ *left_pad_target* is ``True``.
+ - `src_lang_id` (LongTensor): a long Tensor which contains source
+ language IDs of each sample in the batch
+
+ - `target` (LongTensor): a padded 2D Tensor of tokens in the
+ target sentence of shape `(bsz, tgt_len)`. Padding will appear
+ on the left if *left_pad_target* is ``True``.
+ - `tgt_lang_id` (LongTensor): a long Tensor which contains target language
+ IDs of each sample in the batch
+ """
+ res = collate(
+ samples,
+ pad_idx=self.src_dict.pad(),
+ eos_idx=self.eos,
+ left_pad_source=self.left_pad_source,
+ left_pad_target=self.left_pad_target,
+ input_feeding=self.input_feeding,
+ pad_to_length=pad_to_length,
+ pad_to_multiple=self.pad_to_multiple,
+ )
+ if self.src_lang_id is not None or self.tgt_lang_id is not None:
+ src_tokens = res["net_input"]["src_tokens"]
+ bsz = src_tokens.size(0)
+ if self.src_lang_id is not None:
+ res["net_input"]["src_lang_id"] = (
+ torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
+ )
+ if self.tgt_lang_id is not None:
+ res["tgt_lang_id"] = (
+ torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
+ )
+ return res
+
+ def num_tokens(self, index):
+ """Return the number of tokens in a sample. This value is used to
+ enforce ``--max-tokens`` during batching."""
+ return max(
+ self.src_sizes[index],
+ self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
+ )
+
+ def size(self, index):
+ """Return an example's size as a float or tuple. This value is used when
+ filtering a dataset with ``--max-positions``."""
+ return (
+ self.src_sizes[index],
+ self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
+ )
+
+ def ordered_indices(self):
+ """Return an ordered list of indices. Batches will be constructed based
+ on this order."""
+ if self.shuffle:
+ indices = np.random.permutation(len(self)).astype(np.int64)
+ else:
+ indices = np.arange(len(self), dtype=np.int64)
+ if self.buckets is None:
+ # sort by target length, then source length
+ if self.tgt_sizes is not None:
+ indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
+ return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
+ else:
+ # sort by bucketed_num_tokens, which is:
+ # max(padded_src_len, padded_tgt_len)
+ return indices[
+ np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
+ ]
+
+ @property
+ def supports_prefetch(self):
+ return getattr(self.src, "supports_prefetch", False) and (
+ getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
+ )
+
+ def prefetch(self, indices):
+ self.src.prefetch(indices)
+ if self.tgt is not None:
+ self.tgt.prefetch(indices)
+ if self.align_dataset is not None:
+ self.align_dataset.prefetch(indices)
+
+ def filter_indices_by_size(self, indices, max_sizes):
+ """Filter a list of sample indices. Remove those that are longer
+ than specified in max_sizes.
+
+ Args:
+ indices (np.array): original array of sample indices
+ max_sizes (int or list[int] or tuple[int]): max sample size,
+ can be defined separately for src and tgt (then list or tuple)
+
+ Returns:
+ np.array: filtered sample array
+ list: list of removed indices
+ """
+ return data_utils.filter_paired_dataset_indices_by_size(
+ self.src_sizes,
+ self.tgt_sizes,
+ indices,
+ max_sizes,
+ )
diff --git a/fairseq-0.10.2/fairseq/data/legacy/__init__.py b/fairseq-0.10.2/fairseq/data/legacy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bd5c72b5e9d7f67fb7e4ef10808d7ec08967ff4
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/legacy/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .block_pair_dataset import BlockPairDataset
+from .masked_lm_dataset import MaskedLMDataset
+from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
+
+
+__all__ = [
+ "BertDictionary",
+ "BlockPairDataset",
+ "MaskedLMDataset",
+ "MaskedLMDictionary",
+]
diff --git a/fairseq-0.10.2/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b876ac07b0744e62e02a10a224b6fa998f455847
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..558fa969915acdaf3488710fdfb21bb23dc60655
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68e11cf8a9a95ee0b3416c47c59f1dc79a3fb169
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3da204c7c3e48e0fea21257d1ac5feb6600e637
Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/data/legacy/block_pair_dataset.py b/fairseq-0.10.2/fairseq/data/legacy/block_pair_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba069b46052286c531b4f9706d96788732cd2ad2
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/legacy/block_pair_dataset.py
@@ -0,0 +1,311 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import numpy as np
+import torch
+from fairseq.data import FairseqDataset
+
+
+class BlockPairDataset(FairseqDataset):
+ """Break a Dataset of tokens into sentence pair blocks for next sentence
+ prediction as well as masked language model.
+
+ High-level logics are:
+ 1. break input tensor to tensor blocks
+ 2. pair the blocks with 50% next sentence and 50% random sentence
+ 3. return paired blocks as well as related segment labels
+
+ Args:
+ dataset (~torch.utils.data.Dataset): dataset to break into blocks
+ sizes: array of sentence lengths
+ dictionary: dictionary for the task
+ block_size: maximum block size
+ break_mode: mode for breaking copurs into block pairs. currently we support
+ 2 modes
+ doc: respect document boundaries and each part of the pair should belong to on document
+ none: don't respect any boundary and cut tokens evenly
+ short_seq_prob: probability for generating shorter block pairs
+ doc_break_size: Size for empty line separating documents. Typically 1 if
+ the sentences have eos, 0 otherwise.
+ """
+
+ def __init__(
+ self,
+ dataset,
+ dictionary,
+ sizes,
+ block_size,
+ break_mode="doc",
+ short_seq_prob=0.1,
+ doc_break_size=1,
+ ):
+ super().__init__()
+ self.dataset = dataset
+ self.pad = dictionary.pad()
+ self.eos = dictionary.eos()
+ self.cls = dictionary.cls()
+ self.mask = dictionary.mask()
+ self.sep = dictionary.sep()
+ self.break_mode = break_mode
+ self.dictionary = dictionary
+ self.short_seq_prob = short_seq_prob
+ self.block_indices = []
+
+ assert len(dataset) == len(sizes)
+
+ if break_mode == "doc":
+ cur_doc = []
+ for sent_id, sz in enumerate(sizes):
+ assert doc_break_size == 0 or sz != 0, (
+ "when doc_break_size is non-zero, we expect documents to be"
+ "separated by a blank line with a single eos."
+ )
+ # empty line as document separator
+ if sz == doc_break_size:
+ if len(cur_doc) == 0:
+ continue
+ self.block_indices.append(cur_doc)
+ cur_doc = []
+ else:
+ cur_doc.append(sent_id)
+ max_num_tokens = block_size - 3 # Account for [CLS], [SEP], [SEP]
+ self.sent_pairs = []
+ self.sizes = []
+ for doc_id, doc in enumerate(self.block_indices):
+ self._generate_sentence_pair(doc, doc_id, max_num_tokens, sizes)
+ elif break_mode is None or break_mode == "none":
+ # each block should have half of the block size since we are constructing block pair
+ sent_length = (block_size - 3) // 2
+ total_len = sum(dataset.sizes)
+ length = math.ceil(total_len / sent_length)
+
+ def block_at(i):
+ start = i * sent_length
+ end = min(start + sent_length, total_len)
+ return (start, end)
+
+ sent_indices = np.array([block_at(i) for i in range(length)])
+ sent_sizes = np.array([e - s for s, e in sent_indices])
+ dataset_index = self._sent_to_dataset_index(sent_sizes)
+
+ # pair sentences
+ self._pair_sentences(dataset_index)
+ else:
+ raise ValueError("Invalid break_mode: " + break_mode)
+
+ def _pair_sentences(self, dataset_index):
+ """
+ Give a list of evenly cut blocks/sentences, pair these sentences with 50%
+ consecutive sentences and 50% random sentences.
+ This is used for none break mode
+ """
+ # pair sentences
+ for sent_id, sent in enumerate(dataset_index):
+ next_sent_label = (
+ 1 if np.random.rand() > 0.5 and sent_id != len(dataset_index) - 1 else 0
+ )
+ if next_sent_label:
+ next_sent = dataset_index[sent_id + 1]
+ else:
+ next_sent = dataset_index[
+ self._skip_sampling(len(dataset_index), [sent_id, sent_id + 1])
+ ]
+ self.sent_pairs.append((sent, next_sent, next_sent_label))
+
+ # The current blocks don't include the special tokens but the
+ # sizes already account for this
+ self.sizes.append(3 + sent[3] + next_sent[3])
+
+ def _sent_to_dataset_index(self, sent_sizes):
+ """
+ Build index mapping block indices to the underlying dataset indices
+ """
+ dataset_index = []
+ ds_idx, ds_remaining = -1, 0
+ for to_consume in sent_sizes:
+ sent_size = to_consume
+ if ds_remaining == 0:
+ ds_idx += 1
+ ds_remaining = sent_sizes[ds_idx]
+ start_ds_idx = ds_idx
+ start_offset = sent_sizes[ds_idx] - ds_remaining
+ while to_consume > ds_remaining:
+ to_consume -= ds_remaining
+ ds_idx += 1
+ ds_remaining = sent_sizes[ds_idx]
+ ds_remaining -= to_consume
+ dataset_index.append(
+ (
+ start_ds_idx, # starting index in dataset
+ start_offset, # starting offset within starting index
+ ds_idx, # ending index in dataset
+ sent_size, # sentence length
+ )
+ )
+ assert ds_remaining == 0
+ assert ds_idx == len(self.dataset) - 1
+ return dataset_index
+
+ def _generate_sentence_pair(self, doc, doc_id, max_num_tokens, sizes):
+ """
+ Go through a single document and genrate sentence paris from it
+ """
+ current_chunk = []
+ current_length = 0
+ curr = 0
+ # To provide more randomness, we decrease target seq length for parts of
+ # samples (10% by default). Note that max_num_tokens is the hard threshold
+ # for batching and will never be changed.
+ target_seq_length = max_num_tokens
+ if np.random.random() < self.short_seq_prob:
+ target_seq_length = np.random.randint(2, max_num_tokens)
+ # loop through all sentences in document
+ while curr < len(doc):
+ sent_id = doc[curr]
+ current_chunk.append(sent_id)
+ current_length = sum(sizes[current_chunk])
+ # split chunk and generate pair when exceed target_seq_length or
+ # finish the loop
+ if curr == len(doc) - 1 or current_length >= target_seq_length:
+ # split the chunk into 2 parts
+ a_end = 1
+ if len(current_chunk) > 2:
+ a_end = np.random.randint(1, len(current_chunk) - 1)
+ sent_a = current_chunk[:a_end]
+ len_a = sum(sizes[sent_a])
+ # generate next sentence label, note that if there is only 1 sentence
+ # in current chunk, label is always 0
+ next_sent_label = (
+ 1 if np.random.rand() > 0.5 and len(current_chunk) != 1 else 0
+ )
+ if not next_sent_label:
+ # if next sentence label is 0, sample sent_b from a random doc
+ target_b_length = target_seq_length - len_a
+ rand_doc_id = self._skip_sampling(len(self.block_indices), [doc_id])
+ random_doc = self.block_indices[rand_doc_id]
+ random_start = np.random.randint(0, len(random_doc))
+ sent_b = []
+ len_b = 0
+ for j in range(random_start, len(random_doc)):
+ sent_b.append(random_doc[j])
+ len_b = sum(sizes[sent_b])
+ if len_b >= target_b_length:
+ break
+ # return the second part of the chunk since it's not used
+ num_unused_segments = len(current_chunk) - a_end
+ curr -= num_unused_segments
+ else:
+ # if next sentence label is 1, use the second part of chunk as sent_B
+ sent_b = current_chunk[a_end:]
+ len_b = sum(sizes[sent_b])
+ # currently sent_a and sent_B may be longer than max_num_tokens,
+ # truncate them and return block idx and offsets for them
+ sent_a, sent_b = self._truncate_sentences(
+ sent_a, sent_b, max_num_tokens
+ )
+ self.sent_pairs.append((sent_a, sent_b, next_sent_label))
+ self.sizes.append(3 + sent_a[3] + sent_b[3])
+ current_chunk = []
+ curr += 1
+
+ def _skip_sampling(self, total, skip_ids):
+ """
+ Generate a random integer which is not in skip_ids. Sample range is [0, total)
+ TODO: ids in skip_ids should be consecutive, we can extend it to more generic version later
+ """
+ rand_id = np.random.randint(total - len(skip_ids))
+ return rand_id if rand_id < min(skip_ids) else rand_id + len(skip_ids)
+
+ def _truncate_sentences(self, sent_a, sent_b, max_num_tokens):
+ """
+ Trancate a pair of sentence to limit total length under max_num_tokens
+ Logics:
+ 1. Truncate longer sentence
+ 2. Tokens to be truncated could be at the beginning or the end of the sentnce
+ Returns:
+ Truncated sentences represented by dataset idx
+ """
+ len_a, len_b = sum(self.dataset.sizes[sent_a]), sum(self.dataset.sizes[sent_b])
+ front_cut_a = front_cut_b = end_cut_a = end_cut_b = 0
+
+ while True:
+ total_length = (
+ len_a + len_b - front_cut_a - front_cut_b - end_cut_a - end_cut_b
+ )
+ if total_length <= max_num_tokens:
+ break
+
+ if len_a - front_cut_a - end_cut_a > len_b - front_cut_b - end_cut_b:
+ if np.random.rand() < 0.5:
+ front_cut_a += 1
+ else:
+ end_cut_a += 1
+ else:
+ if np.random.rand() < 0.5:
+ front_cut_b += 1
+ else:
+ end_cut_b += 1
+
+ # calculate ds indices as well as offsets and return
+ truncated_sent_a = self._cut_sentence(sent_a, front_cut_a, end_cut_a)
+ truncated_sent_b = self._cut_sentence(sent_b, front_cut_b, end_cut_b)
+ return truncated_sent_a, truncated_sent_b
+
+ def _cut_sentence(self, sent, front_cut, end_cut):
+ """
+ Cut a sentence based on the numbers of tokens to be cut from beginning and end
+ Represent the sentence as dataset idx and return
+ """
+ start_ds_idx, end_ds_idx, offset = sent[0], sent[-1], 0
+ target_len = sum(self.dataset.sizes[sent]) - front_cut - end_cut
+ while front_cut > 0:
+ if self.dataset.sizes[start_ds_idx] > front_cut:
+ offset += front_cut
+ break
+ else:
+ front_cut -= self.dataset.sizes[start_ds_idx]
+ start_ds_idx += 1
+ while end_cut > 0:
+ if self.dataset.sizes[end_ds_idx] > end_cut:
+ break
+ else:
+ end_cut -= self.dataset.sizes[end_ds_idx]
+ end_ds_idx -= 1
+ return start_ds_idx, offset, end_ds_idx, target_len
+
+ def _fetch_block(self, start_ds_idx, offset, end_ds_idx, length):
+ """
+ Fetch a block of tokens based on its dataset idx
+ """
+ buffer = torch.cat(
+ [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
+ )
+ s, e = offset, offset + length
+ return buffer[s:e]
+
+ def __getitem__(self, index):
+ block1, block2, next_sent_label = self.sent_pairs[index]
+ block1 = self._fetch_block(*block1)
+ block2 = self._fetch_block(*block2)
+ return block1, block2, next_sent_label
+
+ def __len__(self):
+ return len(self.sizes)
+
+ @property
+ def supports_prefetch(self):
+ return getattr(self.dataset, "supports_prefetch", False)
+
+ def prefetch(self, indices):
+ prefetch_idx = set()
+ for index in indices:
+ for block1, block2, _ in [self.sent_pairs[index]]:
+ for ds_idx in range(block1[0], block1[2] + 1):
+ prefetch_idx.add(ds_idx)
+ for ds_idx in range(block2[0], block2[2] + 1):
+ prefetch_idx.add(ds_idx)
+ self.dataset.prefetch(prefetch_idx)
diff --git a/fairseq-0.10.2/fairseq/data/legacy/masked_lm_dictionary.py b/fairseq-0.10.2/fairseq/data/legacy/masked_lm_dictionary.py
new file mode 100644
index 0000000000000000000000000000000000000000..dee88f7a3ed72ea465ea4e8ffe7b1c01ff6f57f1
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/legacy/masked_lm_dictionary.py
@@ -0,0 +1,60 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from fairseq.data import Dictionary
+
+
+class MaskedLMDictionary(Dictionary):
+ """
+ Dictionary for Masked Language Modelling tasks. This extends Dictionary by
+ adding the mask symbol.
+ """
+
+ def __init__(
+ self,
+ pad="",
+ eos="",
+ unk="",
+ mask="",
+ ):
+ super().__init__(pad=pad, eos=eos, unk=unk)
+ self.mask_word = mask
+ self.mask_index = self.add_symbol(mask)
+ self.nspecial = len(self.symbols)
+
+ def mask(self):
+ """Helper to get index of mask symbol"""
+ return self.mask_index
+
+
+class BertDictionary(MaskedLMDictionary):
+ """
+ Dictionary for BERT task. This extends MaskedLMDictionary by adding support
+ for cls and sep symbols.
+ """
+
+ def __init__(
+ self,
+ pad="",
+ eos="",
+ unk="",
+ mask="",
+ cls="",
+ sep="",
+ ):
+ super().__init__(pad=pad, eos=eos, unk=unk, mask=mask)
+ self.cls_word = cls
+ self.sep_word = sep
+ self.cls_index = self.add_symbol(cls)
+ self.sep_index = self.add_symbol(sep)
+ self.nspecial = len(self.symbols)
+
+ def cls(self):
+ """Helper to get index of cls symbol"""
+ return self.cls_index
+
+ def sep(self):
+ """Helper to get index of sep symbol"""
+ return self.sep_index
diff --git a/fairseq-0.10.2/fairseq/data/list_dataset.py b/fairseq-0.10.2/fairseq/data/list_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..12f00aa43661d6bad701c9e72653ba8779136906
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/list_dataset.py
@@ -0,0 +1,32 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from . import BaseWrapperDataset
+
+
+class ListDataset(BaseWrapperDataset):
+ def __init__(self, dataset, sizes=None):
+ super().__init__(dataset)
+ self._sizes = sizes
+
+ def __iter__(self):
+ for x in self.dataset:
+ yield x
+
+ def collater(self, samples):
+ return samples
+
+ @property
+ def sizes(self):
+ return self._sizes
+
+ def num_tokens(self, index):
+ return self.sizes[index]
+
+ def size(self, index):
+ return self.sizes[index]
+
+ def set_epoch(self, epoch):
+ pass
diff --git a/fairseq-0.10.2/fairseq/data/lru_cache_dataset.py b/fairseq-0.10.2/fairseq/data/lru_cache_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7854ac1701392754ce5795cafe9c634671aebdf
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/lru_cache_dataset.py
@@ -0,0 +1,21 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from functools import lru_cache
+
+from . import BaseWrapperDataset
+
+
+class LRUCacheDataset(BaseWrapperDataset):
+ def __init__(self, dataset, token=None):
+ super().__init__(dataset)
+
+ @lru_cache(maxsize=8)
+ def __getitem__(self, index):
+ return self.dataset[index]
+
+ @lru_cache(maxsize=8)
+ def collater(self, samples):
+ return self.dataset.collater(samples)
diff --git a/fairseq-0.10.2/fairseq/data/mask_tokens_dataset.py b/fairseq-0.10.2/fairseq/data/mask_tokens_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ea86245f76c4cf233bf6f023138ef341538a267
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/mask_tokens_dataset.py
@@ -0,0 +1,178 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from functools import lru_cache
+
+import numpy as np
+import torch
+from fairseq.data import Dictionary, data_utils
+
+from . import BaseWrapperDataset, LRUCacheDataset
+
+
+class MaskTokensDataset(BaseWrapperDataset):
+ """
+ A wrapper Dataset for masked language modeling.
+
+ Input items are masked according to the specified masking probability.
+
+ Args:
+ dataset: Dataset to wrap.
+ sizes: Sentence lengths
+ vocab: Dictionary with the vocabulary and special tokens.
+ pad_idx: Id of pad token in vocab
+ mask_idx: Id of mask token in vocab
+ return_masked_tokens: controls whether to return the non-masked tokens
+ (the default) or to return a tensor with the original masked token
+ IDs (and *pad_idx* elsewhere). The latter is useful as targets for
+ masked LM training.
+ seed: Seed for random number generator for reproducibility.
+ mask_prob: probability of replacing a token with *mask_idx*.
+ leave_unmasked_prob: probability that a masked token is unmasked.
+ random_token_prob: probability of replacing a masked token with a
+ random token from the vocabulary.
+ freq_weighted_replacement: sample random replacement words based on
+ word frequencies in the vocab.
+ mask_whole_words: only mask whole words. This should be a byte mask
+ over vocab indices, indicating whether it is the beginning of a
+ word. We will extend any mask to encompass the whole word.
+ bpe: BPE to use for whole-word masking.
+ """
+
+ @classmethod
+ def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs):
+ """Return the source and target datasets for masked LM training."""
+ dataset = LRUCacheDataset(dataset)
+ return (
+ LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)),
+ LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)),
+ )
+
+ def __init__(
+ self,
+ dataset: torch.utils.data.Dataset,
+ vocab: Dictionary,
+ pad_idx: int,
+ mask_idx: int,
+ return_masked_tokens: bool = False,
+ seed: int = 1,
+ mask_prob: float = 0.15,
+ leave_unmasked_prob: float = 0.1,
+ random_token_prob: float = 0.1,
+ freq_weighted_replacement: bool = False,
+ mask_whole_words: torch.Tensor = None,
+ ):
+ assert 0.0 < mask_prob < 1.0
+ assert 0.0 <= random_token_prob <= 1.0
+ assert 0.0 <= leave_unmasked_prob <= 1.0
+ assert random_token_prob + leave_unmasked_prob <= 1.0
+
+ self.dataset = dataset
+ self.vocab = vocab
+ self.pad_idx = pad_idx
+ self.mask_idx = mask_idx
+ self.return_masked_tokens = return_masked_tokens
+ self.seed = seed
+ self.mask_prob = mask_prob
+ self.leave_unmasked_prob = leave_unmasked_prob
+ self.random_token_prob = random_token_prob
+ self.mask_whole_words = mask_whole_words
+
+ if random_token_prob > 0.0:
+ if freq_weighted_replacement:
+ weights = np.array(self.vocab.count)
+ else:
+ weights = np.ones(len(self.vocab))
+ weights[: self.vocab.nspecial] = 0
+ self.weights = weights / weights.sum()
+
+ self.epoch = 0
+
+ @property
+ def can_reuse_epoch_itr_across_epochs(self):
+ return True # only the noise changes, not item sizes
+
+ def set_epoch(self, epoch, **unused):
+ super().set_epoch(epoch)
+ self.epoch = epoch
+
+ @lru_cache(maxsize=8)
+ def __getitem__(self, index: int):
+ with data_utils.numpy_seed(self.seed, self.epoch, index):
+ item = self.dataset[index]
+ sz = len(item)
+
+ assert (
+ self.mask_idx not in item
+ ), "Dataset contains mask_idx (={}), this is not expected!".format(
+ self.mask_idx,
+ )
+
+ if self.mask_whole_words is not None:
+ word_begins_mask = self.mask_whole_words.gather(0, item)
+ word_begins_idx = word_begins_mask.nonzero().view(-1)
+ sz = len(word_begins_idx)
+ words = np.split(word_begins_mask, word_begins_idx)[1:]
+ assert len(words) == sz
+ word_lens = list(map(len, words))
+
+ # decide elements to mask
+ mask = np.full(sz, False)
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ self.mask_prob * sz
+ + np.random.rand()
+ )
+ mask[np.random.choice(sz, num_mask, replace=False)] = True
+
+ if self.return_masked_tokens:
+ # exit early if we're just returning the masked tokens
+ # (i.e., the targets for masked LM training)
+ if self.mask_whole_words is not None:
+ mask = np.repeat(mask, word_lens)
+ new_item = np.full(len(mask), self.pad_idx)
+ new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
+ return torch.from_numpy(new_item)
+
+ # decide unmasking and random replacement
+ rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
+ if rand_or_unmask_prob > 0.0:
+ rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
+ if self.random_token_prob == 0.0:
+ unmask = rand_or_unmask
+ rand_mask = None
+ elif self.leave_unmasked_prob == 0.0:
+ unmask = None
+ rand_mask = rand_or_unmask
+ else:
+ unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
+ decision = np.random.rand(sz) < unmask_prob
+ unmask = rand_or_unmask & decision
+ rand_mask = rand_or_unmask & (~decision)
+ else:
+ unmask = rand_mask = None
+
+ if unmask is not None:
+ mask = mask ^ unmask
+
+ if self.mask_whole_words is not None:
+ mask = np.repeat(mask, word_lens)
+
+ new_item = np.copy(item)
+ new_item[mask] = self.mask_idx
+ if rand_mask is not None:
+ num_rand = rand_mask.sum()
+ if num_rand > 0:
+ if self.mask_whole_words is not None:
+ rand_mask = np.repeat(rand_mask, word_lens)
+ num_rand = rand_mask.sum()
+
+ new_item[rand_mask] = np.random.choice(
+ len(self.vocab),
+ num_rand,
+ p=self.weights,
+ )
+
+ return torch.from_numpy(new_item)
diff --git a/fairseq-0.10.2/fairseq/data/multi_corpus_dataset.py b/fairseq-0.10.2/fairseq/data/multi_corpus_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2457666d688f773a98af8eb610a4f5756b02dc0
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/multi_corpus_dataset.py
@@ -0,0 +1,159 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from collections import OrderedDict
+from typing import Dict, List
+
+import numpy as np
+from fairseq.data import data_utils
+
+from . import FairseqDataset
+
+
+logger = logging.getLogger(__name__)
+
+
+class MultiCorpusDataset(FairseqDataset):
+ """
+ Stores multiple instances of FairseqDataset together. Requires each instance
+ to be the same dataset, as the collate method needs to work on batches with
+ samples from each dataset.
+
+ Allows specifying a distribution over the datasets to use. Note that unlike
+ MultiCorpusSampledDataset, this distribution allows sampling for each item,
+ rather than on a batch level.
+
+ Each time ordered_indices() is called, a new sample is generated with
+ the specified distribution.
+
+ Args:
+ datasets: a OrderedDict of FairseqDataset instances.
+ distribution: a List containing the probability of getting an utterance from
+ corresponding dataset
+ seed: random seed for sampling the datsets
+ sort_indices: if true, will sort the ordered indices by size
+ """
+
+ def __init__(
+ self,
+ datasets: Dict[str, FairseqDataset],
+ distribution: List[float],
+ seed: int,
+ sort_indices: bool = False,
+ ):
+ super().__init__()
+ assert isinstance(datasets, OrderedDict)
+ assert len(datasets) == len(distribution)
+ self.datasets = datasets
+ self.distribution = distribution
+ self.seed = seed
+ self.sort_indices = sort_indices
+
+ # Avoid repeated conversions to list later
+ self.dataset_list = list(datasets.values())
+ self.total_num_instances = 0
+
+ first_dataset = list(self.datasets.values())[0]
+
+ self.dataset_offsets = []
+ for dataset in datasets.values():
+ assert isinstance(dataset, FairseqDataset)
+ assert type(dataset) is type(first_dataset)
+ self.dataset_offsets.append(self.total_num_instances)
+ self.total_num_instances += len(dataset)
+
+ def ordered_indices(self):
+ with data_utils.numpy_seed(self.seed, self.epoch):
+ # Used to store the order of indices of each dataset to use
+ indices = [
+ np.random.permutation(len(dataset))
+ for dataset in self.datasets.values()
+ ]
+ # Keep track of which samples we've used for each dataset
+ counters = [0 for _ in self.datasets]
+
+ sampled_indices = [
+ self._sample(indices, counters) for _ in range(self.total_num_instances)
+ ]
+ if self.sort_indices:
+ sampled_indices.sort(key=lambda i: self.num_tokens(i))
+ return np.array(sampled_indices, dtype=np.int64)
+
+ def _sample(self, indices, counters):
+ # First pick dataset
+ dataset_idx = np.random.choice(len(self.distribution), p=self.distribution)
+
+ # Then get dataset internal index
+ idx = indices[dataset_idx][counters[dataset_idx]]
+
+ # Convert to multi-datasets index
+ idx += self.dataset_offsets[dataset_idx]
+
+ counters[dataset_idx] += 1
+
+ # Reset if we reach end
+ if counters[dataset_idx] == len(self.dataset_list[dataset_idx]):
+ counters[dataset_idx] = 0
+ indices[dataset_idx] = np.random.permutation(
+ len(self.dataset_list[dataset_idx])
+ )
+
+ return idx
+
+ def _map_index(self, index: int):
+ """
+ If dataset A has length N and dataset B has length M
+ then index 1 maps to index 1 of dataset A, and index N + 1
+ maps to index 1 of B.
+ """
+ counter = 0
+ for key, dataset in self.datasets.items():
+ if index < counter + len(dataset):
+ return index - counter, key
+ counter += len(dataset)
+ raise ValueError(
+ "Invalid index: {}, max: {}".format(index, self.total_num_instances)
+ )
+
+ def __len__(self):
+ """
+ Length of this dataset is the sum of individual datasets
+ """
+ return self.total_num_instances
+
+ def __getitem__(self, index):
+ index, key = self._map_index(index)
+ return self.datasets[key][index]
+
+ def collater(self, samples):
+ """
+ Since we enforce all datsets to be the same, collating is just
+ picking the first one and doing collate.
+ """
+ if len(samples) == 0:
+ return None
+
+ return list(self.datasets.values())[0].collater(samples)
+
+ def num_tokens(self, index: int):
+ index, key = self._map_index(index)
+ return self.datasets[key].num_tokens(index)
+
+ def size(self, index: int):
+ index, key = self._map_index(index)
+ return self.datasets[key].size(index)
+
+ @property
+ def can_reuse_epoch_itr_across_epochs(self):
+ return False
+
+ def set_epoch(self, epoch, **unused):
+ super().set_epoch(epoch)
+ self.epoch = epoch
+
+ @property
+ def supports_prefetch(self):
+ return False
diff --git a/fairseq-0.10.2/fairseq/data/multi_corpus_sampled_dataset.py b/fairseq-0.10.2/fairseq/data/multi_corpus_sampled_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad8e951cc905a73fea28b4fac449e307cadfa52f
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/multi_corpus_sampled_dataset.py
@@ -0,0 +1,145 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import OrderedDict
+from typing import Callable, Dict, List
+
+import numpy as np
+
+from . import FairseqDataset
+
+
+def uniform_sampler(x):
+ # Sample from uniform distribution
+ return np.random.choice(x, 1).item()
+
+
+class MultiCorpusSampledDataset(FairseqDataset):
+ """
+ Stores multiple instances of FairseqDataset together and in every iteration
+ creates a batch by first sampling a dataset according to a specified
+ probability distribution and then getting instances from that dataset.
+
+ Args:
+ datasets: an OrderedDict of FairseqDataset instances.
+ sampling_func: A function for sampling over list of dataset keys.
+ The default strategy is to sample uniformly.
+ """
+
+ def __init__(
+ self,
+ datasets: Dict[str, FairseqDataset],
+ sampling_func: Callable[[List], int] = None,
+ ):
+ super().__init__()
+ assert isinstance(datasets, OrderedDict)
+ self.datasets = datasets
+ if sampling_func is None:
+ sampling_func = uniform_sampler
+ self.sampling_func = sampling_func
+
+ self.total_num_instances = 0
+ for _, dataset in datasets.items():
+ assert isinstance(dataset, FairseqDataset)
+ self.total_num_instances += len(dataset)
+
+ self._ordered_indices = None
+
+ def __len__(self):
+ """
+ Length of this dataset is the sum of individual datasets
+ """
+ return self.total_num_instances
+
+ def ordered_indices(self):
+ """
+ Ordered indices for batching. Here we call the underlying
+ dataset's ordered_indices() so that we get the same random ordering
+ as we would have from using the underlying dataset directly.
+ """
+ if self._ordered_indices is None:
+ self._ordered_indices = OrderedDict(
+ [
+ (key, dataset.ordered_indices())
+ for key, dataset in self.datasets.items()
+ ]
+ )
+ return np.arange(len(self))
+
+ def _map_index_to_dataset(self, key: int, index: int):
+ """
+ Different underlying datasets have different lengths. In order to ensure
+ we are not accessing an index outside the range of the current dataset
+ size, we wrap around. This function should be called after we have
+ created an ordering for this and all underlying datasets.
+ """
+ assert (
+ self._ordered_indices is not None
+ ), "Must call MultiCorpusSampledDataset.ordered_indices() first"
+ mapped_index = index % len(self.datasets[key])
+ return self._ordered_indices[key][mapped_index]
+
+ def __getitem__(self, index: int):
+ """
+ Get the item associated with index from each underlying dataset.
+ Since index is in the range of [0, TotalNumInstances], we need to
+ map the index to the dataset before retrieving the item.
+ """
+ return OrderedDict(
+ [
+ (key, dataset[self._map_index_to_dataset(key, index)])
+ for key, dataset in self.datasets.items()
+ ]
+ )
+
+ def collater(self, samples: List[Dict]):
+ """
+ Generate a mini-batch for this dataset.
+ To convert this into a regular mini-batch we use the following
+ logic:
+ 1. Select a dataset using the specified probability distribution.
+ 2. Call the collater function of the selected dataset.
+ """
+ if len(samples) == 0:
+ return None
+
+ selected_key = self.sampling_func(list(self.datasets.keys()))
+ selected_samples = [sample[selected_key] for sample in samples]
+ return self.datasets[selected_key].collater(selected_samples)
+
+ def num_tokens(self, index: int):
+ """
+ Return an example's length (number of tokens), used for batching. Here
+ we return the max across all examples at index across all underlying
+ datasets.
+ """
+ return max(
+ dataset.num_tokens(self._map_index_to_dataset(key, index))
+ for key, dataset in self.datasets.items()
+ )
+
+ def size(self, index: int):
+ """
+ Return an example's size as a float or tuple. Here we return the max
+ across all underlying datasets. This value is used when filtering a
+ dataset with max-positions.
+ """
+ return max(
+ dataset.size(self._map_index_to_dataset(key, index))
+ for key, dataset in self.datasets.items()
+ )
+
+ @property
+ def supports_prefetch(self):
+ return all(
+ getattr(dataset, "supports_prefetch", False)
+ for dataset in self.datasets.values()
+ )
+
+ def prefetch(self, indices):
+ for key, dataset in self.datasets.items():
+ dataset.prefetch(
+ [self._map_index_to_dataset(key, index) for index in indices]
+ )
diff --git a/fairseq-0.10.2/fairseq/data/num_samples_dataset.py b/fairseq-0.10.2/fairseq/data/num_samples_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..99a17495c701d8a05e0268f98bf453905e11d078
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/num_samples_dataset.py
@@ -0,0 +1,17 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from . import FairseqDataset
+
+
+class NumSamplesDataset(FairseqDataset):
+ def __getitem__(self, index):
+ return 1
+
+ def __len__(self):
+ return 0
+
+ def collater(self, samples):
+ return sum(samples)
diff --git a/fairseq-0.10.2/fairseq/data/offset_tokens_dataset.py b/fairseq-0.10.2/fairseq/data/offset_tokens_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fabbdcdaa1a8f70d8d8c07db4cd53754503c194
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/offset_tokens_dataset.py
@@ -0,0 +1,15 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from . import BaseWrapperDataset
+
+
+class OffsetTokensDataset(BaseWrapperDataset):
+ def __init__(self, dataset, offset):
+ super().__init__(dataset)
+ self.offset = offset
+
+ def __getitem__(self, idx):
+ return self.dataset[idx] + self.offset
diff --git a/fairseq-0.10.2/fairseq/data/pad_dataset.py b/fairseq-0.10.2/fairseq/data/pad_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8075bba6a9efc5f8421368ee0b2ae66afe3f5009
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/pad_dataset.py
@@ -0,0 +1,28 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from fairseq.data import data_utils
+
+from . import BaseWrapperDataset
+
+
+class PadDataset(BaseWrapperDataset):
+ def __init__(self, dataset, pad_idx, left_pad):
+ super().__init__(dataset)
+ self.pad_idx = pad_idx
+ self.left_pad = left_pad
+
+ def collater(self, samples):
+ return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad)
+
+
+class LeftPadDataset(PadDataset):
+ def __init__(self, dataset, pad_idx):
+ super().__init__(dataset, pad_idx, left_pad=True)
+
+
+class RightPadDataset(PadDataset):
+ def __init__(self, dataset, pad_idx):
+ super().__init__(dataset, pad_idx, left_pad=False)
diff --git a/fairseq-0.10.2/fairseq/data/prepend_dataset.py b/fairseq-0.10.2/fairseq/data/prepend_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad74784d2d7920e4a6225282d95543ce16ea50d9
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/prepend_dataset.py
@@ -0,0 +1,28 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+
+from . import BaseWrapperDataset
+
+
+class PrependDataset(BaseWrapperDataset):
+ def __init__(self, dataset, prepend_getter, ensure_first_token_is=None):
+ super().__init__(dataset)
+ self.prepend_getter = prepend_getter
+ self.ensure_first_token = ensure_first_token_is
+
+ def __getitem__(self, idx):
+ item = self.dataset[idx]
+ is_tuple = isinstance(item, tuple)
+ src = item[0] if is_tuple else item
+
+ assert self.ensure_first_token is None or src[0] == self.ensure_first_token
+ prepend_idx = self.prepend_getter(self.dataset, idx)
+ assert isinstance(prepend_idx, int)
+ src[0] = prepend_idx
+ item = tuple((src,) + item[1:]) if is_tuple else src
+ return item
diff --git a/fairseq-0.10.2/fairseq/data/resampling_dataset.py b/fairseq-0.10.2/fairseq/data/resampling_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d3b993164dc3962df48bacff26714328e843e80
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/resampling_dataset.py
@@ -0,0 +1,139 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+
+import numpy as np
+from fairseq.data import BaseWrapperDataset, plasma_utils
+
+
+logger = logging.getLogger(__name__)
+
+
+class ResamplingDataset(BaseWrapperDataset):
+ """Randomly samples from a given dataset at each epoch.
+
+ Sampling is done with or without replacement, depending on the "replace"
+ parameter.
+
+ Optionally, the epoch size can be rescaled. This is potentially desirable
+ to increase per-epoch coverage of the base dataset (since sampling with
+ replacement means that many items in the dataset will be left out). In the
+ case of sampling without replacement, size_ratio should be strictly less
+ than 1.
+
+ Args:
+ dataset (~torch.utils.data.Dataset): dataset on which to sample.
+ weights (List[float]): list of probability weights
+ (default: None, which corresponds to uniform sampling).
+ replace (bool): sampling mode; True for "with replacement", or False
+ for "without replacement" (default: True)
+ size_ratio (float): the ratio to subsample to; must be positive
+ (default: 1.0).
+ batch_by_size (bool): whether or not to batch by sequence length
+ (default: True).
+ seed (int): RNG seed to use (default: 0).
+ epoch (int): starting epoch number (default: 1).
+ """
+
+ def __init__(
+ self,
+ dataset,
+ weights=None,
+ replace=True,
+ size_ratio=1.0,
+ batch_by_size=True,
+ seed=0,
+ epoch=1,
+ ):
+ super().__init__(dataset)
+
+ if weights is None:
+ self.weights = None
+
+ else:
+ assert len(weights) == len(dataset)
+ weights_arr = np.array(weights, dtype=np.float64)
+ weights_arr /= weights_arr.sum()
+ self.weights = plasma_utils.PlasmaArray(weights_arr)
+
+ self.replace = replace
+
+ assert size_ratio > 0.0
+ if not self.replace:
+ assert size_ratio < 1.0
+ self.size_ratio = float(size_ratio)
+ self.actual_size = np.ceil(len(dataset) * self.size_ratio).astype(int)
+
+ self.batch_by_size = batch_by_size
+ self.seed = seed
+
+ self._cur_epoch = None
+ self._cur_indices = None
+
+ self.set_epoch(epoch)
+
+ def __getitem__(self, index):
+ return self.dataset[self._cur_indices.array[index]]
+
+ def __len__(self):
+ return self.actual_size
+
+ @property
+ def sizes(self):
+ if isinstance(self.dataset.sizes, list):
+ return [s[self._cur_indices.array] for s in self.dataset.sizes]
+ return self.dataset.sizes[self._cur_indices.array]
+
+ def num_tokens(self, index):
+ return self.dataset.num_tokens(self._cur_indices.array[index])
+
+ def size(self, index):
+ return self.dataset.size(self._cur_indices.array[index])
+
+ def ordered_indices(self):
+ if self.batch_by_size:
+ order = [
+ np.arange(len(self)),
+ self.sizes,
+ ] # No need to handle `self.shuffle == True`
+ return np.lexsort(order)
+ else:
+ return np.arange(len(self))
+
+ def prefetch(self, indices):
+ self.dataset.prefetch(self._cur_indices.array[indices])
+
+ @property
+ def can_reuse_epoch_itr_across_epochs(self):
+ return False
+
+ def set_epoch(self, epoch):
+ logger.debug("ResamplingDataset.set_epoch: {}".format(epoch))
+ super().set_epoch(epoch)
+
+ if epoch == self._cur_epoch:
+ return
+
+ self._cur_epoch = epoch
+
+ # Generate a weighted sample of indices as a function of the
+ # random seed and the current epoch.
+
+ rng = np.random.RandomState(
+ [
+ 42, # magic number
+ self.seed % (2 ** 32), # global seed
+ self._cur_epoch, # epoch index
+ ]
+ )
+ self._cur_indices = plasma_utils.PlasmaArray(
+ rng.choice(
+ len(self.dataset),
+ self.actual_size,
+ replace=self.replace,
+ p=(None if self.weights is None else self.weights.array),
+ )
+ )
diff --git a/fairseq-0.10.2/fairseq/data/shorten_dataset.py b/fairseq-0.10.2/fairseq/data/shorten_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ebb5d88feb3f29d1512a0873df304915d051209
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/shorten_dataset.py
@@ -0,0 +1,78 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+from fairseq.data import data_utils
+
+from . import BaseWrapperDataset
+
+
+class TruncateDataset(BaseWrapperDataset):
+ """Truncate a sequence by returning the first truncation_length tokens"""
+
+ def __init__(self, dataset, truncation_length):
+ super().__init__(dataset)
+ assert truncation_length is not None
+ self.truncation_length = truncation_length
+ self.dataset = dataset
+
+ def __getitem__(self, index):
+ item = self.dataset[index]
+ item_len = item.size(0)
+ if item_len > self.truncation_length:
+ item = item[: self.truncation_length]
+ return item
+
+ @property
+ def sizes(self):
+ return np.minimum(self.dataset.sizes, self.truncation_length)
+
+ def __len__(self):
+ return len(self.dataset)
+
+
+class RandomCropDataset(TruncateDataset):
+ """Truncate a sequence by returning a random crop of truncation_length tokens"""
+
+ def __init__(self, dataset, truncation_length, seed=1):
+ super().__init__(dataset, truncation_length)
+ self.seed = seed
+ self.epoch = 0
+
+ @property
+ def can_reuse_epoch_itr_across_epochs(self):
+ return True # only the crop changes, not item sizes
+
+ def set_epoch(self, epoch, **unused):
+ super().set_epoch(epoch)
+ self.epoch = epoch
+
+ def __getitem__(self, index):
+ with data_utils.numpy_seed(self.seed, self.epoch, index):
+ item = self.dataset[index]
+ item_len = item.size(0)
+ excess = item_len - self.truncation_length
+ if excess > 0:
+ start_idx = np.random.randint(0, excess)
+ item = item[start_idx : start_idx + self.truncation_length]
+ return item
+
+
+def maybe_shorten_dataset(
+ dataset,
+ split,
+ shorten_data_split_list,
+ shorten_method,
+ tokens_per_sample,
+ seed,
+):
+ truncate_split = (
+ split in shorten_data_split_list.split(",") or len(shorten_data_split_list) == 0
+ )
+ if shorten_method == "truncate" and truncate_split:
+ dataset = TruncateDataset(dataset, tokens_per_sample)
+ elif shorten_method == "random_crop" and truncate_split:
+ dataset = RandomCropDataset(dataset, tokens_per_sample, seed)
+ return dataset
diff --git a/fairseq-0.10.2/fairseq/data/strip_token_dataset.py b/fairseq-0.10.2/fairseq/data/strip_token_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..cae39ba4d2f8106398eccd7eb0cf5c2194ec0db5
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/strip_token_dataset.py
@@ -0,0 +1,20 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from . import BaseWrapperDataset
+
+
+class StripTokenDataset(BaseWrapperDataset):
+ def __init__(self, dataset, id_to_strip):
+ super().__init__(dataset)
+ self.id_to_strip = id_to_strip
+
+ def __getitem__(self, index):
+ item = self.dataset[index]
+ while len(item) > 0 and item[-1] == self.id_to_strip:
+ item = item[:-1]
+ while len(item) > 0 and item[0] == self.id_to_strip:
+ item = item[1:]
+ return item
diff --git a/fairseq-0.10.2/fairseq/data/token_block_dataset.py b/fairseq-0.10.2/fairseq/data/token_block_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa33f9d06f37fa6a1e239d9733a2725ec158f6a8
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/token_block_dataset.py
@@ -0,0 +1,168 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from fairseq.data import FairseqDataset, plasma_utils
+
+
+class TokenBlockDataset(FairseqDataset):
+ """Break a Dataset of tokens into blocks.
+
+ Args:
+ dataset (~torch.utils.data.Dataset): dataset to break into blocks
+ sizes (List[int]): sentence lengths (required for 'complete' and 'eos')
+ block_size (int): maximum block size (ignored in 'eos' break mode)
+ break_mode (str, optional): Mode used for breaking tokens. Values can
+ be one of:
+ - 'none': break tokens into equally sized blocks (up to block_size)
+ - 'complete': break tokens into blocks (up to block_size) such that
+ blocks contains complete sentences, although block_size may be
+ exceeded if some sentences exceed block_size
+ - 'complete_doc': similar to 'complete' mode, but do not
+ cross document boundaries
+ - 'eos': each block contains one sentence (block_size is ignored)
+ include_targets (bool, optional): return next tokens as targets
+ (default: False).
+ document_sep_len (int, optional): document separator size (required for
+ 'complete_doc' break mode). Typically 1 if the sentences have eos
+ and 0 otherwise.
+ """
+
+ def __init__(
+ self,
+ dataset,
+ sizes,
+ block_size,
+ pad,
+ eos,
+ break_mode=None,
+ include_targets=False,
+ document_sep_len=1,
+ ):
+ try:
+ from fairseq.data.token_block_utils_fast import (
+ _get_slice_indices_fast,
+ _get_block_to_dataset_index_fast,
+ )
+ except ImportError:
+ raise ImportError(
+ "Please build Cython components with: `pip install --editable .` "
+ "or `python setup.py build_ext --inplace`"
+ )
+
+ super().__init__()
+ self.dataset = dataset
+ self.pad = pad
+ self.eos = eos
+ self.include_targets = include_targets
+
+ assert len(dataset) == len(sizes)
+ assert len(dataset) > 0
+
+ if isinstance(sizes, list):
+ sizes = np.array(sizes, dtype=np.int64)
+ else:
+ if torch.is_tensor(sizes):
+ sizes = sizes.numpy()
+ sizes = sizes.astype(np.int64)
+
+ break_mode = break_mode if break_mode is not None else "none"
+
+ # For "eos" break-mode, block_size is not required parameters.
+ if break_mode == "eos" and block_size is None:
+ block_size = 0
+
+ slice_indices = _get_slice_indices_fast(
+ sizes, str(break_mode), block_size, document_sep_len
+ )
+ self._sizes = slice_indices[:, 1] - slice_indices[:, 0]
+
+ # build index mapping block indices to the underlying dataset indices
+ if break_mode == "eos":
+ # much faster version for eos break mode
+ block_to_dataset_index = np.stack(
+ [
+ np.arange(len(sizes)), # starting index in dataset
+ np.zeros(
+ len(sizes), dtype=np.long
+ ), # starting offset within starting index
+ np.arange(len(sizes)), # ending index in dataset
+ ],
+ 1,
+ )
+ else:
+ block_to_dataset_index = _get_block_to_dataset_index_fast(
+ sizes,
+ slice_indices,
+ )
+ self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
+ self._sizes = plasma_utils.PlasmaArray(self._sizes)
+ self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index)
+
+ @property
+ def slice_indices(self):
+ return self._slice_indices.array
+
+ @property
+ def sizes(self):
+ return self._sizes.array
+
+ @property
+ def block_to_dataset_index(self):
+ return self._block_to_dataset_index.array
+
+ def attr(self, attr: str, index: int):
+ start_ds_idx, _, _ = self.block_to_dataset_index[index]
+ return self.dataset.attr(attr, start_ds_idx)
+
+ def __getitem__(self, index):
+ start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
+
+ buffer = torch.cat(
+ [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
+ )
+
+ slice_s, slice_e = self.slice_indices[index]
+ length = slice_e - slice_s
+ s, e = start_offset, start_offset + length
+ item = buffer[s:e]
+
+ if self.include_targets:
+ # *target* is the original sentence (=item)
+ # *source* is shifted right by 1 (maybe left-padded with eos)
+ # *past_target* is shifted right by 2 (left-padded as needed)
+ if s == 0:
+ source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]])
+ past_target = torch.cat(
+ [item.new([self.pad, self.eos]), buffer[0 : e - 2]]
+ )
+ else:
+ source = buffer[s - 1 : e - 1]
+ if s == 1:
+ past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]])
+ else:
+ past_target = buffer[s - 2 : e - 2]
+
+ return source, item, past_target
+
+ return item
+
+ def __len__(self):
+ return len(self.slice_indices)
+
+ @property
+ def supports_prefetch(self):
+ return getattr(self.dataset, "supports_prefetch", False)
+
+ def prefetch(self, indices):
+ self.dataset.prefetch(
+ {
+ ds_idx
+ for index in indices
+ for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]]
+ for ds_idx in range(start_ds_idx, end_ds_idx + 1)
+ }
+ )
diff --git a/fairseq-0.10.2/fairseq/data/transform_eos_lang_pair_dataset.py b/fairseq-0.10.2/fairseq/data/transform_eos_lang_pair_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dd3d93d2b41898ba6b25ba0255abbcebcf495b7
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/data/transform_eos_lang_pair_dataset.py
@@ -0,0 +1,108 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from typing import Optional
+
+import torch
+
+from . import FairseqDataset
+
+
+class TransformEosLangPairDataset(FairseqDataset):
+ """A :class:`~fairseq.data.FairseqDataset` wrapper that transform bos on
+ collated samples of language pair dataset.
+
+ Note that the transformation is applied in :func:`collater`.
+
+ Args:
+ dataset (~fairseq.data.FairseqDataset): dataset that collates sample into
+ LanguagePairDataset schema
+ src_eos (int): original source end-of-sentence symbol index to be replaced
+ new_src_eos (int, optional): new end-of-sentence symbol index to replace source eos symbol
+ tgt_bos (int, optional): original target beginning-of-sentence symbol index to be replaced
+ new_tgt_bos (int, optional): new beginning-of-sentence symbol index to replace at the
+ beginning of 'prev_output_tokens'
+ """
+
+ def __init__(
+ self,
+ dataset: FairseqDataset,
+ src_eos: int,
+ new_src_eos: Optional[int] = None,
+ tgt_bos: Optional[int] = None,
+ new_tgt_bos: Optional[int] = None,
+ ):
+ self.dataset = dataset
+ self.src_eos = src_eos
+ self.new_src_eos = new_src_eos
+ self.tgt_bos = tgt_bos
+ self.new_tgt_bos = new_tgt_bos
+
+ def __getitem__(self, index):
+ return self.dataset[index]
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def collater(self, samples, **extra_args):
+ samples = self.dataset.collater(samples, **extra_args)
+
+ if self.new_src_eos is not None:
+ if self.dataset.left_pad_source:
+ assert (
+ samples["net_input"]["src_tokens"][:, -1] != self.src_eos
+ ).sum() == 0
+ samples["net_input"]["src_tokens"][:, -1] = self.new_src_eos
+ else:
+ eos_idx = samples["net_input"]["src_lengths"] - 1
+ assert (
+ samples["net_input"]["src_tokens"][
+ torch.arange(eos_idx.size(0)), eos_idx
+ ]
+ != self.src_eos
+ ).sum() == 0
+ eos_idx = eos_idx.resize_(len(samples["net_input"]["src_lengths"]), 1)
+ samples["net_input"]["src_tokens"].scatter_(
+ 1, eos_idx, self.new_src_eos
+ )
+
+ if (
+ self.new_tgt_bos is not None
+ and "prev_output_tokens" in samples["net_input"]
+ ):
+ if self.dataset.left_pad_target:
+ # TODO: support different padding direction on target side
+ raise NotImplementedError(
+ "TransformEosLangPairDataset does not implement --left-pad-target True option"
+ )
+ else:
+ assert (
+ samples["net_input"]["prev_output_tokens"][:, 0] != self.tgt_bos
+ ).sum() == 0
+ samples["net_input"]["prev_output_tokens"][:, 0] = self.new_tgt_bos
+
+ return samples
+
+ def num_tokens(self, index):
+ return self.dataset.num_tokens(index)
+
+ def size(self, index):
+ return self.dataset.size(index)
+
+ @property
+ def sizes(self):
+ # dataset.sizes can be a dynamically computed sizes:
+ return self.dataset.sizes
+
+ def ordered_indices(self):
+ return self.dataset.ordered_indices()
+
+ @property
+ def supports_prefetch(self):
+ return getattr(self.dataset, "supports_prefetch", False)
+
+ def prefetch(self, indices):
+ return self.dataset.prefetch(indices)
diff --git a/fairseq-0.10.2/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so b/fairseq-0.10.2/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so
new file mode 100644
index 0000000000000000000000000000000000000000..c5b8c0cb206abc7a419cb637e5265662c160ef72
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b0b5376f3a99c755cece761f9c9d1001ab792bd99574b8cb3401dd20ce49cbdc
+size 148544
diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/__init__.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19dd5a8498d2951d11af13c91a1170b54d3cddb8
Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/fairseq_nat_model.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/fairseq_nat_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13540ed1698837eb859aa8648ef0ddad7f4b109d
Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/fairseq_nat_model.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/insertion_transformer.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/insertion_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6559dc0c11b4deb81a49876053826cf42ba92ad1
Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/insertion_transformer.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/iterative_nonautoregressive_transformer.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/iterative_nonautoregressive_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d37907884648a79d97325454adfe183f6cf24e7f
Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/iterative_nonautoregressive_transformer.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/levenshtein_transformer.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/levenshtein_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69dddfd8ea5f058b47080611bb02cb2972e25dee
Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/levenshtein_transformer.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/levenshtein_utils.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/levenshtein_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e081bd4b82c2a35e71506e9860750c81c048729c
Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/levenshtein_utils.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/nat_crf_transformer.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/nat_crf_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2fb3c5cecc13bc3518c4916d18b1fe9b5360beb6
Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/nat_crf_transformer.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/nonautoregressive_transformer.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/nonautoregressive_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..318a0ffcdbb2950522d638ec4c2210db956f2014
Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/nonautoregressive_transformer.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/optim/__init__.py b/fairseq-0.10.2/fairseq/optim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..94eb2c7ee966756590a4294eab181fc87fac2fa1
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/optim/__init__.py
@@ -0,0 +1,52 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""isort:skip_file"""
+
+import importlib
+import os
+from argparse import Namespace
+from typing import Union
+
+from fairseq import registry
+from fairseq.optim.bmuf import FairseqBMUF # noqa
+from fairseq.optim.fairseq_optimizer import ( # noqa
+ FairseqOptimizer,
+ LegacyFairseqOptimizer,
+)
+from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
+from fairseq.optim.shard import shard_
+from omegaconf import DictConfig
+
+
+__all__ = [
+ "FairseqOptimizer",
+ "FP16Optimizer",
+ "MemoryEfficientFP16Optimizer",
+ "shard_",
+]
+
+
+(
+ _build_optimizer,
+ register_optimizer,
+ OPTIMIZER_REGISTRY,
+ OPTIMIZER_DATACLASS_REGISTRY,
+) = registry.setup_registry("--optimizer", base_class=FairseqOptimizer, required=True)
+
+
+def build_optimizer(
+ optimizer_cfg: Union[DictConfig, Namespace], params, *extra_args, **extra_kwargs
+):
+ if all(isinstance(p, dict) for p in params):
+ params = [t for p in params for t in p.values()]
+ params = list(filter(lambda p: p.requires_grad, params))
+ return _build_optimizer(optimizer_cfg, params, *extra_args, **extra_kwargs)
+
+
+# automatically import any Python files in the optim/ directory
+for file in os.listdir(os.path.dirname(__file__)):
+ if file.endswith(".py") and not file.startswith("_"):
+ file_name = file[: file.find(".py")]
+ importlib.import_module("fairseq.optim." + file_name)
diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/adadelta.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/adadelta.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78b95e0d63c7f5698ee12550b8415dd5c452943f
Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/adadelta.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/adagrad.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/adagrad.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65eaeeedec69594d23977ed23bd7b4bd1ce9944f
Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/adagrad.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/adamax.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/adamax.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66d32c310934ec2f11c002dcbd5e9d33c184019e
Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/adamax.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/bmuf.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/bmuf.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f687c8d0e4a98d39a52995ee1ddfcf37c40c450d
Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/bmuf.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..355deff664f9042cbb78f5cb709388c7c01b5824
Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ea7fdac9d2b2ddad8af9f0bf91b9f116222e08c
Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2cd6a32764a45a2e4f8a5f53ade1a315dcbdf56f
Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/nag.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/nag.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66254eec31174ec38aa5908c6dc4c16fc5051d3d
Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/nag.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/sgd.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/sgd.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a967605193d21ad9de0cca8e586c8877fab2cb5
Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/sgd.cpython-310.pyc differ
diff --git a/fairseq-0.10.2/fairseq/optim/adafactor.py b/fairseq-0.10.2/fairseq/optim/adafactor.py
new file mode 100644
index 0000000000000000000000000000000000000000..91745ce10e183479f8cb552f2a1a91834e2a61ed
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/optim/adafactor.py
@@ -0,0 +1,268 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+import torch.optim
+
+from . import LegacyFairseqOptimizer, register_optimizer
+
+
+@register_optimizer("adafactor")
+class FairseqAdafactor(LegacyFairseqOptimizer):
+ def __init__(self, args, params):
+ super().__init__(args)
+ self._optimizer = Adafactor(params, **self.optimizer_config)
+
+ @staticmethod
+ def add_args(parser):
+ """Add optimizer-specific arguments to the parser."""
+ # fmt: off
+ parser.add_argument('--adafactor-eps', default='(1e-30, 1e-3)', metavar="E",
+ help='epsilons for Adafactor optimizer')
+ parser.add_argument('--clip-threshold', type=float, default=1.0, metavar="C",
+ help='threshold for clipping update root mean square')
+ parser.add_argument('--decay-rate', type=float, default=-0.8, metavar="D",
+ help='decay rate of the second moment estimator')
+ parser.add_argument('--beta1', type=float, default=None, metavar="B",
+ help='beta for first moment estimator. Optional')
+ parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
+ help='weight decay')
+ parser.add_argument('--scale-parameter', action='store_true',
+ help='scale learning rate by root mean square of parameter')
+ parser.add_argument('--relative-step', action='store_true',
+ help='set learning rate to inverse square root of timestep,'
+ 'otherwise use external learning rate')
+ parser.add_argument('--warmup-init', action='store_true',
+ help='use relative step for warm-up learning rate schedule')
+ # fmt: on
+
+ @property
+ def optimizer_config(self):
+ """
+ Return a kwarg dictionary that will be used to override optimizer
+ args stored in checkpoints. This allows us to load a checkpoint and
+ resume training using a different set of optimizer args, e.g., with a
+ different learning rate.
+ Note : Convergence issues empirically observed with fp16 on.
+ Might require search for appropriate configuration.
+ """
+ return {
+ "lr": self.args.lr[0],
+ "eps": eval(self.args.adafactor_eps),
+ "clip_threshold": self.args.clip_threshold,
+ "decay_rate": self.args.decay_rate,
+ "beta1": self.args.beta1,
+ "weight_decay": self.args.weight_decay,
+ "scale_parameter": self.args.scale_parameter, # defaults to False
+ "relative_step": self.args.relative_step, # defaults to False
+ "warmup_init": self.args.warmup_init,
+ }
+
+
+class Adafactor(torch.optim.Optimizer):
+ """Implements Adafactor algorithm.
+
+ This implementation is based on:
+ `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
+ (see https://arxiv.org/abs/1804.04235)
+
+ Note that this optimizer internally adjusts the learning rate
+ depending on the *scale_parameter*, *relative_step* and
+ *warmup_init* options. To use a manual (external) learning rate
+ schedule you should set `scale_parameter=False` and
+ `relative_step=False`.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): external learning rate (default: None)
+ eps (tuple[float, float]): regularization constans for square gradient
+ and parameter scale respectively (default: (1e-30, 1e-3))
+ clip_threshold (float): threshold of root mean square of
+ final gradient update (default: 1.0)
+ decay_rate (float): coefficient used to compute running averages of square
+ gradient (default: -0.8)
+ beta1 (float): coefficient used for computing running averages of gradient
+ (default: None)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ scale_parameter (bool): if True, learning rate is scaled by root mean square of
+ parameter (default: True)
+ relative_step (bool): if True, time-dependent learning rate is computed
+ instead of external learning rate (default: True)
+ warmup_init (bool): time-dependent learning rate computation depends on
+ whether warm-up initialization is being used (default: False)
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=None,
+ eps=(1e-30, 1e-3),
+ clip_threshold=1.0,
+ decay_rate=-0.8,
+ beta1=None,
+ weight_decay=0.0,
+ scale_parameter=True,
+ relative_step=True,
+ warmup_init=False,
+ ):
+ if lr is not None and relative_step:
+ raise ValueError("Cannot combine manual lr and relative_step options")
+ if warmup_init and not relative_step:
+ raise ValueError("warmup_init requires relative_step=True")
+
+ defaults = dict(
+ lr=lr,
+ eps=eps,
+ clip_threshold=clip_threshold,
+ decay_rate=decay_rate,
+ beta1=beta1,
+ weight_decay=weight_decay,
+ scale_parameter=scale_parameter,
+ relative_step=relative_step,
+ warmup_init=warmup_init,
+ )
+ super(Adafactor, self).__init__(params, defaults)
+
+ @property
+ def supports_memory_efficient_fp16(self):
+ return True
+
+ @property
+ def supports_flat_params(self):
+ return False
+
+ def _get_lr(self, param_group, param_state):
+ rel_step_sz = param_group["lr"]
+ if param_group["relative_step"]:
+ min_step = (
+ 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
+ )
+ rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
+ param_scale = 1.0
+ if param_group["scale_parameter"]:
+ param_scale = max(param_group["eps"][1], param_state["RMS"])
+ return param_scale * rel_step_sz
+
+ def _get_options(self, param_group, param_shape):
+ factored = len(param_shape) >= 2
+ use_first_moment = param_group["beta1"] is not None
+ return factored, use_first_moment
+
+ def _rms(self, tensor):
+ return tensor.norm(2) / (tensor.numel() ** 0.5)
+
+ def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
+ r_factor = (
+ (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True))
+ .rsqrt_()
+ .unsqueeze(-1)
+ )
+ c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
+ return torch.mul(r_factor, c_factor)
+
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+ grad = p.grad.data
+ if grad.dtype in {torch.float16, torch.bfloat16}:
+ grad = grad.float()
+ if grad.is_sparse:
+ raise RuntimeError("Adafactor does not support sparse gradients.")
+
+ state = self.state[p]
+ grad_shape = grad.shape
+
+ factored, use_first_moment = self._get_options(group, grad_shape)
+ # State Initialization
+ if len(state) == 0:
+ state["step"] = 0
+
+ if use_first_moment:
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros_like(grad)
+ if factored:
+ state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
+ state["exp_avg_sq_col"] = torch.zeros(
+ grad_shape[:-2] + grad_shape[-1:]
+ ).to(grad)
+ else:
+ state["exp_avg_sq"] = torch.zeros_like(grad)
+
+ state["RMS"] = 0
+ else:
+ if use_first_moment:
+ state["exp_avg"] = state["exp_avg"].to(grad)
+ if factored:
+ state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
+ state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
+ else:
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
+
+ p_data_fp32 = p.data
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
+ p_data_fp32 = p_data_fp32.float()
+
+ state["step"] += 1
+ state["RMS"] = self._rms(p_data_fp32)
+ group["lr"] = self._get_lr(group, state)
+
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
+ update = (grad ** 2) + group["eps"][0]
+ if factored:
+ exp_avg_sq_row = state["exp_avg_sq_row"]
+ exp_avg_sq_col = state["exp_avg_sq_col"]
+
+ exp_avg_sq_row.mul_(beta2t).add_(
+ update.mean(dim=-1), alpha=1.0 - beta2t
+ )
+ exp_avg_sq_col.mul_(beta2t).add_(
+ update.mean(dim=-2), alpha=1.0 - beta2t
+ )
+
+ # Approximation of exponential moving average of square of gradient
+ update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
+ update.mul_(grad)
+ else:
+ exp_avg_sq = state["exp_avg_sq"]
+
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t)
+ update = exp_avg_sq.rsqrt().mul_(grad)
+
+ update.div_(
+ (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)
+ )
+ update.mul_(group["lr"])
+
+ if use_first_moment:
+ exp_avg = state["exp_avg"]
+ exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"])
+ update = exp_avg
+
+ if group["weight_decay"] != 0:
+ p_data_fp32.add_(
+ p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
+ )
+
+ p_data_fp32.add_(-update)
+
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
+ p.data.copy_(p_data_fp32)
+
+ return loss
diff --git a/fairseq-0.10.2/fairseq/optim/adamax.py b/fairseq-0.10.2/fairseq/optim/adamax.py
new file mode 100644
index 0000000000000000000000000000000000000000..577a68816692710b008f9088bdf7fa45868c64a0
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/optim/adamax.py
@@ -0,0 +1,172 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.optim
+
+from . import LegacyFairseqOptimizer, register_optimizer
+
+
+@register_optimizer("adamax")
+class FairseqAdamax(LegacyFairseqOptimizer):
+ def __init__(self, args, params):
+ super().__init__(args)
+ self._optimizer = Adamax(params, **self.optimizer_config)
+
+ @staticmethod
+ def add_args(parser):
+ """Add optimizer-specific arguments to the parser."""
+ # fmt: off
+ parser.add_argument('--adamax-betas', default='(0.9, 0.999)', metavar='B',
+ help='betas for Adam optimizer')
+ parser.add_argument('--adamax-eps', type=float, default=1e-8, metavar='D',
+ help='epsilon for Adam optimizer')
+ parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
+ help='weight decay')
+ parser.add_argument('--no-bias-correction', default=False, action='store_true',
+ help='disable bias correction')
+ # fmt: on
+
+ @property
+ def optimizer_config(self):
+ """
+ Return a kwarg dictionary that will be used to override optimizer
+ args stored in checkpoints. This allows us to load a checkpoint and
+ resume training using a different set of optimizer args, e.g., with a
+ different learning rate.
+ """
+ return {
+ "lr": self.args.lr[0],
+ "betas": eval(self.args.adamax_betas),
+ "eps": self.args.adamax_eps,
+ "weight_decay": self.args.weight_decay,
+ "bias_correction": not self.args.no_bias_correction,
+ }
+
+
+class Adamax(torch.optim.Optimizer):
+ """Implements Adamax algorithm (a variant of Adam based on infinity norm).
+
+ It has been proposed in `Adam: A Method for Stochastic Optimization`__.
+
+ Compared to the version in PyTorch, this version implements a fix for weight decay.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 2e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ bias_correction (bool, optional): enable bias correction (default: True)
+
+ __ https://arxiv.org/abs/1412.6980
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=2e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ bias_correction=True,
+ ):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+
+ defaults = dict(
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ bias_correction=bias_correction,
+ )
+ super(Adamax, self).__init__(params, defaults)
+
+ @property
+ def supports_memory_efficient_fp16(self):
+ return True
+
+ @property
+ def supports_flat_params(self):
+ return True
+
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+ grad = p.grad.data.float()
+ if grad.is_sparse:
+ raise RuntimeError("Adamax does not support sparse gradients")
+
+ p_data_fp32 = p.data
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
+ p_data_fp32 = p_data_fp32.float()
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state["step"] = 0
+ state["exp_avg"] = torch.zeros_like(p_data_fp32)
+ state["exp_inf"] = torch.zeros_like(p_data_fp32)
+ else:
+ state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
+ state["exp_inf"] = state["exp_inf"].to(p_data_fp32)
+
+ exp_avg, exp_inf = state["exp_avg"], state["exp_inf"]
+ beta1, beta2 = group["betas"]
+ eps = group["eps"]
+
+ state["step"] += 1
+
+ # Update biased first moment estimate.
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+
+ # Update the exponentially weighted infinity norm.
+ torch.max(
+ exp_inf.mul_(beta2),
+ grad.abs_(),
+ out=exp_inf,
+ )
+
+ step_size = group["lr"]
+ if group["bias_correction"]:
+ bias_correction = 1 - beta1 ** state["step"]
+ step_size /= bias_correction
+
+ if group["weight_decay"] != 0:
+ p_data_fp32.add_(
+ p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
+ )
+
+ p_data_fp32.addcdiv_(exp_avg, exp_inf.add(eps), value=-step_size)
+
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
+ p.data.copy_(p_data_fp32)
+
+ return loss
diff --git a/fairseq-0.10.2/fairseq/optim/bmuf.py b/fairseq-0.10.2/fairseq/optim/bmuf.py
new file mode 100644
index 0000000000000000000000000000000000000000..3312f81103490f8414487e54c6e0a4a5b2aa5de3
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/optim/bmuf.py
@@ -0,0 +1,231 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass, field
+
+import torch
+import torch.distributed as dist
+from fairseq.dataclass import FairseqDataclass
+from fairseq.dataclass.utils import gen_parser_from_dataclass
+from fairseq.optim.fairseq_optimizer import FairseqOptimizer
+from omegaconf import II
+
+
+@dataclass
+class FairseqBMUFConfig(FairseqDataclass):
+ block_lr: float = field(
+ default=1, metadata={"help": "block learning rate for bmuf"}
+ )
+ block_momentum: float = field(
+ default=0.875, metadata={"help": "block momentum for bmuf"}
+ )
+ global_sync_iter: int = field(
+ default=50, metadata={"help": "Iteration for syncing global model"}
+ )
+ warmup_iterations: int = field(
+ default=500, metadata={"help": "warmup iterations for model to broadcast"}
+ )
+ use_nbm: bool = field(
+ default=False,
+ metadata={"help": "Specify whether you want to use classical BM / Nesterov BM"},
+ )
+ average_sync: bool = field(
+ default=False,
+ metadata={
+ "help": "Specify whether you want to average the local momentum after each sync"
+ },
+ )
+ distributed_world_size: int = II(
+ "params.distributed_training.distributed_world_size"
+ )
+
+
+class FairseqBMUF(FairseqOptimizer):
+ """
+ Implements incremental block distributed data parallelism similar to
+ https://ieeexplore.ieee.org/document/7472805
+
+ Paper title: Scalable training of deep learning machines by incremental
+ block training with intra-block parallel optimization and blockwise
+ model-update filtering
+ """
+
+ def __init__(self, args, optimizer):
+
+ super().__init__(args)
+ self._optimizer = optimizer
+ self._num_updates = 0
+ self.sync_iter = self.args.global_sync_iter
+ self.block_momentum = self.args.block_momentum
+ self.block_lr = self.args.block_lr
+ self._reset_local_data()
+ self.warmup_iteration = self.args.warmup_iterations
+ self.use_nbm = self.args.use_nbm
+ self.initial_state = self._optimizer.state_dict()
+ self.average_sync = self.args.average_sync
+ self.world_size = self.args.distributed_world_size
+
+ @staticmethod
+ def add_args(parser):
+ """Add optimizer-specific arguments to the parser."""
+ gen_parser_from_dataclass(parser, FairseqBMUFConfig())
+
+ @property
+ def optimizer(self):
+ return self._optimizer.optimizer
+
+ @property
+ def optimizer_config(self):
+ return self._optimizer.optimizer_config
+
+ def get_lr(self):
+ return self._optimizer.get_lr()
+
+ def set_lr(self, lr):
+ self._optimizer.set_lr(lr)
+
+ def state_dict(self):
+ return self._optimizer.state_dict()
+
+ def load_state_dict(self, state_dict, optimizer_overrides=None):
+ self._optimizer.load_state_dict(state_dict, optimizer_overrides)
+ self.initial_state = self._optimizer.state_dict()
+
+ def multiply_grads(self, c):
+ """Multiplies grads by a constant *c*."""
+ self._optimizer.multiply_grads(c)
+
+ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
+ """Clips gradient norm."""
+ return self._optimizer.clip_grad_norm(max_norm, aggregate_norm_fn)
+
+ def average_params(self):
+ self._optimizer.average_params()
+
+ def _block_sync(self):
+ if self.world_size <= 1:
+ return
+ # Update the global model using local models from all GPUs
+ # (Step-1) Calculate grad between previously synced model and
+ # currrent local model
+ if self.block_momentum != 0:
+ self._calc_grad()
+
+ # (Step-2) Average gradient from all GPUs
+ self._avg_grad_from_all_gpus()
+
+ # (Step-3) Calculate global momentum and update the global model
+ if self.block_momentum != 0:
+ self._update_global_model()
+
+ # (Step-4) Average local optimizer params
+ if self.average_sync:
+ self.average_params()
+
+ def _is_warmup_end(self):
+ # Check whether train iterations is equal to warmup iter
+ if self.get_num_updates() == self.warmup_iteration:
+ return True
+ return False
+
+ def _is_bmuf_iter(self):
+ # Check whether train iterations is equal to bmuf sync iter
+ if (self.get_num_updates() > self.warmup_iteration) and (
+ self.get_num_updates() % self.sync_iter == 0
+ ):
+ return True
+ return False
+
+ def _warmup_sync(self, root_rank=0):
+ if self.world_size <= 1:
+ return
+ # Broadcast the local model to all gpus
+ for param in self.params:
+ dist.broadcast(param.data, src=root_rank)
+
+ # Update local optimizer state
+ if self.average_sync:
+ self._optimizer.average_params()
+ else:
+ self._optimizer.load_state_dict(self.initial_state)
+
+ self._reset_local_data()
+
+ def step(self, closure=None):
+ """Performs a single optimization step."""
+ self._optimizer.step(closure)
+ self.set_num_updates(self.get_num_updates() + 1)
+ if self._is_warmup_end():
+ self._warmup_sync()
+ elif self._is_bmuf_iter():
+ self._block_sync()
+
+ def zero_grad(self):
+ """Clears the gradients of all optimized parameters."""
+ self._optimizer.zero_grad()
+
+ def get_num_updates(self):
+ """Get the number of parameters updates."""
+ return self._num_updates
+
+ def set_num_updates(self, num_updates):
+ """Set the number of parameters updates."""
+ self._num_updates = num_updates
+
+ @torch.no_grad()
+ def _reset_local_data(self):
+ # (Step-0) Initialize global momentum parameters and store global copy on each gpu
+ self.global_params = [torch.zeros_like(p.data) for p in self.params]
+ self.smoothed_grads = [p.data.new_zeros(p.data.size()) for p in self.params]
+ self.grads = [p.data.new_zeros(p.data.size()) for p in self.params]
+
+ # saving the global model locally for calculating gradient during bmuf sync
+ for param, global_param in zip(self.params, self.global_params):
+ global_param.copy_(param.data)
+
+ @torch.no_grad()
+ def _calc_grad(self):
+ # global_params is basically the global copy from the previously finished
+ # synchronisation. param.data is local parameter after block_sync_freq
+ # for the local gpu. so grad is difference between previously synced
+ # model and currrent local model.
+ for index, (param, global_param) in enumerate(
+ zip(self.params, self.global_params)
+ ):
+ self.grads[index] = global_param - param.data
+
+ def _avg_grad_from_all_gpus(self):
+ for index, param in enumerate(self.params):
+ sync_para = param.data if self.block_momentum == 0 else self.grads[index]
+ sync_para /= float(dist.get_world_size())
+ dist.all_reduce(sync_para, op=dist.ReduceOp.SUM)
+
+ @torch.no_grad()
+ def _update_global_model(self):
+ for index, (param, global_param, smoothed_grad, grad) in enumerate(
+ zip(
+ self.params,
+ self.global_params,
+ self.smoothed_grads,
+ # all gpus would share the same value of smoothed_grad, since it is
+ # always computed on synchronized gradients.
+ self.grads,
+ )
+ ):
+ # global_param is basically last syncrhornized parameter. though
+ # smoothed_grad is local, all processes will have same value of
+ # smoothed_grad and hence param is globally synchronized copy.
+ # smoothed_grad(t) = BM * smoothed_grad(t-1) + BM_lr * grad(t)
+ smoothed_grad = self.block_momentum * smoothed_grad + self.block_lr * grad
+ param.data.copy_(global_param - smoothed_grad)
+
+ # A Nesterov momentum here is to do a partial weight update before
+ # calculating the gradient
+ if self.use_nbm:
+ param.data.copy_(param.data - self.block_momentum * smoothed_grad)
+
+ # backup for the next synchronization.
+ self.smoothed_grads[index] = smoothed_grad
+ global_param.copy_(param.data)
diff --git a/fairseq-0.10.2/fairseq/optim/fairseq_optimizer.py b/fairseq-0.10.2/fairseq/optim/fairseq_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a10399a8b413c4188dbe7c8d51f5353348d835b
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/optim/fairseq_optimizer.py
@@ -0,0 +1,150 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from fairseq import utils
+from fairseq.dataclass.utils import gen_parser_from_dataclass
+
+
+class FairseqOptimizer(object):
+ def __init__(self, args):
+ super().__init__()
+ self.args = args
+
+ @classmethod
+ def add_args(cls, parser):
+ """Add optimizer-specific arguments to the parser."""
+ dc = getattr(cls, "__dataclass", None)
+ if dc is not None:
+ gen_parser_from_dataclass(parser, dc())
+
+ @property
+ def optimizer(self):
+ """Return a torch.optim.optimizer.Optimizer instance."""
+ if not hasattr(self, "_optimizer"):
+ raise NotImplementedError
+ if not isinstance(self._optimizer, torch.optim.Optimizer):
+ raise ValueError("_optimizer must be an instance of torch.optim.Optimizer")
+ return self._optimizer
+
+ @optimizer.setter
+ def optimizer(self, optimizer):
+ """Reset optimizer instance."""
+ if not hasattr(self, "_optimizer"):
+ raise NotImplementedError
+ if not isinstance(self._optimizer, torch.optim.Optimizer):
+ raise ValueError("_optimizer must be an instance of torch.optim.Optimizer")
+ self._optimizer = optimizer
+
+ @property
+ def optimizer_config(self):
+ """
+ Return a kwarg dictionary that will be used to override optimizer
+ args stored in checkpoints. This allows us to load a checkpoint and
+ resume training using a different set of optimizer args, e.g., with a
+ different learning rate.
+ """
+ raise NotImplementedError
+
+ @property
+ def params(self):
+ """Return an iterable of the parameters held by the optimizer."""
+ for param_group in self.param_groups:
+ for p in param_group["params"]:
+ yield p
+
+ @property
+ def param_groups(self):
+ return self.optimizer.param_groups
+
+ def __getstate__(self):
+ return self._optimizer.__getstate__()
+
+ def get_lr(self):
+ """Return the current learning rate."""
+ return self.param_groups[0]["lr"]
+
+ def set_lr(self, lr):
+ """Set the learning rate."""
+ for param_group in self.param_groups:
+ param_group["lr"] = lr
+
+ def state_dict(self):
+ """Return the optimizer's state dict."""
+ return self.optimizer.state_dict()
+
+ def load_state_dict(self, state_dict, optimizer_overrides=None):
+ """Load an optimizer state dict.
+
+ In general we should prefer the configuration of the existing optimizer
+ instance (e.g., learning rate) over that found in the state_dict. This
+ allows us to resume training from a checkpoint using a new set of
+ optimizer args.
+ """
+ self.optimizer.load_state_dict(state_dict)
+
+ if optimizer_overrides is not None and len(optimizer_overrides) > 0:
+ # override learning rate, momentum, etc. with latest values
+ for group in self.param_groups:
+ group.update(optimizer_overrides)
+
+ def backward(self, loss):
+ """Computes the sum of gradients of the given tensor w.r.t. graph leaves."""
+ loss.backward()
+
+ def multiply_grads(self, c):
+ """Multiplies grads by a constant *c*."""
+ for p in self.params:
+ if p.grad is not None:
+ p.grad.data.mul_(c)
+
+ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
+ """Clips gradient norm."""
+ return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn)
+
+ def step(self, closure=None, scale=1.0):
+ """Performs a single optimization step."""
+ if self.supports_step_with_scale:
+ self.optimizer.step(closure, scale=scale)
+ else:
+ if scale != 1.0:
+ self.multiply_grads(1.0 / scale)
+ self.optimizer.step(closure)
+
+ def zero_grad(self):
+ """Clears the gradients of all optimized parameters."""
+ for p in self.params:
+ p.grad = None
+ self.optimizer.zero_grad()
+
+ @property
+ def supports_memory_efficient_fp16(self):
+ if hasattr(self.optimizer, "supports_memory_efficient_fp16"):
+ return self.optimizer.supports_memory_efficient_fp16
+ return False
+
+ @property
+ def supports_step_with_scale(self):
+ if hasattr(self.optimizer, "supports_step_with_scale"):
+ return self.optimizer.supports_step_with_scale
+ return False
+
+ @property
+ def supports_flat_params(self):
+ """
+ Whether the optimizer supports collapsing of the model
+ parameters/gradients into a single contiguous Tensor.
+ """
+ if hasattr(self.optimizer, "supports_flat_params"):
+ return self.optimizer.supports_flat_params
+ return False
+
+ def average_params(self):
+ pass
+
+
+class LegacyFairseqOptimizer(FairseqOptimizer):
+ def __init__(self, args):
+ self.args = args
diff --git a/fairseq-0.10.2/fairseq/optim/fp16_optimizer.py b/fairseq-0.10.2/fairseq/optim/fp16_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..687315030fef38390d22f17db1cbae88a1560266
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/optim/fp16_optimizer.py
@@ -0,0 +1,491 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+from itertools import chain
+
+import torch
+from fairseq import optim, utils
+
+from .dynamic_loss_scaler import DynamicLossScaler
+
+
+class _FP16OptimizerMixin(object):
+ def __init__(self, *args, **kwargs):
+ # forward __init__ call to the next class in mro(method resolution order)
+ super().__init__(*args, **kwargs)
+ self._multiply_factor = 1.0
+
+ @property
+ def has_flat_params(self):
+ return torch.is_tensor(self.fp32_params) or (
+ isinstance(self.fp32_params, dict)
+ and all(torch.is_tensor(t) for t in self.fp32_params.values())
+ )
+
+ @classmethod
+ def build_fp32_params(cls, args, params, flatten=True):
+ # create FP32 copy of parameters and grads
+ if flatten:
+ is_pipeline_parallel = getattr(
+ args, "pipeline_model_parallel", False
+ ) and getattr(args, "distributed_no_spawn", False)
+ total_param_size = sum(p.data.numel() for p in params)
+ devices = [torch.cuda.current_device()]
+ if is_pipeline_parallel:
+ devices = list(set(args.pipeline_devices))
+ fp32_params = {}
+ for device in devices:
+ if is_pipeline_parallel:
+ device_param_size = sum(
+ p.data.numel() for p in params if p.device.index == device
+ )
+ device_params = [p for p in params if p.device.index == device]
+ else:
+ device_param_size = total_param_size
+ device_params = params
+ fp32_params[device] = (
+ device_params[0].new(0).float().new(device_param_size)
+ )
+ offset = 0
+ for p in device_params:
+ numel = p.data.numel()
+ fp32_params[device][offset : offset + numel].copy_(p.data.view(-1))
+ offset += numel
+ fp32_params[device] = torch.nn.Parameter(fp32_params[device])
+ fp32_params[device].grad = fp32_params[device].data.new(
+ device_param_size
+ )
+ return fp32_params
+ else:
+ fp32_params = []
+ for p in params:
+ p32 = torch.nn.Parameter(p.data.float())
+ p32.grad = torch.zeros_like(p32.data)
+ fp32_params.append(p32)
+ return fp32_params
+
+ def state_dict(self):
+ """Return the optimizer's state dict."""
+ state_dict = self.fp32_optimizer.state_dict()
+ if self.scaler is not None:
+ state_dict["loss_scale"] = self.scaler.loss_scale
+ return state_dict
+
+ def load_state_dict(self, state_dict, optimizer_overrides=None):
+ """Load an optimizer state dict.
+
+ In general we should prefer the configuration of the existing optimizer
+ instance (e.g., learning rate) over that found in the state_dict. This
+ allows us to resume training from a checkpoint using a new set of
+ optimizer args.
+ """
+ if "loss_scale" in state_dict and self.scaler is not None:
+ self.scaler.loss_scale = state_dict["loss_scale"]
+ self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
+
+ def backward(self, loss):
+ """Computes the sum of gradients of the given tensor w.r.t. graph leaves.
+
+ Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
+ function additionally dynamically scales the loss to avoid gradient
+ underflow.
+ """
+ if self.scaler is not None:
+ loss = self.scaler.scale(loss)
+ loss.backward()
+ self._needs_sync = True
+
+ def _sync_fp16_grads_to_fp32(self):
+ if self._needs_sync:
+ # copy FP16 grads to FP32
+ if self.has_flat_params:
+ devices = list(self.fp32_params.keys())
+ device_params_dict = defaultdict(list)
+ for p in self.fp16_params:
+ if p.requires_grad:
+ device_params_dict[p.device.index].append(p)
+ for device in devices:
+ device_params = device_params_dict[device]
+ offset = 0
+ for p in device_params:
+ grad_data = (
+ p.grad.data
+ if p.grad is not None
+ else p.data.new_zeros(p.data.shape)
+ )
+ numel = grad_data.numel()
+ self.fp32_params[device].grad.data[
+ offset : offset + numel
+ ].copy_(grad_data.view(-1))
+ offset += numel
+ else:
+ for p, p32 in zip(self.fp16_params, self.fp32_params):
+ if not p.requires_grad:
+ continue
+ if p.grad is not None:
+ p32.grad.data.copy_(p.grad.data)
+ else:
+ p32.grad = torch.zeros_like(p.data, dtype=torch.float)
+
+ self._needs_sync = False
+
+ def _sync_fp32_params_to_fp16(self):
+ # copy FP32 params back into FP16 model
+ if self.has_flat_params:
+ devices = list(self.fp32_params.keys())
+ device_params_dict = defaultdict(list)
+ for p in self.fp16_params:
+ device_params_dict[p.device.index].append(p)
+ for device in devices:
+ device_params = device_params_dict[device]
+ offset = 0
+ for p in device_params:
+ numel = p.data.numel()
+ p.data.copy_(
+ self.fp32_params[device]
+ .data[offset : offset + numel]
+ .view_as(p.data)
+ )
+ offset += numel
+ else:
+ for p, p32 in zip(self.fp16_params, self.fp32_params):
+ if not p.requires_grad:
+ continue
+ p.data.copy_(p32.data)
+
+ def _unscale_grads(self):
+ self._sync_fp16_grads_to_fp32()
+ if self._multiply_factor != 1.0:
+ self.fp32_optimizer.multiply_grads(self._multiply_factor)
+ self._multiply_factor = 1.0
+
+ def multiply_grads(self, c):
+ """Multiplies grads by a constant ``c``."""
+ self._multiply_factor *= c
+
+ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
+ """Clips gradient norm and updates dynamic loss scaler."""
+ self._sync_fp16_grads_to_fp32()
+
+ grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm(
+ 0, aggregate_norm_fn
+ )
+
+ if self.scaler is not None:
+ if grad_norm > max_norm > 0.0:
+ self._multiply_factor *= max_norm / grad_norm
+
+ self.scaler.check_overflow(grad_norm)
+ elif max_norm > 0.0:
+ clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1)
+ self._multiply_factor *= clip_coef
+
+ return grad_norm
+
+ def step(self, closure=None):
+ """Performs a single optimization step."""
+ self._sync_fp16_grads_to_fp32()
+
+ if getattr(self, "supports_step_with_scale", False):
+ self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor))
+ else:
+ self._unscale_grads()
+ self.fp32_optimizer.step(closure)
+
+ if self.scaler is not None:
+ self.scaler.update()
+
+ self._sync_fp32_params_to_fp16()
+
+ def zero_grad(self):
+ """Clears the gradients of all optimized parameters."""
+ for p in self.fp16_params:
+ p.grad = None
+ if self.has_flat_params:
+ if torch.is_tensor(self.fp32_params):
+ self.fp32_params.grad.zero_()
+ elif isinstance(self.fp32_params, dict):
+ for fp32_params in self.fp32_params.values():
+ fp32_params.grad.zero_()
+ else:
+ raise ("self.fp32_params must be a tensor or dict")
+ else:
+ for p32 in self.fp32_params:
+ if p32.grad is None:
+ p32.grad.zero_()
+ self._needs_sync = False
+
+ if self.scaler is not None:
+ self._multiply_factor = 1.0 / float(self.scaler.loss_scale)
+
+
+class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
+ """
+ Wrap an *optimizer* to support FP16 (mixed precision) training.
+ """
+
+ def __init__(self, args, params, fp32_optimizer, fp32_params):
+ super().__init__(args)
+ self.fp16_params = params
+ self.fp32_optimizer = fp32_optimizer
+ self.fp32_params = fp32_params
+
+ if getattr(args, "fp16_scale_window", None) is None:
+ if len(args.update_freq) > 1:
+ raise ValueError(
+ "--fp16-scale-window must be given explicitly when using a "
+ "custom --update-freq schedule"
+ )
+ data_parallel_size = int(
+ args.distributed_world_size / args.model_parallel_size
+ )
+ scale_window = int(2 ** 14 / data_parallel_size / args.update_freq[0])
+ else:
+ scale_window = args.fp16_scale_window
+
+ if not getattr(args, "bf16", False):
+ self.scaler = DynamicLossScaler(
+ init_scale=args.fp16_init_scale,
+ scale_window=scale_window,
+ tolerance=args.fp16_scale_tolerance,
+ threshold=args.threshold_loss_scale,
+ min_loss_scale=args.min_loss_scale,
+ )
+ else:
+ # disable loss scaling for bfloat16
+ self.scaler = None
+
+ @classmethod
+ def build_optimizer(cls, args, params):
+ """
+ Args:
+ args (argparse.Namespace): fairseq args
+ params (iterable): iterable of parameters to optimize
+ """
+ flatten = not getattr(args, "fp16_no_flatten_grads", False)
+ if getattr(args, "bf16", False):
+ flatten = False # mixed precision is faster on TPUs without flat grads
+ fp32_params = cls.build_fp32_params(args, params, flatten=flatten)
+ if flatten:
+ fp32_optimizer = optim.build_optimizer(args, [fp32_params])
+ else:
+ fp32_optimizer = optim.build_optimizer(args, fp32_params)
+ if flatten and not fp32_optimizer.supports_flat_params:
+ raise RuntimeError(
+ "chosen optimizer does not support flat params, "
+ "please set --fp16-no-flatten-grads"
+ )
+ return cls(args, params, fp32_optimizer, fp32_params)
+
+ @property
+ def optimizer(self):
+ return self.fp32_optimizer.optimizer
+
+ @optimizer.setter
+ def optimizer(self, optimizer):
+ self.fp32_optimizer.optimizer = optimizer
+
+ @property
+ def optimizer_config(self):
+ return self.fp32_optimizer.optimizer_config
+
+ def get_lr(self):
+ return self.fp32_optimizer.get_lr()
+
+ def set_lr(self, lr):
+ self.fp32_optimizer.set_lr(lr)
+
+
+class _MemoryEfficientFP16OptimizerMixin(object):
+ def __init__(self, *args, **kwargs):
+ # forward __init__ call to the next class in MRO (method resolution order)
+ super().__init__(*args, **kwargs)
+ self._multiply_factor = 1.0
+
+ @property
+ def has_flat_params(self):
+ return False
+
+ def state_dict(self):
+ """Return the optimizer's state dict."""
+ state_dict = self.wrapped_optimizer.state_dict()
+ if self.scaler is not None:
+ state_dict["loss_scale"] = self.scaler.loss_scale
+ return state_dict
+
+ def load_state_dict(self, state_dict, optimizer_overrides=None):
+ """Load an optimizer state dict.
+
+ In general we should prefer the configuration of the existing optimizer
+ instance (e.g., learning rate) over that found in the state_dict. This
+ allows us to resume training from a checkpoint using a new set of
+ optimizer args.
+ """
+ if "loss_scale" in state_dict and self.scaler is not None:
+ self.scaler.loss_scale = state_dict["loss_scale"]
+
+ self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides)
+
+ # Hack: PyTorch automatically casts the optimizer state to match the
+ # type of the current parameters. But with --memory-efficient-fp16 the
+ # params are FP16 while the optimizer state is FP32 and we don't want
+ # to cast. A workaround is to manually copy back the original state
+ # after the optimizer has been loaded.
+ if not getattr(self.optimizer, "disable_mem_eff_fp16_loading_hack", False):
+ groups = self.optimizer.param_groups
+ saved_groups = state_dict["param_groups"]
+ id_map = {
+ old_id: p
+ for old_id, p in zip(
+ chain(*(g["params"] for g in saved_groups)),
+ chain(*(g["params"] for g in groups)),
+ )
+ }
+ for k, v in state_dict["state"].items():
+ if k in id_map:
+ param = id_map[k]
+ self.optimizer.state[param] = v
+
+ def backward(self, loss):
+ """Computes the sum of gradients of the given tensor w.r.t. graph leaves.
+
+ Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
+ function additionally dynamically scales the loss to avoid gradient
+ underflow.
+ """
+ if self.scaler is not None:
+ loss = self.scaler.scale(loss)
+ loss.backward()
+
+ def _unscale_grads(self):
+ if self._multiply_factor != 1.0:
+ self.wrapped_optimizer.multiply_grads(self._multiply_factor)
+ self._multiply_factor = 1.0
+
+ def multiply_grads(self, c):
+ """Multiplies grads by a constant *c*."""
+ self._multiply_factor *= c
+
+ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
+ """Clips gradient norm and updates dynamic loss scaler."""
+ max_norm = float(max_norm)
+ grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm(
+ 0, aggregate_norm_fn
+ )
+
+ if self.scaler is not None:
+ grad_norm_cpu = float(grad_norm)
+ if grad_norm_cpu > max_norm > 0.0:
+ self._multiply_factor *= max_norm / grad_norm_cpu
+
+ # detect overflow and adjust loss scale
+ self.scaler.check_overflow(grad_norm_cpu)
+ elif max_norm > 0.0:
+ clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1)
+ self._multiply_factor *= clip_coef
+
+ return grad_norm
+
+ def step(self, closure=None):
+ """Performs a single optimization step."""
+ if getattr(self, "supports_step_with_scale", False):
+ # NOTE(msb) optimizer divides by scale factor
+ self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor))
+ else:
+ self._unscale_grads()
+ self.wrapped_optimizer.step(closure)
+
+ if self.scaler is not None:
+ self.scaler.update()
+
+ def zero_grad(self):
+ """Clears the gradients of all optimized parameters."""
+ self.wrapped_optimizer.zero_grad()
+ if self.scaler is not None:
+ self._multiply_factor = 1.0 / float(self.scaler.loss_scale)
+ else:
+ self._multiply_factor = 1.0
+
+
+class MemoryEfficientFP16Optimizer(
+ _MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer
+):
+ """
+ Wrap an *optimizer* to support FP16 (mixed precision) training.
+
+ Compared to :class:`fairseq.optim.FP16Optimizer`, this version does not
+ maintain an FP32 copy of the model. We instead expect the optimizer to
+ convert the gradients to FP32 internally and sync the results back to the
+ FP16 model params. This significantly reduces memory usage but slightly
+ increases the time spent in the optimizer.
+
+ Since this wrapper depends on specific functionality in the wrapped
+ optimizer (i.e., on-the-fly conversion of grads to FP32), only certain
+ optimizers can be wrapped. This is determined by the
+ *supports_memory_efficient_fp16* property.
+ """
+
+ def __init__(self, args, params, optimizer):
+ if not optimizer.supports_memory_efficient_fp16:
+ raise ValueError(
+ "Unsupported optimizer: {}".format(optimizer.__class__.__name__)
+ )
+
+ super().__init__(args)
+ self.wrapped_optimizer = optimizer
+
+ if getattr(args, "fp16_scale_window", None) is None:
+ if len(args.update_freq) > 1:
+ raise ValueError(
+ "--fp16-scale-window must be given explicitly when using a "
+ "custom --update-freq schedule"
+ )
+ data_parallel_size = int(
+ args.distributed_world_size / args.model_parallel_size
+ )
+ scale_window = 2 ** 14 / data_parallel_size / args.update_freq[0]
+ else:
+ scale_window = args.fp16_scale_window
+
+ if not getattr(args, "bf16", False):
+ self.scaler = DynamicLossScaler(
+ init_scale=args.fp16_init_scale,
+ scale_window=scale_window,
+ tolerance=args.fp16_scale_tolerance,
+ threshold=args.threshold_loss_scale,
+ min_loss_scale=args.min_loss_scale,
+ )
+ else:
+ # disable loss scaling for bfloat16
+ self.scaler = None
+
+ @classmethod
+ def build_optimizer(cls, args, params):
+ """
+ Args:
+ args (argparse.Namespace): fairseq args
+ params (iterable): iterable of parameters to optimize
+ """
+ fp16_optimizer = optim.build_optimizer(args, params)
+ return cls(args, params, fp16_optimizer)
+
+ @property
+ def optimizer(self):
+ return self.wrapped_optimizer.optimizer
+
+ @optimizer.setter
+ def optimizer(self, optimizer):
+ self.wrapped_optimizer.optimizer = optimizer
+
+ @property
+ def optimizer_config(self):
+ return self.wrapped_optimizer.optimizer_config
+
+ def get_lr(self):
+ return self.wrapped_optimizer.get_lr()
+
+ def set_lr(self, lr):
+ self.wrapped_optimizer.set_lr(lr)
diff --git a/fairseq-0.10.2/fairseq/optim/fused_lamb.py b/fairseq-0.10.2/fairseq/optim/fused_lamb.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4f2bdb0c6c65f7758509b6d4d2f2c48cb6e8b4f
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/optim/fused_lamb.py
@@ -0,0 +1,51 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
+
+
+@register_optimizer("lamb")
+class FairseqLAMB(LegacyFairseqOptimizer):
+ """LAMB optimizer."""
+
+ def __init__(self, args, params):
+ super().__init__(args)
+ try:
+ from apex.optimizers import FusedLAMB
+
+ self._optimizer = FusedLAMB(params, **self.optimizer_config)
+ except ImportError:
+ raise ImportError("Please install apex to use LAMB optimizer")
+
+ @staticmethod
+ def add_args(parser):
+ """Add optimizer-specific arguments to the parser."""
+ # fmt: off
+ parser.add_argument('--lamb-betas', default='(0.9, 0.999)', metavar='B',
+ help='betas for LAMB optimizer')
+ parser.add_argument('--lamb-eps', type=float, default=1e-8, metavar='D',
+ help='epsilon for LAMB optimizer')
+ parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
+ help='weight decay')
+ # fmt: on
+
+ @property
+ def optimizer_config(self):
+ """
+ Return a kwarg dictionary that will be used to override optimizer
+ args stored in checkpoints. This allows us to load a checkpoint and
+ resume training using a different set of optimizer args, e.g., with a
+ different learning rate.
+ """
+ return {
+ "lr": self.args.lr[0],
+ "betas": eval(self.args.lamb_betas),
+ "eps": self.args.lamb_eps,
+ "weight_decay": self.args.weight_decay,
+ }
+
+ @property
+ def supports_flat_params(self):
+ return False
diff --git a/fairseq-0.10.2/fairseq/optim/shard.py b/fairseq-0.10.2/fairseq/optim/shard.py
new file mode 100644
index 0000000000000000000000000000000000000000..a035a1c1f93d3ce3eae93f8d80216af8ecf9620a
--- /dev/null
+++ b/fairseq-0.10.2/fairseq/optim/shard.py
@@ -0,0 +1,41 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+try:
+ from fairscale.optim import OSS
+
+ _has_fairscale = True
+except ImportError:
+ _has_fairscale = False
+
+
+def shard_(args, optimizer, group):
+ if not _has_fairscale:
+ raise ImportError(
+ "\n\nPlease install the fairscale package:" "\n\n pip install fairscale"
+ )
+
+ class FairseqOSS(OSS):
+ @property
+ def disable_mem_eff_fp16_loading_hack(self):
+ return True
+
+ def __getattr__(self, name):
+ if name.startswith("supports") and hasattr(self.optim, name):
+ return getattr(self.optim, name)
+ raise AttributeError(
+ "'FairseqOSS' object has no attribute {0!r}".format(name)
+ )
+
+ torch_optimizer = optimizer.optimizer
+ optim_cls = type(torch_optimizer)
+
+ optimizer.optimizer = FairseqOSS(
+ torch_optimizer.param_groups,
+ optim_cls,
+ group=group,
+ **optimizer.optimizer_config
+ )