Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from collections import OrderedDict | |
| from typing import Callable, Dict, List | |
| import numpy as np | |
| from . import FairseqDataset | |
| def uniform_sampler(x): | |
| # Sample from uniform distribution | |
| return np.random.choice(x, 1).item() | |
| class MultiCorpusSampledDataset(FairseqDataset): | |
| """ | |
| Stores multiple instances of FairseqDataset together and in every iteration | |
| creates a batch by first sampling a dataset according to a specified | |
| probability distribution and then getting instances from that dataset. | |
| Args: | |
| datasets: an OrderedDict of FairseqDataset instances. | |
| sampling_func: A function for sampling over list of dataset keys. | |
| The default strategy is to sample uniformly. | |
| """ | |
| def __init__( | |
| self, | |
| datasets: Dict[str, FairseqDataset], | |
| sampling_func: Callable[[List], int] = None, | |
| ): | |
| super().__init__() | |
| assert isinstance(datasets, OrderedDict) | |
| self.datasets = datasets | |
| if sampling_func is None: | |
| sampling_func = uniform_sampler | |
| self.sampling_func = sampling_func | |
| self.total_num_instances = 0 | |
| for _, dataset in datasets.items(): | |
| assert isinstance(dataset, FairseqDataset) | |
| self.total_num_instances += len(dataset) | |
| self._ordered_indices = None | |
| def __len__(self): | |
| """ | |
| Length of this dataset is the sum of individual datasets | |
| """ | |
| return self.total_num_instances | |
| def ordered_indices(self): | |
| """ | |
| Ordered indices for batching. Here we call the underlying | |
| dataset's ordered_indices() so that we get the same random ordering | |
| as we would have from using the underlying dataset directly. | |
| """ | |
| if self._ordered_indices is None: | |
| self._ordered_indices = OrderedDict( | |
| [ | |
| (key, dataset.ordered_indices()) | |
| for key, dataset in self.datasets.items() | |
| ] | |
| ) | |
| return np.arange(len(self)) | |
| def _map_index_to_dataset(self, key: int, index: int): | |
| """ | |
| Different underlying datasets have different lengths. In order to ensure | |
| we are not accessing an index outside the range of the current dataset | |
| size, we wrap around. This function should be called after we have | |
| created an ordering for this and all underlying datasets. | |
| """ | |
| assert ( | |
| self._ordered_indices is not None | |
| ), "Must call MultiCorpusSampledDataset.ordered_indices() first" | |
| mapped_index = index % len(self.datasets[key]) | |
| return self._ordered_indices[key][mapped_index] | |
| def __getitem__(self, index: int): | |
| """ | |
| Get the item associated with index from each underlying dataset. | |
| Since index is in the range of [0, TotalNumInstances], we need to | |
| map the index to the dataset before retrieving the item. | |
| """ | |
| return OrderedDict( | |
| [ | |
| (key, dataset[self._map_index_to_dataset(key, index)]) | |
| for key, dataset in self.datasets.items() | |
| ] | |
| ) | |
| def collater(self, samples: List[Dict]): | |
| """ | |
| Generate a mini-batch for this dataset. | |
| To convert this into a regular mini-batch we use the following | |
| logic: | |
| 1. Select a dataset using the specified probability distribution. | |
| 2. Call the collater function of the selected dataset. | |
| """ | |
| if len(samples) == 0: | |
| return None | |
| selected_key = self.sampling_func(list(self.datasets.keys())) | |
| selected_samples = [sample[selected_key] for sample in samples] | |
| return self.datasets[selected_key].collater(selected_samples) | |
| def num_tokens(self, index: int): | |
| """ | |
| Return an example's length (number of tokens), used for batching. Here | |
| we return the max across all examples at index across all underlying | |
| datasets. | |
| """ | |
| return max( | |
| dataset.num_tokens(self._map_index_to_dataset(key, index)) | |
| for key, dataset in self.datasets.items() | |
| ) | |
| def size(self, index: int): | |
| """ | |
| Return an example's size as a float or tuple. Here we return the max | |
| across all underlying datasets. This value is used when filtering a | |
| dataset with max-positions. | |
| """ | |
| return max( | |
| dataset.size(self._map_index_to_dataset(key, index)) | |
| for key, dataset in self.datasets.items() | |
| ) | |
| def supports_prefetch(self): | |
| return all( | |
| getattr(dataset, "supports_prefetch", False) | |
| for dataset in self.datasets.values() | |
| ) | |
| def prefetch(self, indices): | |
| for key, dataset in self.datasets.items(): | |
| dataset.prefetch( | |
| [self._map_index_to_dataset(key, index) for index in indices] | |
| ) | |
| def supports_fetch_outside_dataloader(self): | |
| return all( | |
| self.datasets[key].supports_fetch_outside_dataloader | |
| for key in self.datasets | |
| ) | |