maotao / fairseq /data /multi_corpus_dataset.py
julse's picture
Upload 551 files
be611b4 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from collections import OrderedDict
from typing import Dict, List
import numpy as np
from fairseq.data import data_utils
from . import FairseqDataset
logger = logging.getLogger(__name__)
class MultiCorpusDataset(FairseqDataset):
"""
Stores multiple instances of FairseqDataset together. Requires each instance
to be the same dataset, as the collate method needs to work on batches with
samples from each dataset.
Allows specifying a distribution over the datasets to use. Note that unlike
MultiCorpusSampledDataset, this distribution allows sampling for each item,
rather than on a batch level.
Each time ordered_indices() is called, a new sample is generated with
the specified distribution.
Args:
datasets: a OrderedDict of FairseqDataset instances.
distribution: a List containing the probability of getting an utterance from
corresponding dataset
"""
def __init__(
self, datasets: Dict[str, FairseqDataset], distribution: List[float], seed: int
):
super().__init__()
assert isinstance(datasets, OrderedDict)
assert len(datasets) == len(distribution)
self.datasets = datasets
self.distribution = distribution
self.seed = seed
# Avoid repeated conversions to list later
self.dataset_list = list(datasets.values())
self.total_num_instances = 0
first_dataset = list(self.datasets.values())[0]
self.dataset_offsets = []
for dataset in datasets.values():
assert isinstance(dataset, FairseqDataset)
assert type(dataset) is type(first_dataset)
self.dataset_offsets.append(self.total_num_instances)
self.total_num_instances += len(dataset)
def ordered_indices(self):
with data_utils.numpy_seed(self.seed, self.epoch):
# Used to store the order of indices of each dataset to use
indices = [
np.random.permutation(len(dataset))
for dataset in self.datasets.values()
]
# Keep track of which samples we've used for each dataset
counters = [0 for _ in self.datasets]
return np.array(
[
self._sample(indices, counters)
for _ in range(self.total_num_instances)
],
dtype=np.int64,
)
def _sample(self, indices, counters):
# First pick dataset
dataset_idx = np.random.choice(len(self.distribution), p=self.distribution)
# Then get dataset internal index
idx = indices[dataset_idx][counters[dataset_idx]]
# Convert to multi-datasets index
idx += self.dataset_offsets[dataset_idx]
counters[dataset_idx] += 1
# Reset if we reach end
if counters[dataset_idx] == len(self.dataset_list[dataset_idx]):
counters[dataset_idx] = 0
indices[dataset_idx] = np.random.permutation(
len(self.dataset_list[dataset_idx])
)
return idx
def _map_index(self, index: int):
"""
If dataset A has length N and dataset B has length M
then index 1 maps to index 1 of dataset A, and index N + 1
maps to index 1 of B.
"""
counter = 0
for key, dataset in self.datasets.items():
if index < counter + len(dataset):
return index - counter, key
counter += len(dataset)
raise ValueError(
"Invalid index: {}, max: {}".format(index, self.total_num_instances)
)
def __len__(self):
"""
Length of this dataset is the sum of individual datasets
"""
return self.total_num_instances
def __getitem__(self, index):
index, key = self._map_index(index)
return self.datasets[key][index]
def collater(self, samples):
"""
Since we enforce all datsets to be the same, collating is just
picking the first one and doing collate.
"""
if len(samples) == 0:
return None
return list(self.datasets.values())[0].collater(samples)
def num_tokens(self, index: int):
index, key = self._map_index(index)
return self.datasets[key].num_tokens(index)
def size(self, index: int):
index, key = self._map_index(index)
return self.datasets[key].size(index)
def set_epoch(self, epoch, **unused):
super().set_epoch(epoch)
self.epoch = epoch
@property
def supports_prefetch(self):
return False