| |
| |
| |
| |
|
|
| import logging |
| import numpy as np |
| import torch.utils.data |
| from fairseq.data import data_utils |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class EpochListening: |
| """Mixin for receiving updates whenever the epoch increments.""" |
|
|
| @property |
| def can_reuse_epoch_itr_across_epochs(self): |
| """ |
| Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for |
| this dataset across epochs. |
| |
| This needs to return ``False`` if the sample sizes can change across |
| epochs, in which case we may need to regenerate batches at each epoch. |
| If your dataset relies in ``set_epoch`` then you should consider setting |
| this to ``False``. |
| """ |
| return True |
|
|
| def set_epoch(self, epoch): |
| """Will receive the updated epoch number at the beginning of the epoch.""" |
| pass |
|
|
|
|
| class FairseqDataset(torch.utils.data.Dataset, EpochListening): |
| """A dataset that provides helpers for batching.""" |
|
|
| def __getitem__(self, index): |
| raise NotImplementedError |
|
|
| def __len__(self): |
| raise NotImplementedError |
|
|
| def collater(self, samples): |
| """Merge a list of samples to form a mini-batch. |
| |
| Args: |
| samples (List[dict]): samples to collate |
| |
| Returns: |
| dict: a mini-batch suitable for forwarding with a Model |
| """ |
| raise NotImplementedError |
|
|
| def num_tokens(self, index): |
| """Return the number of tokens in a sample. This value is used to |
| enforce ``--max-tokens`` during batching.""" |
| raise NotImplementedError |
|
|
| def num_tokens_vec(self, indices): |
| """Return the number of tokens for a set of positions defined by indices. |
| This value is used to enforce ``--max-tokens`` during batching.""" |
| raise NotImplementedError |
|
|
| def size(self, index): |
| """Return an example's size as a float or tuple. This value is used when |
| filtering a dataset with ``--max-positions``.""" |
| raise NotImplementedError |
|
|
| def ordered_indices(self): |
| """Return an ordered list of indices. Batches will be constructed based |
| on this order.""" |
| return np.arange(len(self), dtype=np.int64) |
|
|
| @property |
| def supports_prefetch(self): |
| """Whether this dataset supports prefetching.""" |
| return False |
|
|
| def attr(self, attr: str, index: int): |
| return getattr(self, attr, None) |
|
|
| def prefetch(self, indices): |
| """Prefetch the data required for this epoch.""" |
| raise NotImplementedError |
|
|
| def get_batch_shapes(self): |
| """ |
| Return a list of valid batch shapes, for example:: |
| |
| [(8, 512), (16, 256), (32, 128)] |
| |
| The first dimension of each tuple is the batch size and can be ``None`` |
| to automatically infer the max batch size based on ``--max-tokens``. |
| The second dimension of each tuple is the max supported length as given |
| by :func:`fairseq.data.FairseqDataset.num_tokens`. |
| |
| This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size` |
| to restrict batch shapes. This is useful on TPUs to avoid too many |
| dynamic shapes (and recompilations). |
| """ |
| return None |
|
|
| def batch_by_size( |
| self, |
| indices, |
| max_tokens=None, |
| max_sentences=None, |
| required_batch_size_multiple=1, |
| ): |
| """ |
| Given an ordered set of indices, return batches according to |
| *max_tokens*, *max_sentences* and *required_batch_size_multiple*. |
| """ |
| from fairseq.data import data_utils |
|
|
| fixed_shapes = self.get_batch_shapes() |
| if fixed_shapes is not None: |
|
|
| def adjust_bsz(bsz, num_tokens): |
| if bsz is None: |
| assert max_tokens is not None, "Must specify --max-tokens" |
| bsz = max_tokens // num_tokens |
| if max_sentences is not None: |
| bsz = min(bsz, max_sentences) |
| elif ( |
| bsz >= required_batch_size_multiple |
| and bsz % required_batch_size_multiple != 0 |
| ): |
| bsz -= bsz % required_batch_size_multiple |
| return bsz |
|
|
| fixed_shapes = np.array( |
| [ |
| [adjust_bsz(bsz, num_tokens), num_tokens] |
| for (bsz, num_tokens) in fixed_shapes |
| ] |
| ) |
|
|
| try: |
| num_tokens_vec = self.num_tokens_vec(indices).astype('int64') |
| except NotImplementedError: |
| num_tokens_vec = None |
|
|
| return data_utils.batch_by_size( |
| indices, |
| num_tokens_fn=self.num_tokens, |
| num_tokens_vec=num_tokens_vec, |
| max_tokens=max_tokens, |
| max_sentences=max_sentences, |
| required_batch_size_multiple=required_batch_size_multiple, |
| fixed_shapes=fixed_shapes, |
| ) |
|
|
| def filter_indices_by_size(self, indices, max_sizes): |
| """ |
| Filter a list of sample indices. Remove those that are longer than |
| specified in *max_sizes*. |
| |
| WARNING: don't update, override method in child classes |
| |
| Args: |
| indices (np.array): original array of sample indices |
| max_sizes (int or list[int] or tuple[int]): max sample size, |
| can be defined separately for src and tgt (then list or tuple) |
| |
| Returns: |
| np.array: filtered sample array |
| list: list of removed indices |
| """ |
| if isinstance(max_sizes, float) or isinstance(max_sizes, int): |
| if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray): |
| ignored = indices[self.sizes[indices] > max_sizes].tolist() |
| indices = indices[self.sizes[indices] <= max_sizes] |
| elif ( |
| hasattr(self, "sizes") |
| and isinstance(self.sizes, list) |
| and len(self.sizes) == 1 |
| ): |
| ignored = indices[self.sizes[0][indices] > max_sizes].tolist() |
| indices = indices[self.sizes[0][indices] <= max_sizes] |
| else: |
| indices, ignored = data_utils._filter_by_size_dynamic( |
| indices, self.size, max_sizes |
| ) |
| else: |
| indices, ignored = data_utils._filter_by_size_dynamic( |
| indices, self.size, max_sizes |
| ) |
| return indices, ignored |
|
|
| @property |
| def supports_fetch_outside_dataloader(self): |
| """Whether this dataset supports fetching outside the workers of the dataloader.""" |
| return True |
|
|
|
|
| class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): |
| """ |
| For datasets that need to be read sequentially, usually because the data is |
| being streamed or otherwise can't be manipulated on a single machine. |
| """ |
|
|
| def __iter__(self): |
| raise NotImplementedError |
|
|