# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging from collections import OrderedDict from typing import Dict, List import numpy as np from fairseq.data import data_utils from . import FairseqDataset logger = logging.getLogger(__name__) class MultiCorpusDataset(FairseqDataset): """ Stores multiple instances of FairseqDataset together. Requires each instance to be the same dataset, as the collate method needs to work on batches with samples from each dataset. Allows specifying a distribution over the datasets to use. Note that unlike MultiCorpusSampledDataset, this distribution allows sampling for each item, rather than on a batch level. Each time ordered_indices() is called, a new sample is generated with the specified distribution. Args: datasets: a OrderedDict of FairseqDataset instances. distribution: a List containing the probability of getting an utterance from corresponding dataset """ def __init__( self, datasets: Dict[str, FairseqDataset], distribution: List[float], seed: int ): super().__init__() assert isinstance(datasets, OrderedDict) assert len(datasets) == len(distribution) self.datasets = datasets self.distribution = distribution self.seed = seed # Avoid repeated conversions to list later self.dataset_list = list(datasets.values()) self.total_num_instances = 0 first_dataset = list(self.datasets.values())[0] self.dataset_offsets = [] for dataset in datasets.values(): assert isinstance(dataset, FairseqDataset) assert type(dataset) is type(first_dataset) self.dataset_offsets.append(self.total_num_instances) self.total_num_instances += len(dataset) def ordered_indices(self): with data_utils.numpy_seed(self.seed, self.epoch): # Used to store the order of indices of each dataset to use indices = [ np.random.permutation(len(dataset)) for dataset in self.datasets.values() ] # Keep track of which samples we've used for each dataset counters = [0 for _ in self.datasets] return np.array( [ self._sample(indices, counters) for _ in range(self.total_num_instances) ], dtype=np.int64, ) def _sample(self, indices, counters): # First pick dataset dataset_idx = np.random.choice(len(self.distribution), p=self.distribution) # Then get dataset internal index idx = indices[dataset_idx][counters[dataset_idx]] # Convert to multi-datasets index idx += self.dataset_offsets[dataset_idx] counters[dataset_idx] += 1 # Reset if we reach end if counters[dataset_idx] == len(self.dataset_list[dataset_idx]): counters[dataset_idx] = 0 indices[dataset_idx] = np.random.permutation( len(self.dataset_list[dataset_idx]) ) return idx def _map_index(self, index: int): """ If dataset A has length N and dataset B has length M then index 1 maps to index 1 of dataset A, and index N + 1 maps to index 1 of B. """ counter = 0 for key, dataset in self.datasets.items(): if index < counter + len(dataset): return index - counter, key counter += len(dataset) raise ValueError( "Invalid index: {}, max: {}".format(index, self.total_num_instances) ) def __len__(self): """ Length of this dataset is the sum of individual datasets """ return self.total_num_instances def __getitem__(self, index): index, key = self._map_index(index) return self.datasets[key][index] def collater(self, samples): """ Since we enforce all datsets to be the same, collating is just picking the first one and doing collate. """ if len(samples) == 0: return None return list(self.datasets.values())[0].collater(samples) def num_tokens(self, index: int): index, key = self._map_index(index) return self.datasets[key].num_tokens(index) def size(self, index: int): index, key = self._map_index(index) return self.datasets[key].size(index) def set_epoch(self, epoch, **unused): super().set_epoch(epoch) self.epoch = epoch @property def supports_prefetch(self): return False