Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import hashlib | |
| import logging | |
| import math | |
| import numpy as np | |
| from fairseq.data import SampledMultiDataset | |
| from .sampled_multi_dataset import CollateFormat, default_virtual_size_func | |
| logger = logging.getLogger(__name__) | |
| class SampledMultiEpochDataset(SampledMultiDataset): | |
| """Samples from multiple sub-datasets according to sampling ratios | |
| using virtual epoch sizes to speed up dataloading. | |
| Args: | |
| datasets ( | |
| List[~torch.utils.data.Dataset] | |
| or OrderedDict[str, ~torch.utils.data.Dataset] | |
| ): datasets | |
| sampling_ratios (List[float]): list of probability of each dataset to be sampled | |
| (default: None, which corresponds to concating all dataset together). | |
| seed (int): RNG seed to use (default: 2). | |
| epoch (int): starting epoch number (default: 1). | |
| eval_key (str, optional): a key used at evaluation time that causes | |
| this instance to pass-through batches from *datasets[eval_key]*. | |
| collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or | |
| CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures | |
| the collater to output batches of data mixed from all sub-datasets, | |
| and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys | |
| of sub-datasets. | |
| Note that not all sub-datasets will present in a single batch in both formats. | |
| virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func). | |
| split (str): the split of the data, e.g. 'train', 'valid' or 'test'. | |
| virtual_epoch_size (int): virtual epoch size, the dataset will go through the data by | |
| this virtual epoch size one by one to speed up data loading, e.g. indicing and filtering | |
| can be performed whenever a virtual epoch is loaded without waiting for the whole dataset to be loaded. | |
| shared_collater (bool): whether or not to all sub-datasets have the same collater. | |
| shard_epoch (int): the real epoch number for shard selection. | |
| shuffle (bool): whether or not to shuffle data (default: True). | |
| """ | |
| def __init__( | |
| self, | |
| datasets, | |
| sampling_ratios=None, | |
| seed=2, | |
| epoch=1, | |
| eval_key=None, | |
| collate_format=CollateFormat.single, | |
| virtual_size=default_virtual_size_func, | |
| split="", | |
| virtual_epoch_size=None, | |
| shared_collater=False, | |
| shard_epoch=1, | |
| shuffle=True, | |
| ): | |
| self.virtual_epoch_size = virtual_epoch_size | |
| self._current_epoch_start_index = None | |
| self._random_global_indices = None | |
| self.shard_epoch = shard_epoch if shard_epoch is not None else 1 | |
| self.load_next_shard = None | |
| self._epoch_sizes = None | |
| super().__init__( | |
| datasets=datasets, | |
| sampling_ratios=sampling_ratios, | |
| seed=seed, | |
| epoch=epoch, | |
| eval_key=eval_key, | |
| collate_format=collate_format, | |
| virtual_size=virtual_size, | |
| split=split, | |
| shared_collater=shared_collater, | |
| shuffle=shuffle, | |
| ) | |
| def _setup(self, epoch): | |
| self.virtual_epoch_size = ( | |
| self.virtual_epoch_size | |
| if self.virtual_epoch_size is not None | |
| else self.virtual_size | |
| ) | |
| if self.virtual_epoch_size > self.virtual_size: | |
| logger.warning( | |
| f"virtual epoch size {self.virtual_epoch_size} " | |
| f"is greater than virtual dataset size {self.virtual_size}" | |
| ) | |
| self.virtual_epoch_size = self.virtual_size | |
| self.num_virtual_epochs = math.ceil(self.virtual_size / self.virtual_epoch_size) | |
| self._current_epoch_start_index = self._get_epoch_start_index(epoch) | |
| logger.info( | |
| f"virtual epoch size {self.virtual_epoch_size}; virtual dataset size {self.virtual_size}" | |
| ) | |
| def _map_epoch_index_to_global(self, index): | |
| index = self._current_epoch_start_index + index | |
| # add randomness | |
| return self._random_global_indices[index] | |
| def sizes(self): | |
| if self._epoch_sizes is not None: | |
| return self._epoch_sizes | |
| _sizes = super().sizes | |
| indices = self._random_global_indices[ | |
| self._current_epoch_start_index : self._current_epoch_start_index | |
| + len(self) | |
| ] | |
| self._epoch_sizes = _sizes[indices] | |
| # del super()._sizes to save memory | |
| del self._sizes | |
| self._sizes = None | |
| return self._epoch_sizes | |
| def _get_dataset_and_index(self, index): | |
| i = self._map_epoch_index_to_global(index) | |
| return super()._get_dataset_and_index(i) | |
| def __len__(self): | |
| return ( | |
| self.virtual_epoch_size | |
| if self._current_epoch_start_index + self.virtual_epoch_size | |
| < self.virtual_size | |
| else self.virtual_size - self._current_epoch_start_index | |
| ) | |
| def set_epoch(self, epoch): | |
| if self._current_epoch_start_index is None: | |
| # initializing epoch idnices of a virtual dataset | |
| self._setup(epoch) | |
| self._next_virtual_epoch(epoch) | |
| else: | |
| # working on already intialized epoch indices | |
| if epoch == self._cur_epoch: | |
| # re-enter so return | |
| return | |
| self._next_virtual_epoch(epoch) | |
| def _get_epoch_start_index(self, epoch): | |
| assert epoch >= 1 # fairseq is using 1-based epoch everywhere | |
| return ((epoch - 1) % self.num_virtual_epochs) * self.virtual_epoch_size | |
| def _next_global_indices(self, epoch): | |
| rng = np.random.RandomState( | |
| [ | |
| int( | |
| hashlib.sha1( | |
| str(self.__class__.__name__).encode("utf-8") | |
| ).hexdigest(), | |
| 16, | |
| ) | |
| % (2**32), | |
| self.seed % (2**32), # global seed | |
| epoch, # epoch index, | |
| ] | |
| ) | |
| del self._random_global_indices | |
| self._random_global_indices = rng.choice( | |
| self.virtual_size, self.virtual_size, replace=False | |
| ) | |
| if self.load_next_shard is None: | |
| self.load_next_shard = False | |
| else: | |
| # increase shard epoch for next loading | |
| self.shard_epoch += 1 | |
| self.load_next_shard = True | |
| logger.info( | |
| "to load next epoch/shard in next load_dataset: " | |
| f"epoch={epoch}/shard_epoch={self.shard_epoch}" | |
| ) | |
| def _next_virtual_epoch(self, epoch): | |
| index = self._get_epoch_start_index(epoch) | |
| if index == 0 or self._random_global_indices is None: | |
| # need to start from the beginning, | |
| # so call super().set_epoch(epoch) to establish the global virtual indices | |
| logger.info( | |
| "establishing a new set of global virtual indices for " | |
| f"epoch={epoch}/shard_epoch={self.shard_epoch}" | |
| ) | |
| super().set_epoch(epoch) | |
| self._next_global_indices(epoch) | |
| else: | |
| self._cur_epoch = epoch | |
| # reset cache sizes and ordered_indices for the epoch after moving to a new epoch | |
| self._clean_if_not_none( | |
| [ | |
| self._epoch_sizes, | |
| ] | |
| ) | |
| self._epoch_sizes = None | |
| self._current_epoch_start_index = index | |