| | |
| | |
| | |
| | |
| |
|
| | try: |
| | from collections.abc import Iterable |
| | except ImportError: |
| | from collections import Iterable |
| | import contextlib |
| | import itertools |
| | import logging |
| | import re |
| | import warnings |
| | from typing import Optional, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from fairseq.file_io import PathManager |
| | from fairseq import utils |
| | import os |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def infer_language_pair(path): |
| | """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx""" |
| | src, dst = None, None |
| | for filename in PathManager.ls(path): |
| | parts = filename.split(".") |
| | if len(parts) >= 3 and len(parts[1].split("-")) == 2: |
| | return parts[1].split("-") |
| | return src, dst |
| |
|
| |
|
| | def collate_tokens( |
| | values, |
| | pad_idx, |
| | eos_idx=None, |
| | left_pad=False, |
| | move_eos_to_beginning=False, |
| | pad_to_length=None, |
| | pad_to_multiple=1, |
| | pad_to_bsz=None, |
| | ): |
| | """Convert a list of 1d tensors into a padded 2d tensor.""" |
| | size = max(v.size(0) for v in values) |
| | size = size if pad_to_length is None else max(size, pad_to_length) |
| | if pad_to_multiple != 1 and size % pad_to_multiple != 0: |
| | size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) |
| |
|
| | batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz) |
| | res = values[0].new(batch_size, size).fill_(pad_idx) |
| |
|
| | def copy_tensor(src, dst): |
| | assert dst.numel() == src.numel() |
| | if move_eos_to_beginning: |
| | if eos_idx is None: |
| | |
| | dst[0] = src[-1] |
| | else: |
| | dst[0] = eos_idx |
| | dst[1:] = src[:-1] |
| | else: |
| | dst.copy_(src) |
| |
|
| | for i, v in enumerate(values): |
| | copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) |
| | return res |
| |
|
| |
|
| | def load_indexed_dataset( |
| | path, dictionary=None, dataset_impl=None, combine=False, default="cached" |
| | ): |
| | """A helper function for loading indexed datasets. |
| | |
| | Args: |
| | path (str): path to indexed dataset (e.g., 'data-bin/train') |
| | dictionary (~fairseq.data.Dictionary): data dictionary |
| | dataset_impl (str, optional): which dataset implementation to use. If |
| | not provided, it will be inferred automatically. For legacy indexed |
| | data we use the 'cached' implementation by default. |
| | combine (bool, optional): automatically load and combine multiple |
| | datasets. For example, if *path* is 'data-bin/train', then we will |
| | combine 'data-bin/train', 'data-bin/train1', ... and return a |
| | single ConcatDataset instance. |
| | """ |
| | import fairseq.data.indexed_dataset as indexed_dataset |
| | from fairseq.data.concat_dataset import ConcatDataset |
| |
|
| | datasets = [] |
| | for k in itertools.count(): |
| | path_k = path + (str(k) if k > 0 else "") |
| | try: |
| | path_k = indexed_dataset.get_indexed_dataset_to_local(path_k) |
| | except Exception as e: |
| | if "StorageException: [404] Path not found" in str(e): |
| | logger.warning(f"path_k: {e} not found") |
| | else: |
| | raise e |
| |
|
| | dataset_impl_k = dataset_impl |
| | if dataset_impl_k is None: |
| | dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k) |
| | dataset = indexed_dataset.make_dataset( |
| | path_k, |
| | impl=dataset_impl_k or default, |
| | fix_lua_indexing=True, |
| | dictionary=dictionary, |
| | ) |
| | if dataset is None: |
| | break |
| | logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k)) |
| | datasets.append(dataset) |
| | if not combine: |
| | break |
| | if len(datasets) == 0: |
| | return None |
| | elif len(datasets) == 1: |
| | return datasets[0] |
| | else: |
| | return ConcatDataset(datasets) |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def numpy_seed(seed, *addl_seeds): |
| | """Context manager which seeds the NumPy PRNG with the specified seed and |
| | restores the state afterward""" |
| | if seed is None: |
| | yield |
| | return |
| | if len(addl_seeds) > 0: |
| | seed = int(hash((seed, *addl_seeds)) % 1e6) |
| | state = np.random.get_state() |
| | np.random.seed(seed) |
| | try: |
| | yield |
| | finally: |
| | np.random.set_state(state) |
| |
|
| |
|
| | def collect_filtered(function, iterable, filtered): |
| | """ |
| | Similar to :func:`filter` but collects filtered elements in ``filtered``. |
| | |
| | Args: |
| | function (callable): function that returns ``False`` for elements that |
| | should be filtered |
| | iterable (iterable): iterable to filter |
| | filtered (list): list to store filtered elements |
| | """ |
| | for el in iterable: |
| | if function(el): |
| | yield el |
| | else: |
| | filtered.append(el) |
| |
|
| |
|
| | def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False): |
| | def compare_leq(a, b): |
| | return a <= b if not isinstance(a, tuple) else max(a) <= b |
| |
|
| | def check_size(idx): |
| | if isinstance(max_positions, float) or isinstance(max_positions, int): |
| | return size_fn(idx) <= max_positions |
| | elif isinstance(max_positions, dict): |
| | idx_size = size_fn(idx) |
| | assert isinstance(idx_size, dict) |
| | intersect_keys = set(max_positions.keys()) & set(idx_size.keys()) |
| | return all( |
| | all( |
| | a is None or b is None or a <= b |
| | for a, b in zip(idx_size[key], max_positions[key]) |
| | ) |
| | for key in intersect_keys |
| | ) |
| | else: |
| | |
| | if not isinstance(size_fn(idx), Iterable): |
| | return all(size_fn(idx) <= b for b in max_positions) |
| | return all( |
| | a is None or b is None or a <= b |
| | for a, b in zip(size_fn(idx), max_positions) |
| | ) |
| |
|
| | ignored = [] |
| | itr = collect_filtered(check_size, indices, ignored) |
| | indices = np.fromiter(itr, dtype=np.int64, count=-1) |
| | return indices, ignored |
| |
|
| |
|
| | def filter_by_size(indices, dataset, max_positions, raise_exception=False): |
| | """ |
| | [deprecated] Filter indices based on their size. |
| | Use `FairseqDataset::filter_indices_by_size` instead. |
| | |
| | Args: |
| | indices (List[int]): ordered list of dataset indices |
| | dataset (FairseqDataset): fairseq dataset instance |
| | max_positions (tuple): filter elements larger than this size. |
| | Comparisons are done component-wise. |
| | raise_exception (bool, optional): if ``True``, raise an exception if |
| | any elements are filtered (default: False). |
| | """ |
| | warnings.warn( |
| | "data_utils.filter_by_size is deprecated. " |
| | "Use `FairseqDataset::filter_indices_by_size` instead.", |
| | stacklevel=2, |
| | ) |
| | if isinstance(max_positions, float) or isinstance(max_positions, int): |
| | if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray): |
| | ignored = indices[dataset.sizes[indices] > max_positions].tolist() |
| | indices = indices[dataset.sizes[indices] <= max_positions] |
| | elif ( |
| | hasattr(dataset, "sizes") |
| | and isinstance(dataset.sizes, list) |
| | and len(dataset.sizes) == 1 |
| | ): |
| | ignored = indices[dataset.sizes[0][indices] > max_positions].tolist() |
| | indices = indices[dataset.sizes[0][indices] <= max_positions] |
| | else: |
| | indices, ignored = _filter_by_size_dynamic( |
| | indices, dataset.size, max_positions |
| | ) |
| | else: |
| | indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions) |
| |
|
| | if len(ignored) > 0 and raise_exception: |
| | raise Exception( |
| | ( |
| | "Size of sample #{} is invalid (={}) since max_positions={}, " |
| | "skip this example with --skip-invalid-size-inputs-valid-test" |
| | ).format(ignored[0], dataset.size(ignored[0]), max_positions) |
| | ) |
| | if len(ignored) > 0: |
| | logger.warning( |
| | ( |
| | "{} samples have invalid sizes and will be skipped, " |
| | "max_positions={}, first few sample ids={}" |
| | ).format(len(ignored), max_positions, ignored[:10]) |
| | ) |
| | return indices |
| |
|
| |
|
| | def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes): |
| | """Filter a list of sample indices. Remove those that are longer |
| | than specified in max_sizes. |
| | |
| | Args: |
| | indices (np.array): original array of sample indices |
| | max_sizes (int or list[int] or tuple[int]): max sample size, |
| | can be defined separately for src and tgt (then list or tuple) |
| | |
| | Returns: |
| | np.array: filtered sample array |
| | list: list of removed indices |
| | """ |
| | if max_sizes is None: |
| | return indices, [] |
| | if type(max_sizes) in (int, float): |
| | max_src_size, max_tgt_size = max_sizes, max_sizes |
| | else: |
| | max_src_size, max_tgt_size = max_sizes |
| | if tgt_sizes is None: |
| | ignored = indices[src_sizes[indices] > max_src_size] |
| | else: |
| | ignored = indices[ |
| | (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size) |
| | ] |
| | if len(ignored) > 0: |
| | if tgt_sizes is None: |
| | indices = indices[src_sizes[indices] <= max_src_size] |
| | else: |
| | indices = indices[ |
| | (src_sizes[indices] <= max_src_size) |
| | & (tgt_sizes[indices] <= max_tgt_size) |
| | ] |
| | return indices, ignored.tolist() |
| |
|
| |
|
| | def batch_by_size( |
| | indices, |
| | num_tokens_fn, |
| | num_tokens_vec=None, |
| | max_tokens=None, |
| | max_sentences=None, |
| | required_batch_size_multiple=1, |
| | fixed_shapes=None, |
| | ): |
| | """ |
| | Yield mini-batches of indices bucketed by size. Batches may contain |
| | sequences of different lengths. |
| | |
| | Args: |
| | indices (List[int]): ordered list of dataset indices |
| | num_tokens_fn (callable): function that returns the number of tokens at |
| | a given index |
| | num_tokens_vec (List[int], optional): precomputed vector of the number |
| | of tokens for each index in indices (to enable faster batch generation) |
| | max_tokens (int, optional): max number of tokens in each batch |
| | (default: None). |
| | max_sentences (int, optional): max number of sentences in each |
| | batch (default: None). |
| | required_batch_size_multiple (int, optional): require batch size to |
| | be less than N or a multiple of N (default: 1). |
| | fixed_shapes (List[Tuple[int, int]], optional): if given, batches will |
| | only be created with the given shapes. *max_sentences* and |
| | *required_batch_size_multiple* will be ignored (default: None). |
| | """ |
| | try: |
| | from fairseq.data.data_utils_fast import ( |
| | batch_by_size_fn, |
| | batch_by_size_vec, |
| | batch_fixed_shapes_fast, |
| | ) |
| | except ImportError: |
| | raise ImportError( |
| | "Please build Cython components with: " |
| | "`python setup.py build_ext --inplace`" |
| | ) |
| | except ValueError: |
| | raise ValueError( |
| | "Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`." |
| | ) |
| |
|
| | |
| | max_tokens = int(max_tokens) if max_tokens is not None else -1 |
| | max_sentences = max_sentences if max_sentences is not None else -1 |
| | bsz_mult = required_batch_size_multiple |
| |
|
| | if not isinstance(indices, np.ndarray): |
| | indices = np.fromiter(indices, dtype=np.int64, count=-1) |
| |
|
| | if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray): |
| | num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1) |
| |
|
| | if fixed_shapes is None: |
| | if num_tokens_vec is None: |
| | return batch_by_size_fn( |
| | indices, |
| | num_tokens_fn, |
| | max_tokens, |
| | max_sentences, |
| | bsz_mult, |
| | ) |
| | else: |
| | return batch_by_size_vec( |
| | indices, |
| | num_tokens_vec, |
| | max_tokens, |
| | max_sentences, |
| | bsz_mult, |
| | ) |
| |
|
| | else: |
| | fixed_shapes = np.array(fixed_shapes, dtype=np.int64) |
| | sort_order = np.lexsort( |
| | [ |
| | fixed_shapes[:, 1].argsort(), |
| | fixed_shapes[:, 0].argsort(), |
| | ] |
| | ) |
| | fixed_shapes_sorted = fixed_shapes[sort_order] |
| | return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted) |
| |
|
| |
|
| | def post_process(sentence: str, symbol: str): |
| | if symbol == "sentencepiece": |
| | sentence = sentence.replace(" ", "").replace("\u2581", " ").strip() |
| | elif symbol == "wordpiece": |
| | sentence = sentence.replace(" ", "").replace("_", " ").strip() |
| | elif symbol == "letter": |
| | sentence = sentence.replace(" ", "").replace("|", " ").strip() |
| | elif symbol == "silence": |
| | import re |
| |
|
| | sentence = sentence.replace("<SIL>", "") |
| | sentence = re.sub(" +", " ", sentence).strip() |
| | elif symbol == "_EOW": |
| | sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() |
| | elif symbol in {"subword_nmt", "@@ ", "@@"}: |
| | if symbol == "subword_nmt": |
| | symbol = "@@ " |
| | sentence = (sentence + " ").replace(symbol, "").rstrip() |
| | elif symbol == "none": |
| | pass |
| | elif symbol is not None: |
| | raise NotImplementedError(f"Unknown post_process option: {symbol}") |
| | return sentence |
| |
|
| |
|
| | def compute_mask_indices( |
| | shape: Tuple[int, int], |
| | padding_mask: Optional[torch.Tensor], |
| | mask_prob: float, |
| | mask_length: int, |
| | mask_type: str = "static", |
| | mask_other: float = 0.0, |
| | min_masks: int = 0, |
| | no_overlap: bool = False, |
| | min_space: int = 0, |
| | require_same_masks: bool = True, |
| | mask_dropout: float = 0.0, |
| | ) -> np.ndarray: |
| | """ |
| | Computes random mask spans for a given shape |
| | |
| | Args: |
| | shape: the the shape for which to compute masks. |
| | should be of size 2 where first element is batch size and 2nd is timesteps |
| | padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements |
| | mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by |
| | number of timesteps divided by length of mask span to mask approximately this percentage of all elements. |
| | however due to overlaps, the actual number will be smaller (unless no_overlap is True) |
| | mask_type: how to compute mask lengths |
| | static = fixed size |
| | uniform = sample from uniform distribution [mask_other, mask_length*2] |
| | normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element |
| | poisson = sample from possion distribution with lambda = mask length |
| | min_masks: minimum number of masked spans |
| | no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping |
| | min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans |
| | require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample |
| | mask_dropout: randomly dropout this percentage of masks in each example |
| | """ |
| |
|
| | bsz, all_sz = shape |
| | mask = np.full((bsz, all_sz), False) |
| |
|
| | all_num_mask = int( |
| | |
| | mask_prob * all_sz / float(mask_length) |
| | + np.random.rand() |
| | ) |
| |
|
| | all_num_mask = max(min_masks, all_num_mask) |
| |
|
| | mask_idcs = [] |
| | for i in range(bsz): |
| | if padding_mask is not None: |
| | sz = all_sz - padding_mask[i].long().sum().item() |
| | num_mask = int( |
| | |
| | mask_prob * sz / float(mask_length) |
| | + np.random.rand() |
| | ) |
| | num_mask = max(min_masks, num_mask) |
| | else: |
| | sz = all_sz |
| | num_mask = all_num_mask |
| |
|
| | if mask_type == "static": |
| | lengths = np.full(num_mask, mask_length) |
| | elif mask_type == "uniform": |
| | lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) |
| | elif mask_type == "normal": |
| | lengths = np.random.normal(mask_length, mask_other, size=num_mask) |
| | lengths = [max(1, int(round(x))) for x in lengths] |
| | elif mask_type == "poisson": |
| | lengths = np.random.poisson(mask_length, size=num_mask) |
| | lengths = [int(round(x)) for x in lengths] |
| | else: |
| | raise Exception("unknown mask selection " + mask_type) |
| |
|
| | if sum(lengths) == 0: |
| | lengths[0] = min(mask_length, sz - 1) |
| |
|
| | if no_overlap: |
| | mask_idc = [] |
| |
|
| | def arrange(s, e, length, keep_length): |
| | span_start = np.random.randint(s, e - length) |
| | mask_idc.extend(span_start + i for i in range(length)) |
| |
|
| | new_parts = [] |
| | if span_start - s - min_space >= keep_length: |
| | new_parts.append((s, span_start - min_space + 1)) |
| | if e - span_start - length - min_space > keep_length: |
| | new_parts.append((span_start + length + min_space, e)) |
| | return new_parts |
| |
|
| | parts = [(0, sz)] |
| | min_length = min(lengths) |
| | for length in sorted(lengths, reverse=True): |
| | lens = np.fromiter( |
| | (e - s if e - s >= length + min_space else 0 for s, e in parts), |
| | np.int, |
| | ) |
| | l_sum = np.sum(lens) |
| | if l_sum == 0: |
| | break |
| | probs = lens / np.sum(lens) |
| | c = np.random.choice(len(parts), p=probs) |
| | s, e = parts.pop(c) |
| | parts.extend(arrange(s, e, length, min_length)) |
| | mask_idc = np.asarray(mask_idc) |
| | else: |
| | min_len = min(lengths) |
| | if sz - min_len <= num_mask: |
| | min_len = sz - num_mask - 1 |
| |
|
| | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) |
| |
|
| | mask_idc = np.asarray( |
| | [ |
| | mask_idc[j] + offset |
| | for j in range(len(mask_idc)) |
| | for offset in range(lengths[j]) |
| | ] |
| | ) |
| |
|
| | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) |
| |
|
| | min_len = min([len(m) for m in mask_idcs]) |
| | for i, mask_idc in enumerate(mask_idcs): |
| | if len(mask_idc) > min_len and require_same_masks: |
| | mask_idc = np.random.choice(mask_idc, min_len, replace=False) |
| | if mask_dropout > 0: |
| | num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int) |
| | mask_idc = np.random.choice( |
| | mask_idc, len(mask_idc) - num_holes, replace=False |
| | ) |
| |
|
| | mask[i, mask_idc] = True |
| |
|
| | return mask |
| |
|
| |
|
| | def get_mem_usage(): |
| | try: |
| | import psutil |
| |
|
| | mb = 1024 * 1024 |
| | return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb" |
| | except ImportError: |
| | return "N/A" |
| |
|
| |
|
| | |
| | |
| | def lengths_to_padding_mask(lens): |
| | bsz, max_lens = lens.size(0), torch.max(lens).item() |
| | mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) |
| | mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) |
| | return mask |
| |
|
| |
|
| | |
| | |
| | def lengths_to_mask(lens): |
| | return ~lengths_to_padding_mask(lens) |
| |
|
| |
|
| | def get_buckets(sizes, num_buckets): |
| | buckets = np.unique( |
| | np.percentile( |
| | sizes, |
| | np.linspace(0, 100, num_buckets + 1), |
| | interpolation="lower", |
| | )[1:] |
| | ) |
| | return buckets |
| |
|
| |
|
| | def get_bucketed_sizes(orig_sizes, buckets): |
| | sizes = np.copy(orig_sizes) |
| | assert np.min(sizes) >= 0 |
| | start_val = -1 |
| | for end_val in buckets: |
| | mask = (sizes > start_val) & (sizes <= end_val) |
| | sizes[mask] = end_val |
| | start_val = end_val |
| | return sizes |
| |
|
| |
|
| | def _find_extra_valid_paths(dataset_path: str) -> set: |
| | paths = utils.split_paths(dataset_path) |
| | all_valid_paths = set() |
| | for sub_dir in paths: |
| | contents = PathManager.ls(sub_dir) |
| | valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None] |
| | all_valid_paths |= {os.path.basename(p) for p in valid_paths} |
| | |
| | roots = {os.path.splitext(p)[0] for p in all_valid_paths} |
| | return roots |
| |
|
| |
|
| | def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None: |
| | """Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored.""" |
| | if ( |
| | train_cfg.dataset.ignore_unused_valid_subsets |
| | or train_cfg.dataset.combine_valid_subsets |
| | or train_cfg.dataset.disable_validation |
| | or not hasattr(train_cfg.task, "data") |
| | ): |
| | return |
| | other_paths = _find_extra_valid_paths(train_cfg.task.data) |
| | specified_subsets = train_cfg.dataset.valid_subset.split(",") |
| | ignored_paths = [p for p in other_paths if p not in specified_subsets] |
| | if ignored_paths: |
| | advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them." |
| | msg = f"Valid paths {ignored_paths} will be ignored. {advice}" |
| | raise ValueError(msg) |
| |
|