Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import numpy as np | |
| import torch.utils.data | |
| from fairseq.data import data_utils | |
| logger = logging.getLogger(__name__) | |
| class EpochListening: | |
| """Mixin for receiving updates whenever the epoch increments.""" | |
| def can_reuse_epoch_itr_across_epochs(self): | |
| """ | |
| Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for | |
| this dataset across epochs. | |
| This needs to return ``False`` if the sample sizes can change across | |
| epochs, in which case we may need to regenerate batches at each epoch. | |
| If your dataset relies in ``set_epoch`` then you should consider setting | |
| this to ``False``. | |
| """ | |
| return True | |
| def set_epoch(self, epoch): | |
| """Will receive the updated epoch number at the beginning of the epoch.""" | |
| pass | |
| class FairseqDataset(torch.utils.data.Dataset, EpochListening): | |
| """A dataset that provides helpers for batching.""" | |
| def __getitem__(self, index): | |
| raise NotImplementedError | |
| def __len__(self): | |
| raise NotImplementedError | |
| def collater(self, samples): | |
| """Merge a list of samples to form a mini-batch. | |
| Args: | |
| samples (List[dict]): samples to collate | |
| Returns: | |
| dict: a mini-batch suitable for forwarding with a Model | |
| """ | |
| raise NotImplementedError | |
| def num_tokens(self, index): | |
| """Return the number of tokens in a sample. This value is used to | |
| enforce ``--max-tokens`` during batching.""" | |
| raise NotImplementedError | |
| def num_tokens_vec(self, indices): | |
| """Return the number of tokens for a set of positions defined by indices. | |
| This value is used to enforce ``--max-tokens`` during batching.""" | |
| raise NotImplementedError | |
| def size(self, index): | |
| """Return an example's size as a float or tuple. This value is used when | |
| filtering a dataset with ``--max-positions``.""" | |
| raise NotImplementedError | |
| def ordered_indices(self): | |
| """Return an ordered list of indices. Batches will be constructed based | |
| on this order.""" | |
| return np.arange(len(self), dtype=np.int64) | |
| def supports_prefetch(self): | |
| """Whether this dataset supports prefetching.""" | |
| return False | |
| def attr(self, attr: str, index: int): | |
| return getattr(self, attr, None) | |
| def prefetch(self, indices): | |
| """Prefetch the data required for this epoch.""" | |
| raise NotImplementedError | |
| def get_batch_shapes(self): | |
| """ | |
| Return a list of valid batch shapes, for example:: | |
| [(8, 512), (16, 256), (32, 128)] | |
| The first dimension of each tuple is the batch size and can be ``None`` | |
| to automatically infer the max batch size based on ``--max-tokens``. | |
| The second dimension of each tuple is the max supported length as given | |
| by :func:`fairseq.data.FairseqDataset.num_tokens`. | |
| This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size` | |
| to restrict batch shapes. This is useful on TPUs to avoid too many | |
| dynamic shapes (and recompilations). | |
| """ | |
| return None | |
| def batch_by_size( | |
| self, | |
| indices, | |
| max_tokens=None, | |
| max_sentences=None, | |
| required_batch_size_multiple=1, | |
| ): | |
| """ | |
| Given an ordered set of indices, return batches according to | |
| *max_tokens*, *max_sentences* and *required_batch_size_multiple*. | |
| """ | |
| from fairseq.data import data_utils | |
| fixed_shapes = self.get_batch_shapes() | |
| if fixed_shapes is not None: | |
| def adjust_bsz(bsz, num_tokens): | |
| if bsz is None: | |
| assert max_tokens is not None, "Must specify --max-tokens" | |
| bsz = max_tokens // num_tokens | |
| if max_sentences is not None: | |
| bsz = min(bsz, max_sentences) | |
| elif ( | |
| bsz >= required_batch_size_multiple | |
| and bsz % required_batch_size_multiple != 0 | |
| ): | |
| bsz -= bsz % required_batch_size_multiple | |
| return bsz | |
| fixed_shapes = np.array( | |
| [ | |
| [adjust_bsz(bsz, num_tokens), num_tokens] | |
| for (bsz, num_tokens) in fixed_shapes | |
| ] | |
| ) | |
| try: | |
| num_tokens_vec = self.num_tokens_vec(indices).astype('int64') | |
| except NotImplementedError: | |
| num_tokens_vec = None | |
| return data_utils.batch_by_size( | |
| indices, | |
| num_tokens_fn=self.num_tokens, | |
| num_tokens_vec=num_tokens_vec, | |
| max_tokens=max_tokens, | |
| max_sentences=max_sentences, | |
| required_batch_size_multiple=required_batch_size_multiple, | |
| fixed_shapes=fixed_shapes, | |
| ) | |
| def filter_indices_by_size(self, indices, max_sizes): | |
| """ | |
| Filter a list of sample indices. Remove those that are longer than | |
| specified in *max_sizes*. | |
| WARNING: don't update, override method in child classes | |
| Args: | |
| indices (np.array): original array of sample indices | |
| max_sizes (int or list[int] or tuple[int]): max sample size, | |
| can be defined separately for src and tgt (then list or tuple) | |
| Returns: | |
| np.array: filtered sample array | |
| list: list of removed indices | |
| """ | |
| if isinstance(max_sizes, float) or isinstance(max_sizes, int): | |
| if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray): | |
| ignored = indices[self.sizes[indices] > max_sizes].tolist() | |
| indices = indices[self.sizes[indices] <= max_sizes] | |
| elif ( | |
| hasattr(self, "sizes") | |
| and isinstance(self.sizes, list) | |
| and len(self.sizes) == 1 | |
| ): | |
| ignored = indices[self.sizes[0][indices] > max_sizes].tolist() | |
| indices = indices[self.sizes[0][indices] <= max_sizes] | |
| else: | |
| indices, ignored = data_utils._filter_by_size_dynamic( | |
| indices, self.size, max_sizes | |
| ) | |
| else: | |
| indices, ignored = data_utils._filter_by_size_dynamic( | |
| indices, self.size, max_sizes | |
| ) | |
| return indices, ignored | |
| def supports_fetch_outside_dataloader(self): | |
| """Whether this dataset supports fetching outside the workers of the dataloader.""" | |
| return True | |
| class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): | |
| """ | |
| For datasets that need to be read sequentially, usually because the data is | |
| being streamed or otherwise can't be manipulated on a single machine. | |
| """ | |
| def __iter__(self): | |
| raise NotImplementedError | |