| | |
| | |
| | |
| | |
| |
|
| | import logging |
| | from collections import OrderedDict |
| | from typing import Dict, List |
| |
|
| | import numpy as np |
| | from fairseq.data import data_utils |
| |
|
| | from . import FairseqDataset |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class MultiCorpusDataset(FairseqDataset): |
| | """ |
| | Stores multiple instances of FairseqDataset together. Requires each instance |
| | to be the same dataset, as the collate method needs to work on batches with |
| | samples from each dataset. |
| | |
| | Allows specifying a distribution over the datasets to use. Note that unlike |
| | MultiCorpusSampledDataset, this distribution allows sampling for each item, |
| | rather than on a batch level. |
| | |
| | Each time ordered_indices() is called, a new sample is generated with |
| | the specified distribution. |
| | |
| | Args: |
| | datasets: a OrderedDict of FairseqDataset instances. |
| | distribution: a List containing the probability of getting an utterance from |
| | corresponding dataset |
| | """ |
| |
|
| | def __init__( |
| | self, datasets: Dict[str, FairseqDataset], distribution: List[float], seed: int |
| | ): |
| | super().__init__() |
| | assert isinstance(datasets, OrderedDict) |
| | assert len(datasets) == len(distribution) |
| | self.datasets = datasets |
| | self.distribution = distribution |
| | self.seed = seed |
| |
|
| | |
| | self.dataset_list = list(datasets.values()) |
| | self.total_num_instances = 0 |
| |
|
| | first_dataset = list(self.datasets.values())[0] |
| |
|
| | self.dataset_offsets = [] |
| | for dataset in datasets.values(): |
| | assert isinstance(dataset, FairseqDataset) |
| | assert type(dataset) is type(first_dataset) |
| | self.dataset_offsets.append(self.total_num_instances) |
| | self.total_num_instances += len(dataset) |
| |
|
| | def ordered_indices(self): |
| | with data_utils.numpy_seed(self.seed, self.epoch): |
| | |
| | indices = [ |
| | np.random.permutation(len(dataset)) |
| | for dataset in self.datasets.values() |
| | ] |
| | |
| | counters = [0 for _ in self.datasets] |
| |
|
| | return np.array( |
| | [ |
| | self._sample(indices, counters) |
| | for _ in range(self.total_num_instances) |
| | ], |
| | dtype=np.int64, |
| | ) |
| |
|
| | def _sample(self, indices, counters): |
| | |
| | dataset_idx = np.random.choice(len(self.distribution), p=self.distribution) |
| |
|
| | |
| | idx = indices[dataset_idx][counters[dataset_idx]] |
| |
|
| | |
| | idx += self.dataset_offsets[dataset_idx] |
| |
|
| | counters[dataset_idx] += 1 |
| |
|
| | |
| | if counters[dataset_idx] == len(self.dataset_list[dataset_idx]): |
| | counters[dataset_idx] = 0 |
| | indices[dataset_idx] = np.random.permutation( |
| | len(self.dataset_list[dataset_idx]) |
| | ) |
| |
|
| | return idx |
| |
|
| | def _map_index(self, index: int): |
| | """ |
| | If dataset A has length N and dataset B has length M |
| | then index 1 maps to index 1 of dataset A, and index N + 1 |
| | maps to index 1 of B. |
| | """ |
| | counter = 0 |
| | for key, dataset in self.datasets.items(): |
| | if index < counter + len(dataset): |
| | return index - counter, key |
| | counter += len(dataset) |
| | raise ValueError( |
| | "Invalid index: {}, max: {}".format(index, self.total_num_instances) |
| | ) |
| |
|
| | def __len__(self): |
| | """ |
| | Length of this dataset is the sum of individual datasets |
| | """ |
| | return self.total_num_instances |
| |
|
| | def __getitem__(self, index): |
| | index, key = self._map_index(index) |
| | return self.datasets[key][index] |
| |
|
| | def collater(self, samples): |
| | """ |
| | Since we enforce all datsets to be the same, collating is just |
| | picking the first one and doing collate. |
| | """ |
| | if len(samples) == 0: |
| | return None |
| |
|
| | return list(self.datasets.values())[0].collater(samples) |
| |
|
| | def num_tokens(self, index: int): |
| | index, key = self._map_index(index) |
| | return self.datasets[key].num_tokens(index) |
| |
|
| | def size(self, index: int): |
| | index, key = self._map_index(index) |
| | return self.datasets[key].size(index) |
| |
|
| | def set_epoch(self, epoch, **unused): |
| | super().set_epoch(epoch) |
| | self.epoch = epoch |
| |
|
| | @property |
| | def supports_prefetch(self): |
| | return False |
| |
|