Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import numpy as np | |
| from fairseq.data import BaseWrapperDataset, plasma_utils | |
| logger = logging.getLogger(__name__) | |
| class ResamplingDataset(BaseWrapperDataset): | |
| """Randomly samples from a given dataset at each epoch. | |
| Sampling is done with or without replacement, depending on the "replace" | |
| parameter. | |
| Optionally, the epoch size can be rescaled. This is potentially desirable | |
| to increase per-epoch coverage of the base dataset (since sampling with | |
| replacement means that many items in the dataset will be left out). In the | |
| case of sampling without replacement, size_ratio should be strictly less | |
| than 1. | |
| Args: | |
| dataset (~torch.utils.data.Dataset): dataset on which to sample. | |
| weights (List[float]): list of probability weights | |
| (default: None, which corresponds to uniform sampling). | |
| replace (bool): sampling mode; True for "with replacement", or False | |
| for "without replacement" (default: True) | |
| size_ratio (float): the ratio to subsample to; must be positive | |
| (default: 1.0). | |
| batch_by_size (bool): whether or not to batch by sequence length | |
| (default: True). | |
| seed (int): RNG seed to use (default: 0). | |
| epoch (int): starting epoch number (default: 1). | |
| """ | |
| def __init__( | |
| self, | |
| dataset, | |
| weights=None, | |
| replace=True, | |
| size_ratio=1.0, | |
| batch_by_size=True, | |
| seed=0, | |
| epoch=1, | |
| ): | |
| super().__init__(dataset) | |
| if weights is None: | |
| self.weights = None | |
| else: | |
| assert len(weights) == len(dataset) | |
| weights_arr = np.array(weights, dtype=np.float64) | |
| weights_arr /= weights_arr.sum() | |
| self.weights = plasma_utils.PlasmaArray(weights_arr) | |
| self.replace = replace | |
| assert size_ratio > 0.0 | |
| if not self.replace: | |
| assert size_ratio < 1.0 | |
| self.size_ratio = float(size_ratio) | |
| self.actual_size = np.ceil(len(dataset) * self.size_ratio).astype(int) | |
| self.batch_by_size = batch_by_size | |
| self.seed = seed | |
| self._cur_epoch = None | |
| self._cur_indices = None | |
| self.set_epoch(epoch) | |
| def __getitem__(self, index): | |
| return self.dataset[self._cur_indices.array[index]] | |
| def __len__(self): | |
| return self.actual_size | |
| def sizes(self): | |
| if isinstance(self.dataset.sizes, list): | |
| return [s[self._cur_indices.array] for s in self.dataset.sizes] | |
| return self.dataset.sizes[self._cur_indices.array] | |
| def num_tokens(self, index): | |
| return self.dataset.num_tokens(self._cur_indices.array[index]) | |
| def size(self, index): | |
| return self.dataset.size(self._cur_indices.array[index]) | |
| def ordered_indices(self): | |
| if self.batch_by_size: | |
| order = [ | |
| np.arange(len(self)), | |
| self.sizes, | |
| ] # No need to handle `self.shuffle == True` | |
| return np.lexsort(order) | |
| else: | |
| return np.arange(len(self)) | |
| def prefetch(self, indices): | |
| self.dataset.prefetch(self._cur_indices.array[indices]) | |
| def can_reuse_epoch_itr_across_epochs(self): | |
| return False | |
| def set_epoch(self, epoch): | |
| logger.debug("ResamplingDataset.set_epoch: {}".format(epoch)) | |
| super().set_epoch(epoch) | |
| if epoch == self._cur_epoch: | |
| return | |
| self._cur_epoch = epoch | |
| # Generate a weighted sample of indices as a function of the | |
| # random seed and the current epoch. | |
| rng = np.random.RandomState( | |
| [ | |
| 42, # magic number | |
| self.seed % (2 ** 32), # global seed | |
| self._cur_epoch, # epoch index | |
| ] | |
| ) | |
| self._cur_indices = plasma_utils.PlasmaArray( | |
| rng.choice( | |
| len(self.dataset), | |
| self.actual_size, | |
| replace=self.replace, | |
| p=(None if self.weights is None else self.weights.array), | |
| ) | |
| ) | |