diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..abecf8fbe941be1cfe0cb9fc20f47e06bd5be290 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +fairseq-0.10.2/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/__init__.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96c3657a5dfda82479726d301c9fb1e4cdd8d6ce Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1e8006ca2ecaa2e35b56e3c8ba837a007ab4e93 Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa49ca7677681fed6af7579d0af8989295bbd313 Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1563b1be8e2937c24204f47a747de2ad7afa6d8e Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb2889aa229413ecfd12072d0353b07ac4c6c2af Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2d9cd5e0f7d731b877f1ad6ae51e0c41a74fbe1 Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/dictionary.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/dictionary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2178a6f62ee31358ded09f840ff4272076b5f605 Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/dictionary.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/id_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/id_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56fe292ddb6b418954fac1f16be271a556aeb816 Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/id_dataset.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aee56d6a26e61c9c392d29f2bb772cba3bdc2d3e Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae7ba3410aff285298924c57199024cde1e626be Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88d95564d6166c1e13a48dfc134a6124c9827ade Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87293b6e9583b3ae3f3de1eccf15e9d24516c332 Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..223ee06ac60b91925a696a992e57e95094f84db0 Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2472cdfb1758d4734a519a45723a8d1affb5aab Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13a29dddd4f3171cb1bf276a26a8c0a3621d0bed Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/concat_dataset.py b/fairseq-0.10.2/fairseq/data/concat_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..01a4078bb159fa44b2d1062b9a971fe7f1abd1c2 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/concat_dataset.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import bisect + +import numpy as np +from torch.utils.data.dataloader import default_collate + +from . import FairseqDataset + + +class ConcatDataset(FairseqDataset): + @staticmethod + def cumsum(sequence, sample_ratios): + r, s = [], 0 + for e, ratio in zip(sequence, sample_ratios): + curr_len = int(ratio * len(e)) + r.append(curr_len + s) + s += curr_len + return r + + def __init__(self, datasets, sample_ratios=1): + super(ConcatDataset, self).__init__() + assert len(datasets) > 0, "datasets should not be an empty iterable" + self.datasets = list(datasets) + if isinstance(sample_ratios, int): + sample_ratios = [sample_ratios] * len(self.datasets) + self.sample_ratios = sample_ratios + self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios) + self.real_sizes = [len(d) for d in self.datasets] + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) + return self.datasets[dataset_idx][sample_idx] + + def _get_dataset_and_sample_index(self, idx: int): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + sample_idx = sample_idx % self.real_sizes[dataset_idx] + return dataset_idx, sample_idx + + def collater(self, samples, **extra_args): + # For now only supports datasets with same underlying collater implementations + if hasattr(self.datasets[0], "collater"): + return self.datasets[0].collater(samples, **extra_args) + else: + return default_collate(samples, **extra_args) + + def size(self, idx: int): + """ + Return an example's size as a float or tuple. + """ + dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) + return self.datasets[dataset_idx].size(sample_idx) + + def num_tokens(self, index: int): + return np.max(self.size(index)) + + def attr(self, attr: str, index: int): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, index) + return getattr(self.datasets[dataset_idx], attr, None) + + @property + def sizes(self): + _dataset_sizes = [] + for ds, sr in zip(self.datasets, self.sample_ratios): + if isinstance(ds.sizes, np.ndarray): + _dataset_sizes.append(np.tile(ds.sizes, sr)) + else: + # Only support underlying dataset with single size array. + assert isinstance(ds.sizes, list) + _dataset_sizes.append(np.tile(ds.sizes[0], sr)) + return np.concatenate(_dataset_sizes) + + @property + def supports_prefetch(self): + return all(d.supports_prefetch for d in self.datasets) + + def ordered_indices(self): + """ + Returns indices sorted by length. So less padding is needed. + """ + if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1: + # special handling for concatenating lang_pair_datasets + indices = np.arange(len(self)) + sizes = self.sizes + tgt_sizes = ( + sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None + ) + src_sizes = ( + sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + ) + # sort by target length, then source length + if tgt_sizes is not None: + indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")] + return indices[np.argsort(src_sizes[indices], kind="mergesort")] + else: + return np.argsort(self.sizes) + + def prefetch(self, indices): + frm = 0 + for to, ds in zip(self.cumulative_sizes, self.datasets): + real_size = len(ds) + if getattr(ds, "supports_prefetch", False): + ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) + frm = to + + @property + def can_reuse_epoch_itr_across_epochs(self): + return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + for ds in self.datasets: + if hasattr(ds, "set_epoch"): + ds.set_epoch(epoch) diff --git a/fairseq-0.10.2/fairseq/data/concat_sentences_dataset.py b/fairseq-0.10.2/fairseq/data/concat_sentences_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..625a29370e90f9d1d7274024afb902ed83a22325 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/concat_sentences_dataset.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import FairseqDataset + + +class ConcatSentencesDataset(FairseqDataset): + def __init__(self, *datasets): + super().__init__() + self.datasets = datasets + assert all( + len(ds) == len(datasets[0]) for ds in datasets + ), "datasets must have the same length" + + def __getitem__(self, index): + return torch.cat([ds[index] for ds in self.datasets]) + + def __len__(self): + return len(self.datasets[0]) + + def collater(self, samples): + return self.datasets[0].collater(samples) + + @property + def sizes(self): + return sum(ds.sizes for ds in self.datasets) + + def num_tokens(self, index): + return sum(ds.num_tokens(index) for ds in self.datasets) + + def size(self, index): + return sum(ds.size(index) for ds in self.datasets) + + def ordered_indices(self): + return self.datasets[0].ordered_indices() + + @property + def supports_prefetch(self): + return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets) + + def prefetch(self, indices): + for ds in self.datasets: + if getattr(ds, "supports_prefetch", False): + ds.prefetch(indices) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + for ds in self.datasets: + if hasattr(ds, "set_epoch"): + ds.set_epoch(epoch) diff --git a/fairseq-0.10.2/fairseq/data/dictionary.py b/fairseq-0.10.2/fairseq/data/dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..e2df08e092350c5d5feb34723fae8744c7286a44 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/dictionary.py @@ -0,0 +1,387 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +from collections import Counter +from multiprocessing import Pool + +import torch +from fairseq import utils +from fairseq.binarizer import safe_readline +from fairseq.data import data_utils +from fairseq.file_io import PathManager +from fairseq.tokenizer import tokenize_line + + +class Dictionary(object): + """A mapping from symbols to consecutive integers""" + + def __init__( + self, + *, # begin keyword-only arguments + bos="", + pad="", + eos="", + unk="", + extra_special_symbols=None, + ): + self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos + self.symbols = [] + self.count = [] + self.indices = {} + self.bos_index = self.add_symbol(bos) + self.pad_index = self.add_symbol(pad) + self.eos_index = self.add_symbol(eos) + self.unk_index = self.add_symbol(unk) + if extra_special_symbols: + for s in extra_special_symbols: + self.add_symbol(s) + self.nspecial = len(self.symbols) + + def __eq__(self, other): + return self.indices == other.indices + + def __getitem__(self, idx): + if idx < len(self.symbols): + return self.symbols[idx] + return self.unk_word + + def __len__(self): + """Returns the number of symbols in the dictionary""" + return len(self.symbols) + + def __contains__(self, sym): + return sym in self.indices + + def index(self, sym): + """Returns the index of the specified symbol""" + assert isinstance(sym, str) + if sym in self.indices: + return self.indices[sym] + return self.unk_index + + def string( + self, + tensor, + bpe_symbol=None, + escape_unk=False, + extra_symbols_to_ignore=None, + unk_string=None, + ): + """Helper for converting a tensor of token indices to a string. + + Can optionally remove BPE symbols or escape words. + """ + if torch.is_tensor(tensor) and tensor.dim() == 2: + return "\n".join( + self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore) + for t in tensor + ) + + extra_symbols_to_ignore = set(extra_symbols_to_ignore or []) + extra_symbols_to_ignore.add(self.eos()) + + def token_string(i): + if i == self.unk(): + if unk_string is not None: + return unk_string + else: + return self.unk_string(escape_unk) + else: + return self[i] + + if hasattr(self, "bos_index"): + extra_symbols_to_ignore.add(self.bos()) + + sent = " ".join( + token_string(i) + for i in tensor + if utils.item(i) not in extra_symbols_to_ignore + ) + + return data_utils.post_process(sent, bpe_symbol) + + def unk_string(self, escape=False): + """Return unknown string, optionally escaped as: <>""" + if escape: + return "<{}>".format(self.unk_word) + else: + return self.unk_word + + def add_symbol(self, word, n=1, overwrite=False): + """Adds a word to the dictionary""" + if word in self.indices and not overwrite: + idx = self.indices[word] + self.count[idx] = self.count[idx] + n + return idx + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(n) + return idx + + def update(self, new_dict): + """Updates counts from new dictionary.""" + for word in new_dict.symbols: + idx2 = new_dict.indices[word] + if word in self.indices: + idx = self.indices[word] + self.count[idx] = self.count[idx] + new_dict.count[idx2] + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(new_dict.count[idx2]) + + def finalize(self, threshold=-1, nwords=-1, padding_factor=8): + """Sort symbols by frequency in descending order, ignoring special ones. + + Args: + - threshold defines the minimum word count + - nwords defines the total number of words in the final dictionary, + including special symbols + - padding_factor can be used to pad the dictionary size to be a + multiple of 8, which is important on some hardware (e.g., Nvidia + Tensor Cores). + """ + if nwords <= 0: + nwords = len(self) + + new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial))) + new_symbols = self.symbols[: self.nspecial] + new_count = self.count[: self.nspecial] + + c = Counter( + dict( + sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :])) + ) + ) + for symbol, count in c.most_common(nwords - self.nspecial): + if count >= threshold: + new_indices[symbol] = len(new_symbols) + new_symbols.append(symbol) + new_count.append(count) + else: + break + + assert len(new_symbols) == len(new_indices) + + self.count = list(new_count) + self.symbols = list(new_symbols) + self.indices = new_indices + + self.pad_to_multiple_(padding_factor) + + def pad_to_multiple_(self, padding_factor): + """Pad Dictionary size to be a multiple of *padding_factor*.""" + if padding_factor > 1: + i = 0 + while len(self) % padding_factor != 0: + symbol = "madeupword{:04d}".format(i) + self.add_symbol(symbol, n=0) + i += 1 + + def bos(self): + """Helper to get index of beginning-of-sentence symbol""" + return self.bos_index + + def pad(self): + """Helper to get index of pad symbol""" + return self.pad_index + + def eos(self): + """Helper to get index of end-of-sentence symbol""" + return self.eos_index + + def unk(self): + """Helper to get index of unk symbol""" + return self.unk_index + + @classmethod + def load(cls, f): + """Loads the dictionary from a text file with the format: + + ``` + + + ... + ``` + """ + d = cls() + d.add_from_file(f) + return d + + def add_from_file(self, f): + """ + Loads a pre-existing dictionary from a text file and adds its symbols + to this instance. + """ + if isinstance(f, str): + try: + with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd: + self.add_from_file(fd) + except FileNotFoundError as fnfe: + raise fnfe + except UnicodeError: + raise Exception( + "Incorrect encoding detected in {}, please " + "rebuild the dataset".format(f) + ) + return + + lines = f.readlines() + indices_start_line = self._load_meta(lines) + + for line in lines[indices_start_line:]: + try: + line, field = line.rstrip().rsplit(" ", 1) + if field == "#fairseq:overwrite": + overwrite = True + line, field = line.rsplit(" ", 1) + else: + overwrite = False + count = int(field) + word = line + if word in self and not overwrite: + raise RuntimeError( + "Duplicate word found when loading Dictionary: '{}'. " + "Duplicate words can overwrite earlier ones by adding the " + "#fairseq:overwrite flag at the end of the corresponding row " + "in the dictionary file. If using the Camembert model, please " + "download an updated copy of the model file.".format(word) + ) + self.add_symbol(word, n=count, overwrite=overwrite) + except ValueError: + raise ValueError( + "Incorrect dictionary format, expected ' [flags]'" + ) + + def _save(self, f, kv_iterator): + if isinstance(f, str): + PathManager.mkdirs(os.path.dirname(f)) + with PathManager.open(f, "w", encoding="utf-8") as fd: + return self.save(fd) + for k, v in kv_iterator: + print("{} {}".format(k, v), file=f) + + def _get_meta(self): + return [], [] + + def _load_meta(self, lines): + return 0 + + def save(self, f): + """Stores dictionary into a text file""" + ex_keys, ex_vals = self._get_meta() + self._save( + f, + zip( + ex_keys + self.symbols[self.nspecial :], + ex_vals + self.count[self.nspecial :], + ), + ) + + def dummy_sentence(self, length): + t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long() + t[-1] = self.eos() + return t + + def encode_line( + self, + line, + line_tokenizer=tokenize_line, + add_if_not_exist=True, + consumer=None, + append_eos=True, + reverse_order=False, + ): + words = line_tokenizer(line) + if reverse_order: + words = list(reversed(words)) + nwords = len(words) + ids = torch.IntTensor(nwords + 1 if append_eos else nwords) + + for i, word in enumerate(words): + if add_if_not_exist: + idx = self.add_symbol(word) + else: + idx = self.index(word) + if consumer is not None: + consumer(word, idx) + ids[i] = idx + if append_eos: + ids[nwords] = self.eos_index + return ids + + @staticmethod + def _add_file_to_dictionary_single_worker( + filename, tokenize, eos_word, worker_id=0, num_workers=1 + ): + counter = Counter() + with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: + size = os.fstat(f.fileno()).st_size + chunk_size = size // num_workers + offset = worker_id * chunk_size + end = offset + chunk_size + f.seek(offset) + if offset > 0: + safe_readline(f) # drop first incomplete line + line = f.readline() + while line: + for word in tokenize(line): + counter.update([word]) + counter.update([eos_word]) + if f.tell() > end: + break + line = f.readline() + return counter + + @staticmethod + def add_file_to_dictionary(filename, dict, tokenize, num_workers): + def merge_result(counter): + for w, c in sorted(counter.items()): + dict.add_symbol(w, c) + + if num_workers > 1: + pool = Pool(processes=num_workers) + results = [] + for worker_id in range(num_workers): + results.append( + pool.apply_async( + Dictionary._add_file_to_dictionary_single_worker, + (filename, tokenize, dict.eos_word, worker_id, num_workers), + ) + ) + pool.close() + pool.join() + for r in results: + merge_result(r.get()) + else: + merge_result( + Dictionary._add_file_to_dictionary_single_worker( + filename, tokenize, dict.eos_word + ) + ) + + +class TruncatedDictionary(object): + def __init__(self, wrapped_dict, length): + self.__class__ = type( + wrapped_dict.__class__.__name__, + (self.__class__, wrapped_dict.__class__), + {}, + ) + self.__dict__ = wrapped_dict.__dict__ + self.wrapped_dict = wrapped_dict + self.length = min(len(self.wrapped_dict), length) + + def __len__(self): + return self.length + + def __getitem__(self, i): + if i < self.length: + return self.wrapped_dict[i] + return self.wrapped_dict.unk() diff --git a/fairseq-0.10.2/fairseq/data/encoders/byte_utils.py b/fairseq-0.10.2/fairseq/data/encoders/byte_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a305c080926c2d094b7e8ae48f5331da82025a75 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/encoders/byte_utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re + + +WHITESPACE_NORMALIZER = re.compile(r"\s+") +SPACE = chr(32) +SPACE_ESCAPE = chr(9601) +# excluding non-breaking space (160) here +PRINTABLE_LATIN = set( + list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1)) +) +BYTE_TO_BCHAR = { + b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256) +} +BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()} + + +def byte_encode(x: str) -> str: + normalized = WHITESPACE_NORMALIZER.sub(SPACE, x) + return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")]) + + +def byte_decode(x: str) -> str: + try: + return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8") + except ValueError: + return "" + + +def smart_byte_decode(x: str) -> str: + output = byte_decode(x) + if output == "": + # DP the best recovery (max valid chars) if it's broken + n_bytes = len(x) + f = [0 for _ in range(n_bytes + 1)] + pt = [0 for _ in range(n_bytes + 1)] + for i in range(1, n_bytes + 1): + f[i], pt[i] = f[i - 1], i - 1 + for j in range(1, min(4, i) + 1): + if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0: + f[i], pt[i] = f[i - j] + 1, i - j + cur_pt = n_bytes + while cur_pt > 0: + if f[cur_pt] == f[pt[cur_pt]] + 1: + output = byte_decode(x[pt[cur_pt] : cur_pt]) + output + cur_pt = pt[cur_pt] + return output diff --git a/fairseq-0.10.2/fairseq/data/encoders/fastbpe.py b/fairseq-0.10.2/fairseq/data/encoders/fastbpe.py new file mode 100644 index 0000000000000000000000000000000000000000..74d4ad850409d69a5b2476480a9a3aa229038686 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/encoders/fastbpe.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import file_utils +from fairseq.data.encoders import register_bpe + + +@register_bpe("fastbpe") +class fastBPE(object): + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--bpe-codes', type=str, + help='path to fastBPE BPE') + # fmt: on + + def __init__(self, args): + if args.bpe_codes is None: + raise ValueError("--bpe-codes is required for --bpe=fastbpe") + codes = file_utils.cached_path(args.bpe_codes) + try: + import fastBPE + + self.bpe = fastBPE.fastBPE(codes) + self.bpe_symbol = "@@ " + except ImportError: + raise ImportError("Please install fastBPE with: pip install fastBPE") + + def encode(self, x: str) -> str: + return self.bpe.apply([x])[0] + + def decode(self, x: str) -> str: + return (x + " ").replace(self.bpe_symbol, "").rstrip() diff --git a/fairseq-0.10.2/fairseq/data/encoders/gpt2_bpe_utils.py b/fairseq-0.10.2/fairseq/data/encoders/gpt2_bpe_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..688d4e36e358df2dcc432d37d3e57bd81e2f1ed1 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/encoders/gpt2_bpe_utils.py @@ -0,0 +1,140 @@ +""" +Byte pair encoding utilities from GPT-2. + +Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py +Original license: MIT +""" + +import json +from functools import lru_cache + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2 ** 8): + if b not in bs: + bs.append(b) + cs.append(2 ** 8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Encoder: + def __init__(self, encoder, bpe_merges, errors="replace"): + self.encoder = encoder + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + try: + import regex as re + + self.re = re + except ImportError: + raise ImportError("Please install regex with: pip install regex") + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = self.re.compile( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + ) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + for token in self.re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder.get(token, token) for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + "utf-8", errors=self.errors + ) + return text + + +def get_encoder(encoder_json_path, vocab_bpe_path): + with open(encoder_json_path, "r") as f: + encoder = json.load(f) + with open(vocab_bpe_path, "r", encoding="utf-8") as f: + bpe_data = f.read() + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] + return Encoder( + encoder=encoder, + bpe_merges=bpe_merges, + ) diff --git a/fairseq-0.10.2/fairseq/data/encoders/moses_tokenizer.py b/fairseq-0.10.2/fairseq/data/encoders/moses_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8c24844263a98c58a160b624b7741947ee290884 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/encoders/moses_tokenizer.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.data.encoders import register_tokenizer + + +@register_tokenizer("moses") +class MosesTokenizer(object): + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--moses-source-lang', metavar='SRC', + help='source language') + parser.add_argument('--moses-target-lang', metavar='TARGET', + help='target language') + parser.add_argument('--moses-no-dash-splits', action='store_true', default=False, + help='don\'t apply dash split rules') + parser.add_argument('--moses-no-escape', action='store_true', default=False, + help='don\'t perform HTML escaping on apostrophy, quotes, etc.') + # fmt: on + + def __init__(self, args): + self.args = args + + if getattr(args, "moses_source_lang", None) is None: + args.moses_source_lang = getattr(args, "source_lang", "en") + if getattr(args, "moses_target_lang", None) is None: + args.moses_target_lang = getattr(args, "target_lang", "en") + + try: + from sacremoses import MosesTokenizer, MosesDetokenizer + + self.tok = MosesTokenizer(args.moses_source_lang) + self.detok = MosesDetokenizer(args.moses_target_lang) + except ImportError: + raise ImportError( + "Please install Moses tokenizer with: pip install sacremoses" + ) + + def encode(self, x: str) -> str: + return self.tok.tokenize( + x, + aggressive_dash_splits=(not self.args.moses_no_dash_splits), + return_str=True, + escape=(not self.args.moses_no_escape), + ) + + def decode(self, x: str) -> str: + return self.detok.detokenize(x.split()) diff --git a/fairseq-0.10.2/fairseq/data/encoders/nltk_tokenizer.py b/fairseq-0.10.2/fairseq/data/encoders/nltk_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3b617e7314f0e3aae298b685eea54cbb16312203 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/encoders/nltk_tokenizer.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.data.encoders import register_tokenizer + + +@register_tokenizer("nltk") +class NLTKTokenizer(object): + def __init__(self, source_lang=None, target_lang=None): + try: + from nltk.tokenize import word_tokenize + + self.word_tokenize = word_tokenize + except ImportError: + raise ImportError("Please install nltk with: pip install nltk") + + def encode(self, x: str) -> str: + return " ".join(self.word_tokenize(x)) + + def decode(self, x: str) -> str: + return x diff --git a/fairseq-0.10.2/fairseq/data/encoders/subword_nmt_bpe.py b/fairseq-0.10.2/fairseq/data/encoders/subword_nmt_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..e85f99af396a77d273ba2a4ecaf3db399017b9af --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/encoders/subword_nmt_bpe.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import file_utils +from fairseq.data.encoders import register_bpe + + +@register_bpe("subword_nmt") +class SubwordNMTBPE(object): + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--bpe-codes', type=str, + help='path to subword NMT BPE') + parser.add_argument('--bpe-separator', default='@@', + help='BPE separator') + # fmt: on + + def __init__(self, args): + if args.bpe_codes is None: + raise ValueError("--bpe-codes is required for --bpe=subword_nmt") + codes = file_utils.cached_path(args.bpe_codes) + try: + from subword_nmt import apply_bpe + + bpe_parser = apply_bpe.create_parser() + bpe_args = bpe_parser.parse_args( + [ + "--codes", + codes, + "--separator", + args.bpe_separator, + ] + ) + self.bpe = apply_bpe.BPE( + bpe_args.codes, + bpe_args.merges, + bpe_args.separator, + None, + bpe_args.glossaries, + ) + self.bpe_symbol = bpe_args.separator + " " + except ImportError: + raise ImportError( + "Please install subword_nmt with: pip install subword-nmt" + ) + + def encode(self, x: str) -> str: + return self.bpe.process_line(x) + + def decode(self, x: str) -> str: + return (x + " ").replace(self.bpe_symbol, "").rstrip() diff --git a/fairseq-0.10.2/fairseq/data/encoders/utils.py b/fairseq-0.10.2/fairseq/data/encoders/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d93eb532ef84f0e2bc708b777229ab2cb76ca14b --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/encoders/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq.data import encoders + + +def get_whole_word_mask(args, dictionary): + bpe = encoders.build_bpe(args) + if bpe is not None: + + def is_beginning_of_word(i): + if i < dictionary.nspecial: + # special elements are always considered beginnings + return True + tok = dictionary[i] + if tok.startswith("madeupword"): + return True + try: + return bpe.is_beginning_of_word(tok) + except ValueError: + return True + + mask_whole_words = torch.ByteTensor( + list(map(is_beginning_of_word, range(len(dictionary)))) + ) + return mask_whole_words + return None diff --git a/fairseq-0.10.2/fairseq/data/fairseq_dataset.py b/fairseq-0.10.2/fairseq/data/fairseq_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ed08c1ba200f3d4b95053c02aaa227169fe80d26 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/fairseq_dataset.py @@ -0,0 +1,191 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch.utils.data +from fairseq.data import data_utils + + +class EpochListening: + """Mixin for receiving updates whenever the epoch increments.""" + + @property + def can_reuse_epoch_itr_across_epochs(self): + """ + Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for + this dataset across epochs. + + This needs to return ``False`` if the sample sizes can change across + epochs, in which case we may need to regenerate batches at each epoch. + If your dataset relies in ``set_epoch`` then you should consider setting + this to ``False``. + """ + return True + + def set_epoch(self, epoch): + """Will receive the updated epoch number at the beginning of the epoch.""" + pass + + +class FairseqDataset(torch.utils.data.Dataset, EpochListening): + """A dataset that provides helpers for batching.""" + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch suitable for forwarding with a Model + """ + raise NotImplementedError + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + raise NotImplementedError + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + raise NotImplementedError + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + return np.arange(len(self), dtype=np.int64) + + @property + def supports_prefetch(self): + """Whether this dataset supports prefetching.""" + return False + + def attr(self, attr: str, index: int): + return getattr(self, attr, None) + + def prefetch(self, indices): + """Prefetch the data required for this epoch.""" + raise NotImplementedError + + def get_batch_shapes(self): + """ + Return a list of valid batch shapes, for example:: + + [(8, 512), (16, 256), (32, 128)] + + The first dimension of each tuple is the batch size and can be ``None`` + to automatically infer the max batch size based on ``--max-tokens``. + The second dimension of each tuple is the max supported length as given + by :func:`fairseq.data.FairseqDataset.num_tokens`. + + This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size` + to restrict batch shapes. This is useful on TPUs to avoid too many + dynamic shapes (and recompilations). + """ + return None + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + """ + Given an ordered set of indices, return batches according to + *max_tokens*, *max_sentences* and *required_batch_size_multiple*. + """ + from fairseq.data import data_utils + + fixed_shapes = self.get_batch_shapes() + if fixed_shapes is not None: + + def adjust_bsz(bsz, num_tokens): + if bsz is None: + assert max_tokens is not None, "Must specify --max-tokens" + bsz = max_tokens // num_tokens + if max_sentences is not None: + bsz = min(bsz, max_sentences) + elif ( + bsz >= required_batch_size_multiple + and bsz % required_batch_size_multiple != 0 + ): + bsz -= bsz % required_batch_size_multiple + return bsz + + fixed_shapes = np.array( + [ + [adjust_bsz(bsz, num_tokens), num_tokens] + for (bsz, num_tokens) in fixed_shapes + ] + ) + + return data_utils.batch_by_size( + indices, + num_tokens_fn=self.num_tokens, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + fixed_shapes=fixed_shapes, + ) + + def filter_indices_by_size(self, indices, max_sizes): + """ + Filter a list of sample indices. Remove those that are longer than + specified in *max_sizes*. + + WARNING: don't update, override method in child classes + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if isinstance(max_sizes, float) or isinstance(max_sizes, int): + if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray): + ignored = indices[self.sizes[indices] > max_sizes].tolist() + indices = indices[self.sizes[indices] <= max_sizes] + elif ( + hasattr(self, "sizes") + and isinstance(self.sizes, list) + and len(self.sizes) == 1 + ): + ignored = indices[self.sizes[0][indices] > max_sizes].tolist() + indices = indices[self.sizes[0][indices] <= max_sizes] + else: + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + else: + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + return indices, ignored + + @property + def supports_fetch_outside_dataloader(self): + """Whether this dataset supports fetching outside the workers of the dataloader.""" + return True + + +class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): + """ + For datasets that need to be read sequentially, usually because the data is + being streamed or otherwise can't be manipulated on a single machine. + """ + + def __iter__(self): + raise NotImplementedError diff --git a/fairseq-0.10.2/fairseq/data/fasta_dataset.py b/fairseq-0.10.2/fairseq/data/fasta_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..007011974a997fd7446dd29d7eba097d7513bab0 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/fasta_dataset.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import subprocess +import threading +from pathlib import Path + +import numpy as np +import torch + + +def fasta_file_path(prefix_path): + return prefix_path + ".fasta" + + +class FastaDataset(torch.utils.data.Dataset): + """ + For loading protein sequence datasets in the common FASTA data format + """ + + def __init__(self, path: str, cache_indices=False): + self.fn = fasta_file_path(path) + self.threadlocal = threading.local() + self.cache = Path(f"{path}.fasta.idx.npy") + if cache_indices: + if self.cache.exists(): + self.offsets, self.sizes = np.load(self.cache) + else: + self.offsets, self.sizes = self._build_index(path) + np.save(self.cache, np.stack([self.offsets, self.sizes])) + else: + self.offsets, self.sizes = self._build_index(path) + + def _get_file(self): + if not hasattr(self.threadlocal, "f"): + self.threadlocal.f = open(self.fn, "r") + return self.threadlocal.f + + def __getitem__(self, idx): + f = self._get_file() + f.seek(self.offsets[idx]) + desc = f.readline().strip() + line = f.readline() + seq = "" + while line != "" and line[0] != ">": + seq += line.strip() + line = f.readline() + return desc, seq + + def __len__(self): + return self.offsets.size + + def _build_index(self, path: str): + # Use grep and awk to get 100M/s on local SSD. + # Should process your enormous 100G fasta in ~10 min single core... + path = fasta_file_path(path) + bytes_offsets = subprocess.check_output( + f"cat {path} | tqdm --bytes --total $(wc -c < {path})" + "| grep --byte-offset '^>' -o | cut -d: -f1", + shell=True, + ) + fasta_lengths = subprocess.check_output( + f"cat {path} | tqdm --bytes --total $(wc -c < {path})" + "| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'", + shell=True, + ) + bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ") + sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ") + return bytes_np, sizes_np + + def __setstate__(self, state): + self.__dict__ = state + self.threadlocal = threading.local() + + def __getstate__(self): + d = {} + for i, v in self.__dict__.items(): + if i != "threadlocal": + d[i] = v + return d + + def __del__(self): + if hasattr(self.threadlocal, "f"): + self.threadlocal.f.close() + del self.threadlocal.f + + @staticmethod + def exists(path): + return os.path.exists(fasta_file_path(path)) + + +class EncodedFastaDataset(FastaDataset): + """ + The FastaDataset returns raw sequences - this allows us to return + indices with a dictionary instead. + """ + + def __init__(self, path, dictionary): + super().__init__(path, cache_indices=True) + self.dictionary = dictionary + + def __getitem__(self, idx): + desc, seq = super().__getitem__(idx) + return self.dictionary.encode_line(seq, line_tokenizer=list).long() diff --git a/fairseq-0.10.2/fairseq/data/id_dataset.py b/fairseq-0.10.2/fairseq/data/id_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3e4d7969cf2a26e852b466f165a6fadabae3b35f --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/id_dataset.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import FairseqDataset + + +class IdDataset(FairseqDataset): + def __getitem__(self, index): + return index + + def __len__(self): + return 0 + + def collater(self, samples): + return torch.tensor(samples) diff --git a/fairseq-0.10.2/fairseq/data/language_pair_dataset.py b/fairseq-0.10.2/fairseq/data/language_pair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..db95e14f55d459e83d0665f8b6174ea3eaef3cd3 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/language_pair_dataset.py @@ -0,0 +1,475 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch +from fairseq.data import FairseqDataset, data_utils + + +logger = logging.getLogger(__name__) + + +def collate( + samples, + pad_idx, + eos_idx, + left_pad_source=True, + left_pad_target=False, + input_feeding=True, + pad_to_length=None, + pad_to_multiple=1, +): + if len(samples) == 0: + return {} + + def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx, + eos_idx, + left_pad, + move_eos_to_beginning, + pad_to_length=pad_to_length, + pad_to_multiple=pad_to_multiple, + ) + + def check_alignment(alignment, src_len, tgt_len): + if alignment is None or len(alignment) == 0: + return False + if ( + alignment[:, 0].max().item() >= src_len - 1 + or alignment[:, 1].max().item() >= tgt_len - 1 + ): + logger.warning("alignment size mismatch found, skipping alignment!") + return False + return True + + def compute_alignment_weights(alignments): + """ + Given a tensor of shape [:, 2] containing the source-target indices + corresponding to the alignments, a weight vector containing the + inverse frequency of each target index is computed. + For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then + a tensor containing [1., 0.5, 0.5, 1] should be returned (since target + index 3 is repeated twice) + """ + align_tgt = alignments[:, 1] + _, align_tgt_i, align_tgt_c = torch.unique( + align_tgt, return_inverse=True, return_counts=True + ) + align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] + return 1.0 / align_weights.float() + + id = torch.LongTensor([s["id"] for s in samples]) + + # import pdb;pdb.set_trace() + + src_tokens = merge( + "source", + left_pad=left_pad_source, + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) + # sort by descending source length + src_lengths = torch.LongTensor( + [s["source"].ne(pad_idx).long().sum() for s in samples] + ) + + # import pdb;pdb.set_trace() + + src_lengths, sort_order = src_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + src_tokens = src_tokens.index_select(0, sort_order) + + prev_output_tokens = None + target = None + if samples[0].get("target", None) is not None: + target = merge( + "target", + left_pad=left_pad_target, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + target = target.index_select(0, sort_order) + tgt_lengths = torch.LongTensor( + [s["target"].ne(pad_idx).long().sum() for s in samples] + ).index_select(0, sort_order) + ntokens = tgt_lengths.sum().item() + + if samples[0].get("prev_output_tokens", None) is not None: + prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target) + elif input_feeding: + # we create a shifted version of targets for feeding the + # previous output token(s) into the next decoder step + prev_output_tokens = merge( + "target", + left_pad=left_pad_target, + move_eos_to_beginning=True, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + else: + ntokens = src_lengths.sum().item() + + batch = { + "id": id, + "nsentences": len(samples), + "ntokens": ntokens, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + }, + "target": target, + } + if prev_output_tokens is not None: + batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select( + 0, sort_order + ) + + if samples[0].get("alignment", None) is not None: + bsz, tgt_sz = batch["target"].shape + src_sz = batch["net_input"]["src_tokens"].shape[1] + + offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) + offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz + if left_pad_source: + offsets[:, 0] += src_sz - src_lengths + if left_pad_target: + offsets[:, 1] += tgt_sz - tgt_lengths + + alignments = [ + alignment + offset + for align_idx, offset, src_len, tgt_len in zip( + sort_order, offsets, src_lengths, tgt_lengths + ) + for alignment in [samples[align_idx]["alignment"].view(-1, 2)] + if check_alignment(alignment, src_len, tgt_len) + ] + + if len(alignments) > 0: + alignments = torch.cat(alignments, dim=0) + align_weights = compute_alignment_weights(alignments) + + batch["alignments"] = alignments + batch["align_weights"] = align_weights + + if samples[0].get("constraints", None) is not None: + # Collate the packed constraints across the samples, padding to + # the length of the longest sample. + lens = [sample.get("constraints").size(0) for sample in samples] + max_len = max(lens) + constraints = torch.zeros((len(samples), max(lens))).long() + for i, sample in enumerate(samples): + constraints[i, 0 : lens[i]] = samples[i].get("constraints") + batch["constraints"] = constraints + + return batch + + +class LanguagePairDataset(FairseqDataset): + """ + A pair of torch.utils.data.Datasets. + + Args: + src (torch.utils.data.Dataset): source dataset to wrap + src_sizes (List[int]): source sentence lengths + src_dict (~fairseq.data.Dictionary): source vocabulary + tgt (torch.utils.data.Dataset, optional): target dataset to wrap + tgt_sizes (List[int], optional): target sentence lengths + tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary + left_pad_source (bool, optional): pad source tensors on the left side + (default: True). + left_pad_target (bool, optional): pad target tensors on the left side + (default: False). + shuffle (bool, optional): shuffle dataset elements before batching + (default: True). + input_feeding (bool, optional): create a shifted version of the targets + to be passed into the model for teacher forcing (default: True). + remove_eos_from_source (bool, optional): if set, removes eos from end + of source if it's present (default: False). + append_eos_to_target (bool, optional): if set, appends eos to end of + target if it's absent (default: False). + align_dataset (torch.utils.data.Dataset, optional): dataset + containing alignments. + constraints (Tensor, optional): 2d tensor with a concatenated, zero- + delimited list of constraints for each sentence. + append_bos (bool, optional): if set, appends bos to the beginning of + source/target sentence. + num_buckets (int, optional): if set to a value greater than 0, then + batches will be bucketed into the given number of batch shapes. + src_lang_id (int, optional): source language ID, if set, the collated batch + will contain a field 'src_lang_id' in 'net_input' which indicates the + source language of the samples. + tgt_lang_id (int, optional): target language ID, if set, the collated batch + will contain a field 'tgt_lang_id' which indicates the target language + of the samples. + """ + + def __init__( + self, + src, + src_sizes, + src_dict, + tgt=None, + tgt_sizes=None, + tgt_dict=None, + left_pad_source=True, + left_pad_target=False, + shuffle=True, + input_feeding=True, + remove_eos_from_source=False, + append_eos_to_target=False, + align_dataset=None, + constraints=None, + append_bos=False, + eos=None, + num_buckets=0, + src_lang_id=None, + tgt_lang_id=None, + pad_to_multiple=1, + ): + if tgt_dict is not None: + assert src_dict.pad() == tgt_dict.pad() + assert src_dict.eos() == tgt_dict.eos() + assert src_dict.unk() == tgt_dict.unk() + if tgt is not None: + assert len(src) == len( + tgt + ), "Source and target must contain the same number of examples" + self.src = src + self.tgt = tgt + self.src_sizes = np.array(src_sizes) + self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None + self.sizes = ( + np.vstack((self.src_sizes, self.tgt_sizes)).T + if self.tgt_sizes is not None + else self.src_sizes + ) + self.src_dict = src_dict + self.tgt_dict = tgt_dict + self.left_pad_source = left_pad_source + self.left_pad_target = left_pad_target + self.shuffle = shuffle + self.input_feeding = input_feeding + self.remove_eos_from_source = remove_eos_from_source + self.append_eos_to_target = append_eos_to_target + self.align_dataset = align_dataset + if self.align_dataset is not None: + assert ( + self.tgt_sizes is not None + ), "Both source and target needed when alignments are provided" + self.constraints = constraints + self.append_bos = append_bos + self.eos = eos if eos is not None else src_dict.eos() + self.src_lang_id = src_lang_id + self.tgt_lang_id = tgt_lang_id + if num_buckets > 0: + from fairseq.data import BucketPadLengthDataset + + self.src = BucketPadLengthDataset( + self.src, + sizes=self.src_sizes, + num_buckets=num_buckets, + pad_idx=self.src_dict.pad(), + left_pad=self.left_pad_source, + ) + self.src_sizes = self.src.sizes + logger.info("bucketing source lengths: {}".format(list(self.src.buckets))) + if self.tgt is not None: + self.tgt = BucketPadLengthDataset( + self.tgt, + sizes=self.tgt_sizes, + num_buckets=num_buckets, + pad_idx=self.tgt_dict.pad(), + left_pad=self.left_pad_target, + ) + self.tgt_sizes = self.tgt.sizes + logger.info( + "bucketing target lengths: {}".format(list(self.tgt.buckets)) + ) + + # determine bucket sizes using self.num_tokens, which will return + # the padded lengths (thanks to BucketPadLengthDataset) + num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) + self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) + self.buckets = [ + (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) + ] + else: + self.buckets = None + self.pad_to_multiple = pad_to_multiple + + def get_batch_shapes(self): + return self.buckets + + def __getitem__(self, index): + tgt_item = self.tgt[index] if self.tgt is not None else None + src_item = self.src[index] + # Append EOS to end of tgt sentence if it does not have an EOS and remove + # EOS from end of src sentence if it exists. This is useful when we use + # use existing datasets for opposite directions i.e., when we want to + # use tgt_dataset as src_dataset and vice versa + if self.append_eos_to_target: + eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() + if self.tgt and self.tgt[index][-1] != eos: + tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) + + if self.append_bos: + bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() + if self.tgt and self.tgt[index][0] != bos: + tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]]) + + bos = self.src_dict.bos() + if self.src[index][0] != bos: + src_item = torch.cat([torch.LongTensor([bos]), self.src[index]]) + + if self.remove_eos_from_source: + eos = self.src_dict.eos() + if self.src[index][-1] == eos: + src_item = self.src[index][:-1] + + example = { + "id": index, + "source": src_item, + "target": tgt_item, + } + if self.align_dataset is not None: + example["alignment"] = self.align_dataset[index] + if self.constraints is not None: + example["constraints"] = self.constraints[index] + return example + + def __len__(self): + return len(self.src) + + def collater(self, samples, pad_to_length=None): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + pad_to_length (dict, optional): a dictionary of + {'source': source_pad_to_length, 'target': target_pad_to_length} + to indicate the max length to pad to in source and target respectively. + + Returns: + dict: a mini-batch with the following keys: + + - `id` (LongTensor): example IDs in the original input order + - `ntokens` (int): total number of tokens in the batch + - `net_input` (dict): the input to the Model, containing keys: + + - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in + the source sentence of shape `(bsz, src_len)`. Padding will + appear on the left if *left_pad_source* is ``True``. + - `src_lengths` (LongTensor): 1D Tensor of the unpadded + lengths of each source sentence of shape `(bsz)` + - `prev_output_tokens` (LongTensor): a padded 2D Tensor of + tokens in the target sentence, shifted right by one + position for teacher forcing, of shape `(bsz, tgt_len)`. + This key will not be present if *input_feeding* is + ``False``. Padding will appear on the left if + *left_pad_target* is ``True``. + - `src_lang_id` (LongTensor): a long Tensor which contains source + language IDs of each sample in the batch + + - `target` (LongTensor): a padded 2D Tensor of tokens in the + target sentence of shape `(bsz, tgt_len)`. Padding will appear + on the left if *left_pad_target* is ``True``. + - `tgt_lang_id` (LongTensor): a long Tensor which contains target language + IDs of each sample in the batch + """ + res = collate( + samples, + pad_idx=self.src_dict.pad(), + eos_idx=self.eos, + left_pad_source=self.left_pad_source, + left_pad_target=self.left_pad_target, + input_feeding=self.input_feeding, + pad_to_length=pad_to_length, + pad_to_multiple=self.pad_to_multiple, + ) + if self.src_lang_id is not None or self.tgt_lang_id is not None: + src_tokens = res["net_input"]["src_tokens"] + bsz = src_tokens.size(0) + if self.src_lang_id is not None: + res["net_input"]["src_lang_id"] = ( + torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens) + ) + if self.tgt_lang_id is not None: + res["tgt_lang_id"] = ( + torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens) + ) + return res + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return max( + self.src_sizes[index], + self.tgt_sizes[index] if self.tgt_sizes is not None else 0, + ) + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return ( + self.src_sizes[index], + self.tgt_sizes[index] if self.tgt_sizes is not None else 0, + ) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)).astype(np.int64) + else: + indices = np.arange(len(self), dtype=np.int64) + if self.buckets is None: + # sort by target length, then source length + if self.tgt_sizes is not None: + indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] + return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] + else: + # sort by bucketed_num_tokens, which is: + # max(padded_src_len, padded_tgt_len) + return indices[ + np.argsort(self.bucketed_num_tokens[indices], kind="mergesort") + ] + + @property + def supports_prefetch(self): + return getattr(self.src, "supports_prefetch", False) and ( + getattr(self.tgt, "supports_prefetch", False) or self.tgt is None + ) + + def prefetch(self, indices): + self.src.prefetch(indices) + if self.tgt is not None: + self.tgt.prefetch(indices) + if self.align_dataset is not None: + self.align_dataset.prefetch(indices) + + def filter_indices_by_size(self, indices, max_sizes): + """Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + return data_utils.filter_paired_dataset_indices_by_size( + self.src_sizes, + self.tgt_sizes, + indices, + max_sizes, + ) diff --git a/fairseq-0.10.2/fairseq/data/legacy/__init__.py b/fairseq-0.10.2/fairseq/data/legacy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd5c72b5e9d7f67fb7e4ef10808d7ec08967ff4 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/legacy/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .block_pair_dataset import BlockPairDataset +from .masked_lm_dataset import MaskedLMDataset +from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary + + +__all__ = [ + "BertDictionary", + "BlockPairDataset", + "MaskedLMDataset", + "MaskedLMDictionary", +] diff --git a/fairseq-0.10.2/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b876ac07b0744e62e02a10a224b6fa998f455847 Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..558fa969915acdaf3488710fdfb21bb23dc60655 Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68e11cf8a9a95ee0b3416c47c59f1dc79a3fb169 Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3da204c7c3e48e0fea21257d1ac5feb6600e637 Binary files /dev/null and b/fairseq-0.10.2/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/data/legacy/block_pair_dataset.py b/fairseq-0.10.2/fairseq/data/legacy/block_pair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ba069b46052286c531b4f9706d96788732cd2ad2 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/legacy/block_pair_dataset.py @@ -0,0 +1,311 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import numpy as np +import torch +from fairseq.data import FairseqDataset + + +class BlockPairDataset(FairseqDataset): + """Break a Dataset of tokens into sentence pair blocks for next sentence + prediction as well as masked language model. + + High-level logics are: + 1. break input tensor to tensor blocks + 2. pair the blocks with 50% next sentence and 50% random sentence + 3. return paired blocks as well as related segment labels + + Args: + dataset (~torch.utils.data.Dataset): dataset to break into blocks + sizes: array of sentence lengths + dictionary: dictionary for the task + block_size: maximum block size + break_mode: mode for breaking copurs into block pairs. currently we support + 2 modes + doc: respect document boundaries and each part of the pair should belong to on document + none: don't respect any boundary and cut tokens evenly + short_seq_prob: probability for generating shorter block pairs + doc_break_size: Size for empty line separating documents. Typically 1 if + the sentences have eos, 0 otherwise. + """ + + def __init__( + self, + dataset, + dictionary, + sizes, + block_size, + break_mode="doc", + short_seq_prob=0.1, + doc_break_size=1, + ): + super().__init__() + self.dataset = dataset + self.pad = dictionary.pad() + self.eos = dictionary.eos() + self.cls = dictionary.cls() + self.mask = dictionary.mask() + self.sep = dictionary.sep() + self.break_mode = break_mode + self.dictionary = dictionary + self.short_seq_prob = short_seq_prob + self.block_indices = [] + + assert len(dataset) == len(sizes) + + if break_mode == "doc": + cur_doc = [] + for sent_id, sz in enumerate(sizes): + assert doc_break_size == 0 or sz != 0, ( + "when doc_break_size is non-zero, we expect documents to be" + "separated by a blank line with a single eos." + ) + # empty line as document separator + if sz == doc_break_size: + if len(cur_doc) == 0: + continue + self.block_indices.append(cur_doc) + cur_doc = [] + else: + cur_doc.append(sent_id) + max_num_tokens = block_size - 3 # Account for [CLS], [SEP], [SEP] + self.sent_pairs = [] + self.sizes = [] + for doc_id, doc in enumerate(self.block_indices): + self._generate_sentence_pair(doc, doc_id, max_num_tokens, sizes) + elif break_mode is None or break_mode == "none": + # each block should have half of the block size since we are constructing block pair + sent_length = (block_size - 3) // 2 + total_len = sum(dataset.sizes) + length = math.ceil(total_len / sent_length) + + def block_at(i): + start = i * sent_length + end = min(start + sent_length, total_len) + return (start, end) + + sent_indices = np.array([block_at(i) for i in range(length)]) + sent_sizes = np.array([e - s for s, e in sent_indices]) + dataset_index = self._sent_to_dataset_index(sent_sizes) + + # pair sentences + self._pair_sentences(dataset_index) + else: + raise ValueError("Invalid break_mode: " + break_mode) + + def _pair_sentences(self, dataset_index): + """ + Give a list of evenly cut blocks/sentences, pair these sentences with 50% + consecutive sentences and 50% random sentences. + This is used for none break mode + """ + # pair sentences + for sent_id, sent in enumerate(dataset_index): + next_sent_label = ( + 1 if np.random.rand() > 0.5 and sent_id != len(dataset_index) - 1 else 0 + ) + if next_sent_label: + next_sent = dataset_index[sent_id + 1] + else: + next_sent = dataset_index[ + self._skip_sampling(len(dataset_index), [sent_id, sent_id + 1]) + ] + self.sent_pairs.append((sent, next_sent, next_sent_label)) + + # The current blocks don't include the special tokens but the + # sizes already account for this + self.sizes.append(3 + sent[3] + next_sent[3]) + + def _sent_to_dataset_index(self, sent_sizes): + """ + Build index mapping block indices to the underlying dataset indices + """ + dataset_index = [] + ds_idx, ds_remaining = -1, 0 + for to_consume in sent_sizes: + sent_size = to_consume + if ds_remaining == 0: + ds_idx += 1 + ds_remaining = sent_sizes[ds_idx] + start_ds_idx = ds_idx + start_offset = sent_sizes[ds_idx] - ds_remaining + while to_consume > ds_remaining: + to_consume -= ds_remaining + ds_idx += 1 + ds_remaining = sent_sizes[ds_idx] + ds_remaining -= to_consume + dataset_index.append( + ( + start_ds_idx, # starting index in dataset + start_offset, # starting offset within starting index + ds_idx, # ending index in dataset + sent_size, # sentence length + ) + ) + assert ds_remaining == 0 + assert ds_idx == len(self.dataset) - 1 + return dataset_index + + def _generate_sentence_pair(self, doc, doc_id, max_num_tokens, sizes): + """ + Go through a single document and genrate sentence paris from it + """ + current_chunk = [] + current_length = 0 + curr = 0 + # To provide more randomness, we decrease target seq length for parts of + # samples (10% by default). Note that max_num_tokens is the hard threshold + # for batching and will never be changed. + target_seq_length = max_num_tokens + if np.random.random() < self.short_seq_prob: + target_seq_length = np.random.randint(2, max_num_tokens) + # loop through all sentences in document + while curr < len(doc): + sent_id = doc[curr] + current_chunk.append(sent_id) + current_length = sum(sizes[current_chunk]) + # split chunk and generate pair when exceed target_seq_length or + # finish the loop + if curr == len(doc) - 1 or current_length >= target_seq_length: + # split the chunk into 2 parts + a_end = 1 + if len(current_chunk) > 2: + a_end = np.random.randint(1, len(current_chunk) - 1) + sent_a = current_chunk[:a_end] + len_a = sum(sizes[sent_a]) + # generate next sentence label, note that if there is only 1 sentence + # in current chunk, label is always 0 + next_sent_label = ( + 1 if np.random.rand() > 0.5 and len(current_chunk) != 1 else 0 + ) + if not next_sent_label: + # if next sentence label is 0, sample sent_b from a random doc + target_b_length = target_seq_length - len_a + rand_doc_id = self._skip_sampling(len(self.block_indices), [doc_id]) + random_doc = self.block_indices[rand_doc_id] + random_start = np.random.randint(0, len(random_doc)) + sent_b = [] + len_b = 0 + for j in range(random_start, len(random_doc)): + sent_b.append(random_doc[j]) + len_b = sum(sizes[sent_b]) + if len_b >= target_b_length: + break + # return the second part of the chunk since it's not used + num_unused_segments = len(current_chunk) - a_end + curr -= num_unused_segments + else: + # if next sentence label is 1, use the second part of chunk as sent_B + sent_b = current_chunk[a_end:] + len_b = sum(sizes[sent_b]) + # currently sent_a and sent_B may be longer than max_num_tokens, + # truncate them and return block idx and offsets for them + sent_a, sent_b = self._truncate_sentences( + sent_a, sent_b, max_num_tokens + ) + self.sent_pairs.append((sent_a, sent_b, next_sent_label)) + self.sizes.append(3 + sent_a[3] + sent_b[3]) + current_chunk = [] + curr += 1 + + def _skip_sampling(self, total, skip_ids): + """ + Generate a random integer which is not in skip_ids. Sample range is [0, total) + TODO: ids in skip_ids should be consecutive, we can extend it to more generic version later + """ + rand_id = np.random.randint(total - len(skip_ids)) + return rand_id if rand_id < min(skip_ids) else rand_id + len(skip_ids) + + def _truncate_sentences(self, sent_a, sent_b, max_num_tokens): + """ + Trancate a pair of sentence to limit total length under max_num_tokens + Logics: + 1. Truncate longer sentence + 2. Tokens to be truncated could be at the beginning or the end of the sentnce + Returns: + Truncated sentences represented by dataset idx + """ + len_a, len_b = sum(self.dataset.sizes[sent_a]), sum(self.dataset.sizes[sent_b]) + front_cut_a = front_cut_b = end_cut_a = end_cut_b = 0 + + while True: + total_length = ( + len_a + len_b - front_cut_a - front_cut_b - end_cut_a - end_cut_b + ) + if total_length <= max_num_tokens: + break + + if len_a - front_cut_a - end_cut_a > len_b - front_cut_b - end_cut_b: + if np.random.rand() < 0.5: + front_cut_a += 1 + else: + end_cut_a += 1 + else: + if np.random.rand() < 0.5: + front_cut_b += 1 + else: + end_cut_b += 1 + + # calculate ds indices as well as offsets and return + truncated_sent_a = self._cut_sentence(sent_a, front_cut_a, end_cut_a) + truncated_sent_b = self._cut_sentence(sent_b, front_cut_b, end_cut_b) + return truncated_sent_a, truncated_sent_b + + def _cut_sentence(self, sent, front_cut, end_cut): + """ + Cut a sentence based on the numbers of tokens to be cut from beginning and end + Represent the sentence as dataset idx and return + """ + start_ds_idx, end_ds_idx, offset = sent[0], sent[-1], 0 + target_len = sum(self.dataset.sizes[sent]) - front_cut - end_cut + while front_cut > 0: + if self.dataset.sizes[start_ds_idx] > front_cut: + offset += front_cut + break + else: + front_cut -= self.dataset.sizes[start_ds_idx] + start_ds_idx += 1 + while end_cut > 0: + if self.dataset.sizes[end_ds_idx] > end_cut: + break + else: + end_cut -= self.dataset.sizes[end_ds_idx] + end_ds_idx -= 1 + return start_ds_idx, offset, end_ds_idx, target_len + + def _fetch_block(self, start_ds_idx, offset, end_ds_idx, length): + """ + Fetch a block of tokens based on its dataset idx + """ + buffer = torch.cat( + [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] + ) + s, e = offset, offset + length + return buffer[s:e] + + def __getitem__(self, index): + block1, block2, next_sent_label = self.sent_pairs[index] + block1 = self._fetch_block(*block1) + block2 = self._fetch_block(*block2) + return block1, block2, next_sent_label + + def __len__(self): + return len(self.sizes) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + prefetch_idx = set() + for index in indices: + for block1, block2, _ in [self.sent_pairs[index]]: + for ds_idx in range(block1[0], block1[2] + 1): + prefetch_idx.add(ds_idx) + for ds_idx in range(block2[0], block2[2] + 1): + prefetch_idx.add(ds_idx) + self.dataset.prefetch(prefetch_idx) diff --git a/fairseq-0.10.2/fairseq/data/legacy/masked_lm_dictionary.py b/fairseq-0.10.2/fairseq/data/legacy/masked_lm_dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..dee88f7a3ed72ea465ea4e8ffe7b1c01ff6f57f1 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/legacy/masked_lm_dictionary.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.data import Dictionary + + +class MaskedLMDictionary(Dictionary): + """ + Dictionary for Masked Language Modelling tasks. This extends Dictionary by + adding the mask symbol. + """ + + def __init__( + self, + pad="", + eos="", + unk="", + mask="", + ): + super().__init__(pad=pad, eos=eos, unk=unk) + self.mask_word = mask + self.mask_index = self.add_symbol(mask) + self.nspecial = len(self.symbols) + + def mask(self): + """Helper to get index of mask symbol""" + return self.mask_index + + +class BertDictionary(MaskedLMDictionary): + """ + Dictionary for BERT task. This extends MaskedLMDictionary by adding support + for cls and sep symbols. + """ + + def __init__( + self, + pad="", + eos="", + unk="", + mask="", + cls="", + sep="", + ): + super().__init__(pad=pad, eos=eos, unk=unk, mask=mask) + self.cls_word = cls + self.sep_word = sep + self.cls_index = self.add_symbol(cls) + self.sep_index = self.add_symbol(sep) + self.nspecial = len(self.symbols) + + def cls(self): + """Helper to get index of cls symbol""" + return self.cls_index + + def sep(self): + """Helper to get index of sep symbol""" + return self.sep_index diff --git a/fairseq-0.10.2/fairseq/data/list_dataset.py b/fairseq-0.10.2/fairseq/data/list_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..12f00aa43661d6bad701c9e72653ba8779136906 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/list_dataset.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import BaseWrapperDataset + + +class ListDataset(BaseWrapperDataset): + def __init__(self, dataset, sizes=None): + super().__init__(dataset) + self._sizes = sizes + + def __iter__(self): + for x in self.dataset: + yield x + + def collater(self, samples): + return samples + + @property + def sizes(self): + return self._sizes + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + def set_epoch(self, epoch): + pass diff --git a/fairseq-0.10.2/fairseq/data/lru_cache_dataset.py b/fairseq-0.10.2/fairseq/data/lru_cache_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a7854ac1701392754ce5795cafe9c634671aebdf --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/lru_cache_dataset.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache + +from . import BaseWrapperDataset + + +class LRUCacheDataset(BaseWrapperDataset): + def __init__(self, dataset, token=None): + super().__init__(dataset) + + @lru_cache(maxsize=8) + def __getitem__(self, index): + return self.dataset[index] + + @lru_cache(maxsize=8) + def collater(self, samples): + return self.dataset.collater(samples) diff --git a/fairseq-0.10.2/fairseq/data/mask_tokens_dataset.py b/fairseq-0.10.2/fairseq/data/mask_tokens_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea86245f76c4cf233bf6f023138ef341538a267 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/mask_tokens_dataset.py @@ -0,0 +1,178 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache + +import numpy as np +import torch +from fairseq.data import Dictionary, data_utils + +from . import BaseWrapperDataset, LRUCacheDataset + + +class MaskTokensDataset(BaseWrapperDataset): + """ + A wrapper Dataset for masked language modeling. + + Input items are masked according to the specified masking probability. + + Args: + dataset: Dataset to wrap. + sizes: Sentence lengths + vocab: Dictionary with the vocabulary and special tokens. + pad_idx: Id of pad token in vocab + mask_idx: Id of mask token in vocab + return_masked_tokens: controls whether to return the non-masked tokens + (the default) or to return a tensor with the original masked token + IDs (and *pad_idx* elsewhere). The latter is useful as targets for + masked LM training. + seed: Seed for random number generator for reproducibility. + mask_prob: probability of replacing a token with *mask_idx*. + leave_unmasked_prob: probability that a masked token is unmasked. + random_token_prob: probability of replacing a masked token with a + random token from the vocabulary. + freq_weighted_replacement: sample random replacement words based on + word frequencies in the vocab. + mask_whole_words: only mask whole words. This should be a byte mask + over vocab indices, indicating whether it is the beginning of a + word. We will extend any mask to encompass the whole word. + bpe: BPE to use for whole-word masking. + """ + + @classmethod + def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs): + """Return the source and target datasets for masked LM training.""" + dataset = LRUCacheDataset(dataset) + return ( + LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)), + LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)), + ) + + def __init__( + self, + dataset: torch.utils.data.Dataset, + vocab: Dictionary, + pad_idx: int, + mask_idx: int, + return_masked_tokens: bool = False, + seed: int = 1, + mask_prob: float = 0.15, + leave_unmasked_prob: float = 0.1, + random_token_prob: float = 0.1, + freq_weighted_replacement: bool = False, + mask_whole_words: torch.Tensor = None, + ): + assert 0.0 < mask_prob < 1.0 + assert 0.0 <= random_token_prob <= 1.0 + assert 0.0 <= leave_unmasked_prob <= 1.0 + assert random_token_prob + leave_unmasked_prob <= 1.0 + + self.dataset = dataset + self.vocab = vocab + self.pad_idx = pad_idx + self.mask_idx = mask_idx + self.return_masked_tokens = return_masked_tokens + self.seed = seed + self.mask_prob = mask_prob + self.leave_unmasked_prob = leave_unmasked_prob + self.random_token_prob = random_token_prob + self.mask_whole_words = mask_whole_words + + if random_token_prob > 0.0: + if freq_weighted_replacement: + weights = np.array(self.vocab.count) + else: + weights = np.ones(len(self.vocab)) + weights[: self.vocab.nspecial] = 0 + self.weights = weights / weights.sum() + + self.epoch = 0 + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the noise changes, not item sizes + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=8) + def __getitem__(self, index: int): + with data_utils.numpy_seed(self.seed, self.epoch, index): + item = self.dataset[index] + sz = len(item) + + assert ( + self.mask_idx not in item + ), "Dataset contains mask_idx (={}), this is not expected!".format( + self.mask_idx, + ) + + if self.mask_whole_words is not None: + word_begins_mask = self.mask_whole_words.gather(0, item) + word_begins_idx = word_begins_mask.nonzero().view(-1) + sz = len(word_begins_idx) + words = np.split(word_begins_mask, word_begins_idx)[1:] + assert len(words) == sz + word_lens = list(map(len, words)) + + # decide elements to mask + mask = np.full(sz, False) + num_mask = int( + # add a random number for probabilistic rounding + self.mask_prob * sz + + np.random.rand() + ) + mask[np.random.choice(sz, num_mask, replace=False)] = True + + if self.return_masked_tokens: + # exit early if we're just returning the masked tokens + # (i.e., the targets for masked LM training) + if self.mask_whole_words is not None: + mask = np.repeat(mask, word_lens) + new_item = np.full(len(mask), self.pad_idx) + new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1] + return torch.from_numpy(new_item) + + # decide unmasking and random replacement + rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob + if rand_or_unmask_prob > 0.0: + rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) + if self.random_token_prob == 0.0: + unmask = rand_or_unmask + rand_mask = None + elif self.leave_unmasked_prob == 0.0: + unmask = None + rand_mask = rand_or_unmask + else: + unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob + decision = np.random.rand(sz) < unmask_prob + unmask = rand_or_unmask & decision + rand_mask = rand_or_unmask & (~decision) + else: + unmask = rand_mask = None + + if unmask is not None: + mask = mask ^ unmask + + if self.mask_whole_words is not None: + mask = np.repeat(mask, word_lens) + + new_item = np.copy(item) + new_item[mask] = self.mask_idx + if rand_mask is not None: + num_rand = rand_mask.sum() + if num_rand > 0: + if self.mask_whole_words is not None: + rand_mask = np.repeat(rand_mask, word_lens) + num_rand = rand_mask.sum() + + new_item[rand_mask] = np.random.choice( + len(self.vocab), + num_rand, + p=self.weights, + ) + + return torch.from_numpy(new_item) diff --git a/fairseq-0.10.2/fairseq/data/multi_corpus_dataset.py b/fairseq-0.10.2/fairseq/data/multi_corpus_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d2457666d688f773a98af8eb610a4f5756b02dc0 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/multi_corpus_dataset.py @@ -0,0 +1,159 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from collections import OrderedDict +from typing import Dict, List + +import numpy as np +from fairseq.data import data_utils + +from . import FairseqDataset + + +logger = logging.getLogger(__name__) + + +class MultiCorpusDataset(FairseqDataset): + """ + Stores multiple instances of FairseqDataset together. Requires each instance + to be the same dataset, as the collate method needs to work on batches with + samples from each dataset. + + Allows specifying a distribution over the datasets to use. Note that unlike + MultiCorpusSampledDataset, this distribution allows sampling for each item, + rather than on a batch level. + + Each time ordered_indices() is called, a new sample is generated with + the specified distribution. + + Args: + datasets: a OrderedDict of FairseqDataset instances. + distribution: a List containing the probability of getting an utterance from + corresponding dataset + seed: random seed for sampling the datsets + sort_indices: if true, will sort the ordered indices by size + """ + + def __init__( + self, + datasets: Dict[str, FairseqDataset], + distribution: List[float], + seed: int, + sort_indices: bool = False, + ): + super().__init__() + assert isinstance(datasets, OrderedDict) + assert len(datasets) == len(distribution) + self.datasets = datasets + self.distribution = distribution + self.seed = seed + self.sort_indices = sort_indices + + # Avoid repeated conversions to list later + self.dataset_list = list(datasets.values()) + self.total_num_instances = 0 + + first_dataset = list(self.datasets.values())[0] + + self.dataset_offsets = [] + for dataset in datasets.values(): + assert isinstance(dataset, FairseqDataset) + assert type(dataset) is type(first_dataset) + self.dataset_offsets.append(self.total_num_instances) + self.total_num_instances += len(dataset) + + def ordered_indices(self): + with data_utils.numpy_seed(self.seed, self.epoch): + # Used to store the order of indices of each dataset to use + indices = [ + np.random.permutation(len(dataset)) + for dataset in self.datasets.values() + ] + # Keep track of which samples we've used for each dataset + counters = [0 for _ in self.datasets] + + sampled_indices = [ + self._sample(indices, counters) for _ in range(self.total_num_instances) + ] + if self.sort_indices: + sampled_indices.sort(key=lambda i: self.num_tokens(i)) + return np.array(sampled_indices, dtype=np.int64) + + def _sample(self, indices, counters): + # First pick dataset + dataset_idx = np.random.choice(len(self.distribution), p=self.distribution) + + # Then get dataset internal index + idx = indices[dataset_idx][counters[dataset_idx]] + + # Convert to multi-datasets index + idx += self.dataset_offsets[dataset_idx] + + counters[dataset_idx] += 1 + + # Reset if we reach end + if counters[dataset_idx] == len(self.dataset_list[dataset_idx]): + counters[dataset_idx] = 0 + indices[dataset_idx] = np.random.permutation( + len(self.dataset_list[dataset_idx]) + ) + + return idx + + def _map_index(self, index: int): + """ + If dataset A has length N and dataset B has length M + then index 1 maps to index 1 of dataset A, and index N + 1 + maps to index 1 of B. + """ + counter = 0 + for key, dataset in self.datasets.items(): + if index < counter + len(dataset): + return index - counter, key + counter += len(dataset) + raise ValueError( + "Invalid index: {}, max: {}".format(index, self.total_num_instances) + ) + + def __len__(self): + """ + Length of this dataset is the sum of individual datasets + """ + return self.total_num_instances + + def __getitem__(self, index): + index, key = self._map_index(index) + return self.datasets[key][index] + + def collater(self, samples): + """ + Since we enforce all datsets to be the same, collating is just + picking the first one and doing collate. + """ + if len(samples) == 0: + return None + + return list(self.datasets.values())[0].collater(samples) + + def num_tokens(self, index: int): + index, key = self._map_index(index) + return self.datasets[key].num_tokens(index) + + def size(self, index: int): + index, key = self._map_index(index) + return self.datasets[key].size(index) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return False + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @property + def supports_prefetch(self): + return False diff --git a/fairseq-0.10.2/fairseq/data/multi_corpus_sampled_dataset.py b/fairseq-0.10.2/fairseq/data/multi_corpus_sampled_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ad8e951cc905a73fea28b4fac449e307cadfa52f --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/multi_corpus_sampled_dataset.py @@ -0,0 +1,145 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict +from typing import Callable, Dict, List + +import numpy as np + +from . import FairseqDataset + + +def uniform_sampler(x): + # Sample from uniform distribution + return np.random.choice(x, 1).item() + + +class MultiCorpusSampledDataset(FairseqDataset): + """ + Stores multiple instances of FairseqDataset together and in every iteration + creates a batch by first sampling a dataset according to a specified + probability distribution and then getting instances from that dataset. + + Args: + datasets: an OrderedDict of FairseqDataset instances. + sampling_func: A function for sampling over list of dataset keys. + The default strategy is to sample uniformly. + """ + + def __init__( + self, + datasets: Dict[str, FairseqDataset], + sampling_func: Callable[[List], int] = None, + ): + super().__init__() + assert isinstance(datasets, OrderedDict) + self.datasets = datasets + if sampling_func is None: + sampling_func = uniform_sampler + self.sampling_func = sampling_func + + self.total_num_instances = 0 + for _, dataset in datasets.items(): + assert isinstance(dataset, FairseqDataset) + self.total_num_instances += len(dataset) + + self._ordered_indices = None + + def __len__(self): + """ + Length of this dataset is the sum of individual datasets + """ + return self.total_num_instances + + def ordered_indices(self): + """ + Ordered indices for batching. Here we call the underlying + dataset's ordered_indices() so that we get the same random ordering + as we would have from using the underlying dataset directly. + """ + if self._ordered_indices is None: + self._ordered_indices = OrderedDict( + [ + (key, dataset.ordered_indices()) + for key, dataset in self.datasets.items() + ] + ) + return np.arange(len(self)) + + def _map_index_to_dataset(self, key: int, index: int): + """ + Different underlying datasets have different lengths. In order to ensure + we are not accessing an index outside the range of the current dataset + size, we wrap around. This function should be called after we have + created an ordering for this and all underlying datasets. + """ + assert ( + self._ordered_indices is not None + ), "Must call MultiCorpusSampledDataset.ordered_indices() first" + mapped_index = index % len(self.datasets[key]) + return self._ordered_indices[key][mapped_index] + + def __getitem__(self, index: int): + """ + Get the item associated with index from each underlying dataset. + Since index is in the range of [0, TotalNumInstances], we need to + map the index to the dataset before retrieving the item. + """ + return OrderedDict( + [ + (key, dataset[self._map_index_to_dataset(key, index)]) + for key, dataset in self.datasets.items() + ] + ) + + def collater(self, samples: List[Dict]): + """ + Generate a mini-batch for this dataset. + To convert this into a regular mini-batch we use the following + logic: + 1. Select a dataset using the specified probability distribution. + 2. Call the collater function of the selected dataset. + """ + if len(samples) == 0: + return None + + selected_key = self.sampling_func(list(self.datasets.keys())) + selected_samples = [sample[selected_key] for sample in samples] + return self.datasets[selected_key].collater(selected_samples) + + def num_tokens(self, index: int): + """ + Return an example's length (number of tokens), used for batching. Here + we return the max across all examples at index across all underlying + datasets. + """ + return max( + dataset.num_tokens(self._map_index_to_dataset(key, index)) + for key, dataset in self.datasets.items() + ) + + def size(self, index: int): + """ + Return an example's size as a float or tuple. Here we return the max + across all underlying datasets. This value is used when filtering a + dataset with max-positions. + """ + return max( + dataset.size(self._map_index_to_dataset(key, index)) + for key, dataset in self.datasets.items() + ) + + @property + def supports_prefetch(self): + return all( + getattr(dataset, "supports_prefetch", False) + for dataset in self.datasets.values() + ) + + def prefetch(self, indices): + for key, dataset in self.datasets.items(): + dataset.prefetch( + [self._map_index_to_dataset(key, index) for index in indices] + ) diff --git a/fairseq-0.10.2/fairseq/data/num_samples_dataset.py b/fairseq-0.10.2/fairseq/data/num_samples_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..99a17495c701d8a05e0268f98bf453905e11d078 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/num_samples_dataset.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import FairseqDataset + + +class NumSamplesDataset(FairseqDataset): + def __getitem__(self, index): + return 1 + + def __len__(self): + return 0 + + def collater(self, samples): + return sum(samples) diff --git a/fairseq-0.10.2/fairseq/data/offset_tokens_dataset.py b/fairseq-0.10.2/fairseq/data/offset_tokens_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6fabbdcdaa1a8f70d8d8c07db4cd53754503c194 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/offset_tokens_dataset.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import BaseWrapperDataset + + +class OffsetTokensDataset(BaseWrapperDataset): + def __init__(self, dataset, offset): + super().__init__(dataset) + self.offset = offset + + def __getitem__(self, idx): + return self.dataset[idx] + self.offset diff --git a/fairseq-0.10.2/fairseq/data/pad_dataset.py b/fairseq-0.10.2/fairseq/data/pad_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8075bba6a9efc5f8421368ee0b2ae66afe3f5009 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/pad_dataset.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.data import data_utils + +from . import BaseWrapperDataset + + +class PadDataset(BaseWrapperDataset): + def __init__(self, dataset, pad_idx, left_pad): + super().__init__(dataset) + self.pad_idx = pad_idx + self.left_pad = left_pad + + def collater(self, samples): + return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad) + + +class LeftPadDataset(PadDataset): + def __init__(self, dataset, pad_idx): + super().__init__(dataset, pad_idx, left_pad=True) + + +class RightPadDataset(PadDataset): + def __init__(self, dataset, pad_idx): + super().__init__(dataset, pad_idx, left_pad=False) diff --git a/fairseq-0.10.2/fairseq/data/prepend_dataset.py b/fairseq-0.10.2/fairseq/data/prepend_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ad74784d2d7920e4a6225282d95543ce16ea50d9 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/prepend_dataset.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from . import BaseWrapperDataset + + +class PrependDataset(BaseWrapperDataset): + def __init__(self, dataset, prepend_getter, ensure_first_token_is=None): + super().__init__(dataset) + self.prepend_getter = prepend_getter + self.ensure_first_token = ensure_first_token_is + + def __getitem__(self, idx): + item = self.dataset[idx] + is_tuple = isinstance(item, tuple) + src = item[0] if is_tuple else item + + assert self.ensure_first_token is None or src[0] == self.ensure_first_token + prepend_idx = self.prepend_getter(self.dataset, idx) + assert isinstance(prepend_idx, int) + src[0] = prepend_idx + item = tuple((src,) + item[1:]) if is_tuple else src + return item diff --git a/fairseq-0.10.2/fairseq/data/resampling_dataset.py b/fairseq-0.10.2/fairseq/data/resampling_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3d3b993164dc3962df48bacff26714328e843e80 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/resampling_dataset.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +from fairseq.data import BaseWrapperDataset, plasma_utils + + +logger = logging.getLogger(__name__) + + +class ResamplingDataset(BaseWrapperDataset): + """Randomly samples from a given dataset at each epoch. + + Sampling is done with or without replacement, depending on the "replace" + parameter. + + Optionally, the epoch size can be rescaled. This is potentially desirable + to increase per-epoch coverage of the base dataset (since sampling with + replacement means that many items in the dataset will be left out). In the + case of sampling without replacement, size_ratio should be strictly less + than 1. + + Args: + dataset (~torch.utils.data.Dataset): dataset on which to sample. + weights (List[float]): list of probability weights + (default: None, which corresponds to uniform sampling). + replace (bool): sampling mode; True for "with replacement", or False + for "without replacement" (default: True) + size_ratio (float): the ratio to subsample to; must be positive + (default: 1.0). + batch_by_size (bool): whether or not to batch by sequence length + (default: True). + seed (int): RNG seed to use (default: 0). + epoch (int): starting epoch number (default: 1). + """ + + def __init__( + self, + dataset, + weights=None, + replace=True, + size_ratio=1.0, + batch_by_size=True, + seed=0, + epoch=1, + ): + super().__init__(dataset) + + if weights is None: + self.weights = None + + else: + assert len(weights) == len(dataset) + weights_arr = np.array(weights, dtype=np.float64) + weights_arr /= weights_arr.sum() + self.weights = plasma_utils.PlasmaArray(weights_arr) + + self.replace = replace + + assert size_ratio > 0.0 + if not self.replace: + assert size_ratio < 1.0 + self.size_ratio = float(size_ratio) + self.actual_size = np.ceil(len(dataset) * self.size_ratio).astype(int) + + self.batch_by_size = batch_by_size + self.seed = seed + + self._cur_epoch = None + self._cur_indices = None + + self.set_epoch(epoch) + + def __getitem__(self, index): + return self.dataset[self._cur_indices.array[index]] + + def __len__(self): + return self.actual_size + + @property + def sizes(self): + if isinstance(self.dataset.sizes, list): + return [s[self._cur_indices.array] for s in self.dataset.sizes] + return self.dataset.sizes[self._cur_indices.array] + + def num_tokens(self, index): + return self.dataset.num_tokens(self._cur_indices.array[index]) + + def size(self, index): + return self.dataset.size(self._cur_indices.array[index]) + + def ordered_indices(self): + if self.batch_by_size: + order = [ + np.arange(len(self)), + self.sizes, + ] # No need to handle `self.shuffle == True` + return np.lexsort(order) + else: + return np.arange(len(self)) + + def prefetch(self, indices): + self.dataset.prefetch(self._cur_indices.array[indices]) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return False + + def set_epoch(self, epoch): + logger.debug("ResamplingDataset.set_epoch: {}".format(epoch)) + super().set_epoch(epoch) + + if epoch == self._cur_epoch: + return + + self._cur_epoch = epoch + + # Generate a weighted sample of indices as a function of the + # random seed and the current epoch. + + rng = np.random.RandomState( + [ + 42, # magic number + self.seed % (2 ** 32), # global seed + self._cur_epoch, # epoch index + ] + ) + self._cur_indices = plasma_utils.PlasmaArray( + rng.choice( + len(self.dataset), + self.actual_size, + replace=self.replace, + p=(None if self.weights is None else self.weights.array), + ) + ) diff --git a/fairseq-0.10.2/fairseq/data/shorten_dataset.py b/fairseq-0.10.2/fairseq/data/shorten_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6ebb5d88feb3f29d1512a0873df304915d051209 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/shorten_dataset.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from fairseq.data import data_utils + +from . import BaseWrapperDataset + + +class TruncateDataset(BaseWrapperDataset): + """Truncate a sequence by returning the first truncation_length tokens""" + + def __init__(self, dataset, truncation_length): + super().__init__(dataset) + assert truncation_length is not None + self.truncation_length = truncation_length + self.dataset = dataset + + def __getitem__(self, index): + item = self.dataset[index] + item_len = item.size(0) + if item_len > self.truncation_length: + item = item[: self.truncation_length] + return item + + @property + def sizes(self): + return np.minimum(self.dataset.sizes, self.truncation_length) + + def __len__(self): + return len(self.dataset) + + +class RandomCropDataset(TruncateDataset): + """Truncate a sequence by returning a random crop of truncation_length tokens""" + + def __init__(self, dataset, truncation_length, seed=1): + super().__init__(dataset, truncation_length) + self.seed = seed + self.epoch = 0 + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the crop changes, not item sizes + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + def __getitem__(self, index): + with data_utils.numpy_seed(self.seed, self.epoch, index): + item = self.dataset[index] + item_len = item.size(0) + excess = item_len - self.truncation_length + if excess > 0: + start_idx = np.random.randint(0, excess) + item = item[start_idx : start_idx + self.truncation_length] + return item + + +def maybe_shorten_dataset( + dataset, + split, + shorten_data_split_list, + shorten_method, + tokens_per_sample, + seed, +): + truncate_split = ( + split in shorten_data_split_list.split(",") or len(shorten_data_split_list) == 0 + ) + if shorten_method == "truncate" and truncate_split: + dataset = TruncateDataset(dataset, tokens_per_sample) + elif shorten_method == "random_crop" and truncate_split: + dataset = RandomCropDataset(dataset, tokens_per_sample, seed) + return dataset diff --git a/fairseq-0.10.2/fairseq/data/strip_token_dataset.py b/fairseq-0.10.2/fairseq/data/strip_token_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cae39ba4d2f8106398eccd7eb0cf5c2194ec0db5 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/strip_token_dataset.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import BaseWrapperDataset + + +class StripTokenDataset(BaseWrapperDataset): + def __init__(self, dataset, id_to_strip): + super().__init__(dataset) + self.id_to_strip = id_to_strip + + def __getitem__(self, index): + item = self.dataset[index] + while len(item) > 0 and item[-1] == self.id_to_strip: + item = item[:-1] + while len(item) > 0 and item[0] == self.id_to_strip: + item = item[1:] + return item diff --git a/fairseq-0.10.2/fairseq/data/token_block_dataset.py b/fairseq-0.10.2/fairseq/data/token_block_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..aa33f9d06f37fa6a1e239d9733a2725ec158f6a8 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/token_block_dataset.py @@ -0,0 +1,168 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from fairseq.data import FairseqDataset, plasma_utils + + +class TokenBlockDataset(FairseqDataset): + """Break a Dataset of tokens into blocks. + + Args: + dataset (~torch.utils.data.Dataset): dataset to break into blocks + sizes (List[int]): sentence lengths (required for 'complete' and 'eos') + block_size (int): maximum block size (ignored in 'eos' break mode) + break_mode (str, optional): Mode used for breaking tokens. Values can + be one of: + - 'none': break tokens into equally sized blocks (up to block_size) + - 'complete': break tokens into blocks (up to block_size) such that + blocks contains complete sentences, although block_size may be + exceeded if some sentences exceed block_size + - 'complete_doc': similar to 'complete' mode, but do not + cross document boundaries + - 'eos': each block contains one sentence (block_size is ignored) + include_targets (bool, optional): return next tokens as targets + (default: False). + document_sep_len (int, optional): document separator size (required for + 'complete_doc' break mode). Typically 1 if the sentences have eos + and 0 otherwise. + """ + + def __init__( + self, + dataset, + sizes, + block_size, + pad, + eos, + break_mode=None, + include_targets=False, + document_sep_len=1, + ): + try: + from fairseq.data.token_block_utils_fast import ( + _get_slice_indices_fast, + _get_block_to_dataset_index_fast, + ) + except ImportError: + raise ImportError( + "Please build Cython components with: `pip install --editable .` " + "or `python setup.py build_ext --inplace`" + ) + + super().__init__() + self.dataset = dataset + self.pad = pad + self.eos = eos + self.include_targets = include_targets + + assert len(dataset) == len(sizes) + assert len(dataset) > 0 + + if isinstance(sizes, list): + sizes = np.array(sizes, dtype=np.int64) + else: + if torch.is_tensor(sizes): + sizes = sizes.numpy() + sizes = sizes.astype(np.int64) + + break_mode = break_mode if break_mode is not None else "none" + + # For "eos" break-mode, block_size is not required parameters. + if break_mode == "eos" and block_size is None: + block_size = 0 + + slice_indices = _get_slice_indices_fast( + sizes, str(break_mode), block_size, document_sep_len + ) + self._sizes = slice_indices[:, 1] - slice_indices[:, 0] + + # build index mapping block indices to the underlying dataset indices + if break_mode == "eos": + # much faster version for eos break mode + block_to_dataset_index = np.stack( + [ + np.arange(len(sizes)), # starting index in dataset + np.zeros( + len(sizes), dtype=np.long + ), # starting offset within starting index + np.arange(len(sizes)), # ending index in dataset + ], + 1, + ) + else: + block_to_dataset_index = _get_block_to_dataset_index_fast( + sizes, + slice_indices, + ) + self._slice_indices = plasma_utils.PlasmaArray(slice_indices) + self._sizes = plasma_utils.PlasmaArray(self._sizes) + self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index) + + @property + def slice_indices(self): + return self._slice_indices.array + + @property + def sizes(self): + return self._sizes.array + + @property + def block_to_dataset_index(self): + return self._block_to_dataset_index.array + + def attr(self, attr: str, index: int): + start_ds_idx, _, _ = self.block_to_dataset_index[index] + return self.dataset.attr(attr, start_ds_idx) + + def __getitem__(self, index): + start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index] + + buffer = torch.cat( + [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] + ) + + slice_s, slice_e = self.slice_indices[index] + length = slice_e - slice_s + s, e = start_offset, start_offset + length + item = buffer[s:e] + + if self.include_targets: + # *target* is the original sentence (=item) + # *source* is shifted right by 1 (maybe left-padded with eos) + # *past_target* is shifted right by 2 (left-padded as needed) + if s == 0: + source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]]) + past_target = torch.cat( + [item.new([self.pad, self.eos]), buffer[0 : e - 2]] + ) + else: + source = buffer[s - 1 : e - 1] + if s == 1: + past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]]) + else: + past_target = buffer[s - 2 : e - 2] + + return source, item, past_target + + return item + + def __len__(self): + return len(self.slice_indices) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + self.dataset.prefetch( + { + ds_idx + for index in indices + for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]] + for ds_idx in range(start_ds_idx, end_ds_idx + 1) + } + ) diff --git a/fairseq-0.10.2/fairseq/data/transform_eos_lang_pair_dataset.py b/fairseq-0.10.2/fairseq/data/transform_eos_lang_pair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1dd3d93d2b41898ba6b25ba0255abbcebcf495b7 --- /dev/null +++ b/fairseq-0.10.2/fairseq/data/transform_eos_lang_pair_dataset.py @@ -0,0 +1,108 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional + +import torch + +from . import FairseqDataset + + +class TransformEosLangPairDataset(FairseqDataset): + """A :class:`~fairseq.data.FairseqDataset` wrapper that transform bos on + collated samples of language pair dataset. + + Note that the transformation is applied in :func:`collater`. + + Args: + dataset (~fairseq.data.FairseqDataset): dataset that collates sample into + LanguagePairDataset schema + src_eos (int): original source end-of-sentence symbol index to be replaced + new_src_eos (int, optional): new end-of-sentence symbol index to replace source eos symbol + tgt_bos (int, optional): original target beginning-of-sentence symbol index to be replaced + new_tgt_bos (int, optional): new beginning-of-sentence symbol index to replace at the + beginning of 'prev_output_tokens' + """ + + def __init__( + self, + dataset: FairseqDataset, + src_eos: int, + new_src_eos: Optional[int] = None, + tgt_bos: Optional[int] = None, + new_tgt_bos: Optional[int] = None, + ): + self.dataset = dataset + self.src_eos = src_eos + self.new_src_eos = new_src_eos + self.tgt_bos = tgt_bos + self.new_tgt_bos = new_tgt_bos + + def __getitem__(self, index): + return self.dataset[index] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples, **extra_args): + samples = self.dataset.collater(samples, **extra_args) + + if self.new_src_eos is not None: + if self.dataset.left_pad_source: + assert ( + samples["net_input"]["src_tokens"][:, -1] != self.src_eos + ).sum() == 0 + samples["net_input"]["src_tokens"][:, -1] = self.new_src_eos + else: + eos_idx = samples["net_input"]["src_lengths"] - 1 + assert ( + samples["net_input"]["src_tokens"][ + torch.arange(eos_idx.size(0)), eos_idx + ] + != self.src_eos + ).sum() == 0 + eos_idx = eos_idx.resize_(len(samples["net_input"]["src_lengths"]), 1) + samples["net_input"]["src_tokens"].scatter_( + 1, eos_idx, self.new_src_eos + ) + + if ( + self.new_tgt_bos is not None + and "prev_output_tokens" in samples["net_input"] + ): + if self.dataset.left_pad_target: + # TODO: support different padding direction on target side + raise NotImplementedError( + "TransformEosLangPairDataset does not implement --left-pad-target True option" + ) + else: + assert ( + samples["net_input"]["prev_output_tokens"][:, 0] != self.tgt_bos + ).sum() == 0 + samples["net_input"]["prev_output_tokens"][:, 0] = self.new_tgt_bos + + return samples + + def num_tokens(self, index): + return self.dataset.num_tokens(index) + + def size(self, index): + return self.dataset.size(index) + + @property + def sizes(self): + # dataset.sizes can be a dynamically computed sizes: + return self.dataset.sizes + + def ordered_indices(self): + return self.dataset.ordered_indices() + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + return self.dataset.prefetch(indices) diff --git a/fairseq-0.10.2/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so b/fairseq-0.10.2/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..c5b8c0cb206abc7a419cb637e5265662c160ef72 --- /dev/null +++ b/fairseq-0.10.2/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b0b5376f3a99c755cece761f9c9d1001ab792bd99574b8cb3401dd20ce49cbdc +size 148544 diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/__init__.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19dd5a8498d2951d11af13c91a1170b54d3cddb8 Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/fairseq_nat_model.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/fairseq_nat_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13540ed1698837eb859aa8648ef0ddad7f4b109d Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/fairseq_nat_model.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/insertion_transformer.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/insertion_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6559dc0c11b4deb81a49876053826cf42ba92ad1 Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/insertion_transformer.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/iterative_nonautoregressive_transformer.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/iterative_nonautoregressive_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d37907884648a79d97325454adfe183f6cf24e7f Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/iterative_nonautoregressive_transformer.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/levenshtein_transformer.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/levenshtein_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69dddfd8ea5f058b47080611bb02cb2972e25dee Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/levenshtein_transformer.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/levenshtein_utils.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/levenshtein_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e081bd4b82c2a35e71506e9860750c81c048729c Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/levenshtein_utils.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/nat_crf_transformer.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/nat_crf_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fb3c5cecc13bc3518c4916d18b1fe9b5360beb6 Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/nat_crf_transformer.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/models/nat/__pycache__/nonautoregressive_transformer.cpython-310.pyc b/fairseq-0.10.2/fairseq/models/nat/__pycache__/nonautoregressive_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..318a0ffcdbb2950522d638ec4c2210db956f2014 Binary files /dev/null and b/fairseq-0.10.2/fairseq/models/nat/__pycache__/nonautoregressive_transformer.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/optim/__init__.py b/fairseq-0.10.2/fairseq/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94eb2c7ee966756590a4294eab181fc87fac2fa1 --- /dev/null +++ b/fairseq-0.10.2/fairseq/optim/__init__.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import importlib +import os +from argparse import Namespace +from typing import Union + +from fairseq import registry +from fairseq.optim.bmuf import FairseqBMUF # noqa +from fairseq.optim.fairseq_optimizer import ( # noqa + FairseqOptimizer, + LegacyFairseqOptimizer, +) +from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer +from fairseq.optim.shard import shard_ +from omegaconf import DictConfig + + +__all__ = [ + "FairseqOptimizer", + "FP16Optimizer", + "MemoryEfficientFP16Optimizer", + "shard_", +] + + +( + _build_optimizer, + register_optimizer, + OPTIMIZER_REGISTRY, + OPTIMIZER_DATACLASS_REGISTRY, +) = registry.setup_registry("--optimizer", base_class=FairseqOptimizer, required=True) + + +def build_optimizer( + optimizer_cfg: Union[DictConfig, Namespace], params, *extra_args, **extra_kwargs +): + if all(isinstance(p, dict) for p in params): + params = [t for p in params for t in p.values()] + params = list(filter(lambda p: p.requires_grad, params)) + return _build_optimizer(optimizer_cfg, params, *extra_args, **extra_kwargs) + + +# automatically import any Python files in the optim/ directory +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.optim." + file_name) diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/adadelta.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/adadelta.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78b95e0d63c7f5698ee12550b8415dd5c452943f Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/adadelta.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/adagrad.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/adagrad.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65eaeeedec69594d23977ed23bd7b4bd1ce9944f Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/adagrad.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/adamax.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/adamax.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66d32c310934ec2f11c002dcbd5e9d33c184019e Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/adamax.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/bmuf.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/bmuf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f687c8d0e4a98d39a52995ee1ddfcf37c40c450d Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/bmuf.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..355deff664f9042cbb78f5cb709388c7c01b5824 Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ea7fdac9d2b2ddad8af9f0bf91b9f116222e08c Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cd6a32764a45a2e4f8a5f53ade1a315dcbdf56f Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/nag.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/nag.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66254eec31174ec38aa5908c6dc4c16fc5051d3d Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/nag.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/optim/__pycache__/sgd.cpython-310.pyc b/fairseq-0.10.2/fairseq/optim/__pycache__/sgd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a967605193d21ad9de0cca8e586c8877fab2cb5 Binary files /dev/null and b/fairseq-0.10.2/fairseq/optim/__pycache__/sgd.cpython-310.pyc differ diff --git a/fairseq-0.10.2/fairseq/optim/adafactor.py b/fairseq-0.10.2/fairseq/optim/adafactor.py new file mode 100644 index 0000000000000000000000000000000000000000..91745ce10e183479f8cb552f2a1a91834e2a61ed --- /dev/null +++ b/fairseq-0.10.2/fairseq/optim/adafactor.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.optim + +from . import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("adafactor") +class FairseqAdafactor(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = Adafactor(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--adafactor-eps', default='(1e-30, 1e-3)', metavar="E", + help='epsilons for Adafactor optimizer') + parser.add_argument('--clip-threshold', type=float, default=1.0, metavar="C", + help='threshold for clipping update root mean square') + parser.add_argument('--decay-rate', type=float, default=-0.8, metavar="D", + help='decay rate of the second moment estimator') + parser.add_argument('--beta1', type=float, default=None, metavar="B", + help='beta for first moment estimator. Optional') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + parser.add_argument('--scale-parameter', action='store_true', + help='scale learning rate by root mean square of parameter') + parser.add_argument('--relative-step', action='store_true', + help='set learning rate to inverse square root of timestep,' + 'otherwise use external learning rate') + parser.add_argument('--warmup-init', action='store_true', + help='use relative step for warm-up learning rate schedule') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + Note : Convergence issues empirically observed with fp16 on. + Might require search for appropriate configuration. + """ + return { + "lr": self.args.lr[0], + "eps": eval(self.args.adafactor_eps), + "clip_threshold": self.args.clip_threshold, + "decay_rate": self.args.decay_rate, + "beta1": self.args.beta1, + "weight_decay": self.args.weight_decay, + "scale_parameter": self.args.scale_parameter, # defaults to False + "relative_step": self.args.relative_step, # defaults to False + "warmup_init": self.args.warmup_init, + } + + +class Adafactor(torch.optim.Optimizer): + """Implements Adafactor algorithm. + + This implementation is based on: + `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` + (see https://arxiv.org/abs/1804.04235) + + Note that this optimizer internally adjusts the learning rate + depending on the *scale_parameter*, *relative_step* and + *warmup_init* options. To use a manual (external) learning rate + schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): external learning rate (default: None) + eps (tuple[float, float]): regularization constans for square gradient + and parameter scale respectively (default: (1e-30, 1e-3)) + clip_threshold (float): threshold of root mean square of + final gradient update (default: 1.0) + decay_rate (float): coefficient used to compute running averages of square + gradient (default: -0.8) + beta1 (float): coefficient used for computing running averages of gradient + (default: None) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + scale_parameter (bool): if True, learning rate is scaled by root mean square of + parameter (default: True) + relative_step (bool): if True, time-dependent learning rate is computed + instead of external learning rate (default: True) + warmup_init (bool): time-dependent learning rate computation depends on + whether warm-up initialization is being used (default: False) + """ + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + if lr is not None and relative_step: + raise ValueError("Cannot combine manual lr and relative_step options") + if warmup_init and not relative_step: + raise ValueError("warmup_init requires relative_step=True") + + defaults = dict( + lr=lr, + eps=eps, + clip_threshold=clip_threshold, + decay_rate=decay_rate, + beta1=beta1, + weight_decay=weight_decay, + scale_parameter=scale_parameter, + relative_step=relative_step, + warmup_init=warmup_init, + ) + super(Adafactor, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return False + + def _get_lr(self, param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = ( + 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + ) + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + def _get_options(self, param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + def _rms(self, tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): + r_factor = ( + (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)) + .rsqrt_() + .unsqueeze(-1) + ) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:] + ).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + group["lr"] = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad ** 2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_( + update.mean(dim=-1), alpha=1.0 - beta2t + ) + exp_avg_sq_col.mul_(beta2t).add_( + update.mean(dim=-2), alpha=1.0 - beta2t + ) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_( + (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0) + ) + update.mul_(group["lr"]) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"]) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=-group["weight_decay"] * group["lr"] + ) + + p_data_fp32.add_(-update) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss diff --git a/fairseq-0.10.2/fairseq/optim/adamax.py b/fairseq-0.10.2/fairseq/optim/adamax.py new file mode 100644 index 0000000000000000000000000000000000000000..577a68816692710b008f9088bdf7fa45868c64a0 --- /dev/null +++ b/fairseq-0.10.2/fairseq/optim/adamax.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.optim + +from . import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("adamax") +class FairseqAdamax(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = Adamax(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--adamax-betas', default='(0.9, 0.999)', metavar='B', + help='betas for Adam optimizer') + parser.add_argument('--adamax-eps', type=float, default=1e-8, metavar='D', + help='epsilon for Adam optimizer') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + parser.add_argument('--no-bias-correction', default=False, action='store_true', + help='disable bias correction') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "betas": eval(self.args.adamax_betas), + "eps": self.args.adamax_eps, + "weight_decay": self.args.weight_decay, + "bias_correction": not self.args.no_bias_correction, + } + + +class Adamax(torch.optim.Optimizer): + """Implements Adamax algorithm (a variant of Adam based on infinity norm). + + It has been proposed in `Adam: A Method for Stochastic Optimization`__. + + Compared to the version in PyTorch, this version implements a fix for weight decay. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + bias_correction (bool, optional): enable bias correction (default: True) + + __ https://arxiv.org/abs/1412.6980 + """ + + def __init__( + self, + params, + lr=2e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + bias_correction=True, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + bias_correction=bias_correction, + ) + super(Adamax, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("Adamax does not support sparse gradients") + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_inf"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].to(p_data_fp32) + state["exp_inf"] = state["exp_inf"].to(p_data_fp32) + + exp_avg, exp_inf = state["exp_avg"], state["exp_inf"] + beta1, beta2 = group["betas"] + eps = group["eps"] + + state["step"] += 1 + + # Update biased first moment estimate. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + # Update the exponentially weighted infinity norm. + torch.max( + exp_inf.mul_(beta2), + grad.abs_(), + out=exp_inf, + ) + + step_size = group["lr"] + if group["bias_correction"]: + bias_correction = 1 - beta1 ** state["step"] + step_size /= bias_correction + + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=-group["weight_decay"] * group["lr"] + ) + + p_data_fp32.addcdiv_(exp_avg, exp_inf.add(eps), value=-step_size) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss diff --git a/fairseq-0.10.2/fairseq/optim/bmuf.py b/fairseq-0.10.2/fairseq/optim/bmuf.py new file mode 100644 index 0000000000000000000000000000000000000000..3312f81103490f8414487e54c6e0a4a5b2aa5de3 --- /dev/null +++ b/fairseq-0.10.2/fairseq/optim/bmuf.py @@ -0,0 +1,231 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +import torch +import torch.distributed as dist +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.optim.fairseq_optimizer import FairseqOptimizer +from omegaconf import II + + +@dataclass +class FairseqBMUFConfig(FairseqDataclass): + block_lr: float = field( + default=1, metadata={"help": "block learning rate for bmuf"} + ) + block_momentum: float = field( + default=0.875, metadata={"help": "block momentum for bmuf"} + ) + global_sync_iter: int = field( + default=50, metadata={"help": "Iteration for syncing global model"} + ) + warmup_iterations: int = field( + default=500, metadata={"help": "warmup iterations for model to broadcast"} + ) + use_nbm: bool = field( + default=False, + metadata={"help": "Specify whether you want to use classical BM / Nesterov BM"}, + ) + average_sync: bool = field( + default=False, + metadata={ + "help": "Specify whether you want to average the local momentum after each sync" + }, + ) + distributed_world_size: int = II( + "params.distributed_training.distributed_world_size" + ) + + +class FairseqBMUF(FairseqOptimizer): + """ + Implements incremental block distributed data parallelism similar to + https://ieeexplore.ieee.org/document/7472805 + + Paper title: Scalable training of deep learning machines by incremental + block training with intra-block parallel optimization and blockwise + model-update filtering + """ + + def __init__(self, args, optimizer): + + super().__init__(args) + self._optimizer = optimizer + self._num_updates = 0 + self.sync_iter = self.args.global_sync_iter + self.block_momentum = self.args.block_momentum + self.block_lr = self.args.block_lr + self._reset_local_data() + self.warmup_iteration = self.args.warmup_iterations + self.use_nbm = self.args.use_nbm + self.initial_state = self._optimizer.state_dict() + self.average_sync = self.args.average_sync + self.world_size = self.args.distributed_world_size + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + gen_parser_from_dataclass(parser, FairseqBMUFConfig()) + + @property + def optimizer(self): + return self._optimizer.optimizer + + @property + def optimizer_config(self): + return self._optimizer.optimizer_config + + def get_lr(self): + return self._optimizer.get_lr() + + def set_lr(self, lr): + self._optimizer.set_lr(lr) + + def state_dict(self): + return self._optimizer.state_dict() + + def load_state_dict(self, state_dict, optimizer_overrides=None): + self._optimizer.load_state_dict(state_dict, optimizer_overrides) + self.initial_state = self._optimizer.state_dict() + + def multiply_grads(self, c): + """Multiplies grads by a constant *c*.""" + self._optimizer.multiply_grads(c) + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm.""" + return self._optimizer.clip_grad_norm(max_norm, aggregate_norm_fn) + + def average_params(self): + self._optimizer.average_params() + + def _block_sync(self): + if self.world_size <= 1: + return + # Update the global model using local models from all GPUs + # (Step-1) Calculate grad between previously synced model and + # currrent local model + if self.block_momentum != 0: + self._calc_grad() + + # (Step-2) Average gradient from all GPUs + self._avg_grad_from_all_gpus() + + # (Step-3) Calculate global momentum and update the global model + if self.block_momentum != 0: + self._update_global_model() + + # (Step-4) Average local optimizer params + if self.average_sync: + self.average_params() + + def _is_warmup_end(self): + # Check whether train iterations is equal to warmup iter + if self.get_num_updates() == self.warmup_iteration: + return True + return False + + def _is_bmuf_iter(self): + # Check whether train iterations is equal to bmuf sync iter + if (self.get_num_updates() > self.warmup_iteration) and ( + self.get_num_updates() % self.sync_iter == 0 + ): + return True + return False + + def _warmup_sync(self, root_rank=0): + if self.world_size <= 1: + return + # Broadcast the local model to all gpus + for param in self.params: + dist.broadcast(param.data, src=root_rank) + + # Update local optimizer state + if self.average_sync: + self._optimizer.average_params() + else: + self._optimizer.load_state_dict(self.initial_state) + + self._reset_local_data() + + def step(self, closure=None): + """Performs a single optimization step.""" + self._optimizer.step(closure) + self.set_num_updates(self.get_num_updates() + 1) + if self._is_warmup_end(): + self._warmup_sync() + elif self._is_bmuf_iter(): + self._block_sync() + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + self._optimizer.zero_grad() + + def get_num_updates(self): + """Get the number of parameters updates.""" + return self._num_updates + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + self._num_updates = num_updates + + @torch.no_grad() + def _reset_local_data(self): + # (Step-0) Initialize global momentum parameters and store global copy on each gpu + self.global_params = [torch.zeros_like(p.data) for p in self.params] + self.smoothed_grads = [p.data.new_zeros(p.data.size()) for p in self.params] + self.grads = [p.data.new_zeros(p.data.size()) for p in self.params] + + # saving the global model locally for calculating gradient during bmuf sync + for param, global_param in zip(self.params, self.global_params): + global_param.copy_(param.data) + + @torch.no_grad() + def _calc_grad(self): + # global_params is basically the global copy from the previously finished + # synchronisation. param.data is local parameter after block_sync_freq + # for the local gpu. so grad is difference between previously synced + # model and currrent local model. + for index, (param, global_param) in enumerate( + zip(self.params, self.global_params) + ): + self.grads[index] = global_param - param.data + + def _avg_grad_from_all_gpus(self): + for index, param in enumerate(self.params): + sync_para = param.data if self.block_momentum == 0 else self.grads[index] + sync_para /= float(dist.get_world_size()) + dist.all_reduce(sync_para, op=dist.ReduceOp.SUM) + + @torch.no_grad() + def _update_global_model(self): + for index, (param, global_param, smoothed_grad, grad) in enumerate( + zip( + self.params, + self.global_params, + self.smoothed_grads, + # all gpus would share the same value of smoothed_grad, since it is + # always computed on synchronized gradients. + self.grads, + ) + ): + # global_param is basically last syncrhornized parameter. though + # smoothed_grad is local, all processes will have same value of + # smoothed_grad and hence param is globally synchronized copy. + # smoothed_grad(t) = BM * smoothed_grad(t-1) + BM_lr * grad(t) + smoothed_grad = self.block_momentum * smoothed_grad + self.block_lr * grad + param.data.copy_(global_param - smoothed_grad) + + # A Nesterov momentum here is to do a partial weight update before + # calculating the gradient + if self.use_nbm: + param.data.copy_(param.data - self.block_momentum * smoothed_grad) + + # backup for the next synchronization. + self.smoothed_grads[index] = smoothed_grad + global_param.copy_(param.data) diff --git a/fairseq-0.10.2/fairseq/optim/fairseq_optimizer.py b/fairseq-0.10.2/fairseq/optim/fairseq_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8a10399a8b413c4188dbe7c8d51f5353348d835b --- /dev/null +++ b/fairseq-0.10.2/fairseq/optim/fairseq_optimizer.py @@ -0,0 +1,150 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq import utils +from fairseq.dataclass.utils import gen_parser_from_dataclass + + +class FairseqOptimizer(object): + def __init__(self, args): + super().__init__() + self.args = args + + @classmethod + def add_args(cls, parser): + """Add optimizer-specific arguments to the parser.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) + + @property + def optimizer(self): + """Return a torch.optim.optimizer.Optimizer instance.""" + if not hasattr(self, "_optimizer"): + raise NotImplementedError + if not isinstance(self._optimizer, torch.optim.Optimizer): + raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") + return self._optimizer + + @optimizer.setter + def optimizer(self, optimizer): + """Reset optimizer instance.""" + if not hasattr(self, "_optimizer"): + raise NotImplementedError + if not isinstance(self._optimizer, torch.optim.Optimizer): + raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") + self._optimizer = optimizer + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + raise NotImplementedError + + @property + def params(self): + """Return an iterable of the parameters held by the optimizer.""" + for param_group in self.param_groups: + for p in param_group["params"]: + yield p + + @property + def param_groups(self): + return self.optimizer.param_groups + + def __getstate__(self): + return self._optimizer.__getstate__() + + def get_lr(self): + """Return the current learning rate.""" + return self.param_groups[0]["lr"] + + def set_lr(self, lr): + """Set the learning rate.""" + for param_group in self.param_groups: + param_group["lr"] = lr + + def state_dict(self): + """Return the optimizer's state dict.""" + return self.optimizer.state_dict() + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + self.optimizer.load_state_dict(state_dict) + + if optimizer_overrides is not None and len(optimizer_overrides) > 0: + # override learning rate, momentum, etc. with latest values + for group in self.param_groups: + group.update(optimizer_overrides) + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves.""" + loss.backward() + + def multiply_grads(self, c): + """Multiplies grads by a constant *c*.""" + for p in self.params: + if p.grad is not None: + p.grad.data.mul_(c) + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm.""" + return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) + + def step(self, closure=None, scale=1.0): + """Performs a single optimization step.""" + if self.supports_step_with_scale: + self.optimizer.step(closure, scale=scale) + else: + if scale != 1.0: + self.multiply_grads(1.0 / scale) + self.optimizer.step(closure) + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + for p in self.params: + p.grad = None + self.optimizer.zero_grad() + + @property + def supports_memory_efficient_fp16(self): + if hasattr(self.optimizer, "supports_memory_efficient_fp16"): + return self.optimizer.supports_memory_efficient_fp16 + return False + + @property + def supports_step_with_scale(self): + if hasattr(self.optimizer, "supports_step_with_scale"): + return self.optimizer.supports_step_with_scale + return False + + @property + def supports_flat_params(self): + """ + Whether the optimizer supports collapsing of the model + parameters/gradients into a single contiguous Tensor. + """ + if hasattr(self.optimizer, "supports_flat_params"): + return self.optimizer.supports_flat_params + return False + + def average_params(self): + pass + + +class LegacyFairseqOptimizer(FairseqOptimizer): + def __init__(self, args): + self.args = args diff --git a/fairseq-0.10.2/fairseq/optim/fp16_optimizer.py b/fairseq-0.10.2/fairseq/optim/fp16_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..687315030fef38390d22f17db1cbae88a1560266 --- /dev/null +++ b/fairseq-0.10.2/fairseq/optim/fp16_optimizer.py @@ -0,0 +1,491 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from itertools import chain + +import torch +from fairseq import optim, utils + +from .dynamic_loss_scaler import DynamicLossScaler + + +class _FP16OptimizerMixin(object): + def __init__(self, *args, **kwargs): + # forward __init__ call to the next class in mro(method resolution order) + super().__init__(*args, **kwargs) + self._multiply_factor = 1.0 + + @property + def has_flat_params(self): + return torch.is_tensor(self.fp32_params) or ( + isinstance(self.fp32_params, dict) + and all(torch.is_tensor(t) for t in self.fp32_params.values()) + ) + + @classmethod + def build_fp32_params(cls, args, params, flatten=True): + # create FP32 copy of parameters and grads + if flatten: + is_pipeline_parallel = getattr( + args, "pipeline_model_parallel", False + ) and getattr(args, "distributed_no_spawn", False) + total_param_size = sum(p.data.numel() for p in params) + devices = [torch.cuda.current_device()] + if is_pipeline_parallel: + devices = list(set(args.pipeline_devices)) + fp32_params = {} + for device in devices: + if is_pipeline_parallel: + device_param_size = sum( + p.data.numel() for p in params if p.device.index == device + ) + device_params = [p for p in params if p.device.index == device] + else: + device_param_size = total_param_size + device_params = params + fp32_params[device] = ( + device_params[0].new(0).float().new(device_param_size) + ) + offset = 0 + for p in device_params: + numel = p.data.numel() + fp32_params[device][offset : offset + numel].copy_(p.data.view(-1)) + offset += numel + fp32_params[device] = torch.nn.Parameter(fp32_params[device]) + fp32_params[device].grad = fp32_params[device].data.new( + device_param_size + ) + return fp32_params + else: + fp32_params = [] + for p in params: + p32 = torch.nn.Parameter(p.data.float()) + p32.grad = torch.zeros_like(p32.data) + fp32_params.append(p32) + return fp32_params + + def state_dict(self): + """Return the optimizer's state dict.""" + state_dict = self.fp32_optimizer.state_dict() + if self.scaler is not None: + state_dict["loss_scale"] = self.scaler.loss_scale + return state_dict + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + if "loss_scale" in state_dict and self.scaler is not None: + self.scaler.loss_scale = state_dict["loss_scale"] + self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides) + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves. + + Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this + function additionally dynamically scales the loss to avoid gradient + underflow. + """ + if self.scaler is not None: + loss = self.scaler.scale(loss) + loss.backward() + self._needs_sync = True + + def _sync_fp16_grads_to_fp32(self): + if self._needs_sync: + # copy FP16 grads to FP32 + if self.has_flat_params: + devices = list(self.fp32_params.keys()) + device_params_dict = defaultdict(list) + for p in self.fp16_params: + if p.requires_grad: + device_params_dict[p.device.index].append(p) + for device in devices: + device_params = device_params_dict[device] + offset = 0 + for p in device_params: + grad_data = ( + p.grad.data + if p.grad is not None + else p.data.new_zeros(p.data.shape) + ) + numel = grad_data.numel() + self.fp32_params[device].grad.data[ + offset : offset + numel + ].copy_(grad_data.view(-1)) + offset += numel + else: + for p, p32 in zip(self.fp16_params, self.fp32_params): + if not p.requires_grad: + continue + if p.grad is not None: + p32.grad.data.copy_(p.grad.data) + else: + p32.grad = torch.zeros_like(p.data, dtype=torch.float) + + self._needs_sync = False + + def _sync_fp32_params_to_fp16(self): + # copy FP32 params back into FP16 model + if self.has_flat_params: + devices = list(self.fp32_params.keys()) + device_params_dict = defaultdict(list) + for p in self.fp16_params: + device_params_dict[p.device.index].append(p) + for device in devices: + device_params = device_params_dict[device] + offset = 0 + for p in device_params: + numel = p.data.numel() + p.data.copy_( + self.fp32_params[device] + .data[offset : offset + numel] + .view_as(p.data) + ) + offset += numel + else: + for p, p32 in zip(self.fp16_params, self.fp32_params): + if not p.requires_grad: + continue + p.data.copy_(p32.data) + + def _unscale_grads(self): + self._sync_fp16_grads_to_fp32() + if self._multiply_factor != 1.0: + self.fp32_optimizer.multiply_grads(self._multiply_factor) + self._multiply_factor = 1.0 + + def multiply_grads(self, c): + """Multiplies grads by a constant ``c``.""" + self._multiply_factor *= c + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm and updates dynamic loss scaler.""" + self._sync_fp16_grads_to_fp32() + + grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm( + 0, aggregate_norm_fn + ) + + if self.scaler is not None: + if grad_norm > max_norm > 0.0: + self._multiply_factor *= max_norm / grad_norm + + self.scaler.check_overflow(grad_norm) + elif max_norm > 0.0: + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) + self._multiply_factor *= clip_coef + + return grad_norm + + def step(self, closure=None): + """Performs a single optimization step.""" + self._sync_fp16_grads_to_fp32() + + if getattr(self, "supports_step_with_scale", False): + self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor)) + else: + self._unscale_grads() + self.fp32_optimizer.step(closure) + + if self.scaler is not None: + self.scaler.update() + + self._sync_fp32_params_to_fp16() + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + for p in self.fp16_params: + p.grad = None + if self.has_flat_params: + if torch.is_tensor(self.fp32_params): + self.fp32_params.grad.zero_() + elif isinstance(self.fp32_params, dict): + for fp32_params in self.fp32_params.values(): + fp32_params.grad.zero_() + else: + raise ("self.fp32_params must be a tensor or dict") + else: + for p32 in self.fp32_params: + if p32.grad is None: + p32.grad.zero_() + self._needs_sync = False + + if self.scaler is not None: + self._multiply_factor = 1.0 / float(self.scaler.loss_scale) + + +class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer): + """ + Wrap an *optimizer* to support FP16 (mixed precision) training. + """ + + def __init__(self, args, params, fp32_optimizer, fp32_params): + super().__init__(args) + self.fp16_params = params + self.fp32_optimizer = fp32_optimizer + self.fp32_params = fp32_params + + if getattr(args, "fp16_scale_window", None) is None: + if len(args.update_freq) > 1: + raise ValueError( + "--fp16-scale-window must be given explicitly when using a " + "custom --update-freq schedule" + ) + data_parallel_size = int( + args.distributed_world_size / args.model_parallel_size + ) + scale_window = int(2 ** 14 / data_parallel_size / args.update_freq[0]) + else: + scale_window = args.fp16_scale_window + + if not getattr(args, "bf16", False): + self.scaler = DynamicLossScaler( + init_scale=args.fp16_init_scale, + scale_window=scale_window, + tolerance=args.fp16_scale_tolerance, + threshold=args.threshold_loss_scale, + min_loss_scale=args.min_loss_scale, + ) + else: + # disable loss scaling for bfloat16 + self.scaler = None + + @classmethod + def build_optimizer(cls, args, params): + """ + Args: + args (argparse.Namespace): fairseq args + params (iterable): iterable of parameters to optimize + """ + flatten = not getattr(args, "fp16_no_flatten_grads", False) + if getattr(args, "bf16", False): + flatten = False # mixed precision is faster on TPUs without flat grads + fp32_params = cls.build_fp32_params(args, params, flatten=flatten) + if flatten: + fp32_optimizer = optim.build_optimizer(args, [fp32_params]) + else: + fp32_optimizer = optim.build_optimizer(args, fp32_params) + if flatten and not fp32_optimizer.supports_flat_params: + raise RuntimeError( + "chosen optimizer does not support flat params, " + "please set --fp16-no-flatten-grads" + ) + return cls(args, params, fp32_optimizer, fp32_params) + + @property + def optimizer(self): + return self.fp32_optimizer.optimizer + + @optimizer.setter + def optimizer(self, optimizer): + self.fp32_optimizer.optimizer = optimizer + + @property + def optimizer_config(self): + return self.fp32_optimizer.optimizer_config + + def get_lr(self): + return self.fp32_optimizer.get_lr() + + def set_lr(self, lr): + self.fp32_optimizer.set_lr(lr) + + +class _MemoryEfficientFP16OptimizerMixin(object): + def __init__(self, *args, **kwargs): + # forward __init__ call to the next class in MRO (method resolution order) + super().__init__(*args, **kwargs) + self._multiply_factor = 1.0 + + @property + def has_flat_params(self): + return False + + def state_dict(self): + """Return the optimizer's state dict.""" + state_dict = self.wrapped_optimizer.state_dict() + if self.scaler is not None: + state_dict["loss_scale"] = self.scaler.loss_scale + return state_dict + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + if "loss_scale" in state_dict and self.scaler is not None: + self.scaler.loss_scale = state_dict["loss_scale"] + + self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides) + + # Hack: PyTorch automatically casts the optimizer state to match the + # type of the current parameters. But with --memory-efficient-fp16 the + # params are FP16 while the optimizer state is FP32 and we don't want + # to cast. A workaround is to manually copy back the original state + # after the optimizer has been loaded. + if not getattr(self.optimizer, "disable_mem_eff_fp16_loading_hack", False): + groups = self.optimizer.param_groups + saved_groups = state_dict["param_groups"] + id_map = { + old_id: p + for old_id, p in zip( + chain(*(g["params"] for g in saved_groups)), + chain(*(g["params"] for g in groups)), + ) + } + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + self.optimizer.state[param] = v + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves. + + Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this + function additionally dynamically scales the loss to avoid gradient + underflow. + """ + if self.scaler is not None: + loss = self.scaler.scale(loss) + loss.backward() + + def _unscale_grads(self): + if self._multiply_factor != 1.0: + self.wrapped_optimizer.multiply_grads(self._multiply_factor) + self._multiply_factor = 1.0 + + def multiply_grads(self, c): + """Multiplies grads by a constant *c*.""" + self._multiply_factor *= c + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm and updates dynamic loss scaler.""" + max_norm = float(max_norm) + grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm( + 0, aggregate_norm_fn + ) + + if self.scaler is not None: + grad_norm_cpu = float(grad_norm) + if grad_norm_cpu > max_norm > 0.0: + self._multiply_factor *= max_norm / grad_norm_cpu + + # detect overflow and adjust loss scale + self.scaler.check_overflow(grad_norm_cpu) + elif max_norm > 0.0: + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) + self._multiply_factor *= clip_coef + + return grad_norm + + def step(self, closure=None): + """Performs a single optimization step.""" + if getattr(self, "supports_step_with_scale", False): + # NOTE(msb) optimizer divides by scale factor + self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor)) + else: + self._unscale_grads() + self.wrapped_optimizer.step(closure) + + if self.scaler is not None: + self.scaler.update() + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + self.wrapped_optimizer.zero_grad() + if self.scaler is not None: + self._multiply_factor = 1.0 / float(self.scaler.loss_scale) + else: + self._multiply_factor = 1.0 + + +class MemoryEfficientFP16Optimizer( + _MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer +): + """ + Wrap an *optimizer* to support FP16 (mixed precision) training. + + Compared to :class:`fairseq.optim.FP16Optimizer`, this version does not + maintain an FP32 copy of the model. We instead expect the optimizer to + convert the gradients to FP32 internally and sync the results back to the + FP16 model params. This significantly reduces memory usage but slightly + increases the time spent in the optimizer. + + Since this wrapper depends on specific functionality in the wrapped + optimizer (i.e., on-the-fly conversion of grads to FP32), only certain + optimizers can be wrapped. This is determined by the + *supports_memory_efficient_fp16* property. + """ + + def __init__(self, args, params, optimizer): + if not optimizer.supports_memory_efficient_fp16: + raise ValueError( + "Unsupported optimizer: {}".format(optimizer.__class__.__name__) + ) + + super().__init__(args) + self.wrapped_optimizer = optimizer + + if getattr(args, "fp16_scale_window", None) is None: + if len(args.update_freq) > 1: + raise ValueError( + "--fp16-scale-window must be given explicitly when using a " + "custom --update-freq schedule" + ) + data_parallel_size = int( + args.distributed_world_size / args.model_parallel_size + ) + scale_window = 2 ** 14 / data_parallel_size / args.update_freq[0] + else: + scale_window = args.fp16_scale_window + + if not getattr(args, "bf16", False): + self.scaler = DynamicLossScaler( + init_scale=args.fp16_init_scale, + scale_window=scale_window, + tolerance=args.fp16_scale_tolerance, + threshold=args.threshold_loss_scale, + min_loss_scale=args.min_loss_scale, + ) + else: + # disable loss scaling for bfloat16 + self.scaler = None + + @classmethod + def build_optimizer(cls, args, params): + """ + Args: + args (argparse.Namespace): fairseq args + params (iterable): iterable of parameters to optimize + """ + fp16_optimizer = optim.build_optimizer(args, params) + return cls(args, params, fp16_optimizer) + + @property + def optimizer(self): + return self.wrapped_optimizer.optimizer + + @optimizer.setter + def optimizer(self, optimizer): + self.wrapped_optimizer.optimizer = optimizer + + @property + def optimizer_config(self): + return self.wrapped_optimizer.optimizer_config + + def get_lr(self): + return self.wrapped_optimizer.get_lr() + + def set_lr(self, lr): + self.wrapped_optimizer.set_lr(lr) diff --git a/fairseq-0.10.2/fairseq/optim/fused_lamb.py b/fairseq-0.10.2/fairseq/optim/fused_lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..f4f2bdb0c6c65f7758509b6d4d2f2c48cb6e8b4f --- /dev/null +++ b/fairseq-0.10.2/fairseq/optim/fused_lamb.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.optim import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("lamb") +class FairseqLAMB(LegacyFairseqOptimizer): + """LAMB optimizer.""" + + def __init__(self, args, params): + super().__init__(args) + try: + from apex.optimizers import FusedLAMB + + self._optimizer = FusedLAMB(params, **self.optimizer_config) + except ImportError: + raise ImportError("Please install apex to use LAMB optimizer") + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--lamb-betas', default='(0.9, 0.999)', metavar='B', + help='betas for LAMB optimizer') + parser.add_argument('--lamb-eps', type=float, default=1e-8, metavar='D', + help='epsilon for LAMB optimizer') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "betas": eval(self.args.lamb_betas), + "eps": self.args.lamb_eps, + "weight_decay": self.args.weight_decay, + } + + @property + def supports_flat_params(self): + return False diff --git a/fairseq-0.10.2/fairseq/optim/shard.py b/fairseq-0.10.2/fairseq/optim/shard.py new file mode 100644 index 0000000000000000000000000000000000000000..a035a1c1f93d3ce3eae93f8d80216af8ecf9620a --- /dev/null +++ b/fairseq-0.10.2/fairseq/optim/shard.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +try: + from fairscale.optim import OSS + + _has_fairscale = True +except ImportError: + _has_fairscale = False + + +def shard_(args, optimizer, group): + if not _has_fairscale: + raise ImportError( + "\n\nPlease install the fairscale package:" "\n\n pip install fairscale" + ) + + class FairseqOSS(OSS): + @property + def disable_mem_eff_fp16_loading_hack(self): + return True + + def __getattr__(self, name): + if name.startswith("supports") and hasattr(self.optim, name): + return getattr(self.optim, name) + raise AttributeError( + "'FairseqOSS' object has no attribute {0!r}".format(name) + ) + + torch_optimizer = optimizer.optimizer + optim_cls = type(torch_optimizer) + + optimizer.optimizer = FairseqOSS( + torch_optimizer.param_groups, + optim_cls, + group=group, + **optimizer.optimizer_config + )