# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, List, Optional, Tuple import numpy as np import torch.utils.data as pt_data from torch.utils.data import Dataset, IterableDataset __all__ = ['ConcatDataset', 'ConcatMapDataset'] class ConcatDataset(IterableDataset): """ A dataset that accepts as argument multiple datasets and then samples from them based on the specified sampling technique. Args: datasets (list): A list of datasets to sample from. shuffle (bool): Whether to shuffle individual datasets. Only works with non-iterable datasets. Defaults to True. sampling_technique (str): Sampling technique to choose which dataset to draw a sample from. Defaults to 'temperature'. Currently supports 'temperature', 'random' and 'round-robin'. sampling_temperature (int): Temperature value for sampling. Only used when sampling_technique = 'temperature'. Defaults to 5. sampling_probabilities (list): Probability values for sampling. Only used when sampling_technique = 'random'. global_rank (int): Worker rank, used for partitioning map style datasets. Defaults to 0. world_size (int): Total number of processes, used for partitioning map style datasets. Defaults to 1. """ def __init__( self, datasets: List[Any], shuffle: bool = True, sampling_technique: str = 'temperature', sampling_temperature: int = 5, sampling_probabilities: List[float] = None, global_rank: int = 0, world_size: int = 1, ): super().__init__() supported_sampling_techniques = ['temperature', 'random', 'round-robin'] self.datasets = datasets self.iterables = [None] * len(datasets) self.shuffle = shuffle self.global_rank = global_rank self.world_size = world_size self.sampling_kwargs = {} if sampling_technique == 'temperature': self.index_generator = ConcatDataset.temperature_generator self.sampling_kwargs['temperature'] = sampling_temperature elif sampling_technique == 'random': self.index_generator = ConcatDataset.random_generator self.sampling_kwargs['p'] = sampling_probabilities elif sampling_technique == 'round-robin': self.index_generator = ConcatDataset.round_robin_generator else: raise ValueError(f"Currently we only support sampling techniques in {supported_sampling_techniques}.") self.length = 0 if isinstance(datasets[0], IterableDataset): self.kind = 'iterable' else: self.kind = 'map' for idx, dataset in enumerate(datasets): isiterable = isinstance(dataset, IterableDataset) if (isiterable and not self.kind == 'iterable') or (not isiterable and self.kind == 'iterable'): raise ValueError("All datasets in ConcatDataset must be of the same kind (Iterable or Map).") if self.kind == 'map': self.length += len(dataset) // world_size else: self.length += len(dataset) def get_iterable(self, dataset): if isinstance(dataset, IterableDataset): return dataset.__iter__() else: indices = np.arange(len(dataset)) if self.shuffle: np.random.shuffle(indices) return iter(indices) def __iter__(self): worker_info = pt_data.get_worker_info() if worker_info is None: max_elements = self.length wid = 0 wnum = 1 else: wid = worker_info.id wnum = worker_info.num_workers max_elements = len(range(wid, self.length, wnum)) if self.kind == 'map': for idx in range(len(self.datasets)): start_idx = (len(self.datasets[idx]) // self.world_size) * self.global_rank end_idx = start_idx + (len(self.datasets[idx]) // self.world_size) if self.global_rank == self.world_size - 1: end_idx = len(self.datasets[idx]) indices = range(start_idx + wid, end_idx, wnum) self.datasets[idx] = pt_data.Subset(self.datasets[idx], indices) for idx, dataset in enumerate(self.datasets): iterable = self.get_iterable(dataset) self.iterables[idx] = iterable n = 0 ind_gen = self.index_generator(self.datasets, **self.sampling_kwargs) while n < max_elements: n += 1 try: ind = next(ind_gen) except StopIteration: return try: val = next(self.iterables[ind]) if self.kind == 'map': val = self.datasets[ind][val] yield val except StopIteration: self.iterables[ind] = self.get_iterable(self.datasets[ind]) n -= 1 def __len__(self): return self.length @staticmethod def temperature_generator(datasets, **kwargs): temp = kwargs.get('temperature') if not temp: raise ValueError("Temperature generator expects a 'temperature' keyword argument.") lengths = [] num = len(datasets) for dataset in datasets: lengths.append(len(dataset)) p = np.array(lengths) / np.sum(lengths) p = np.power(p, 1 / temp) p = p / np.sum(p) while True: ind = np.random.choice(np.arange(num), p=p) yield ind @staticmethod def round_robin_generator(datasets, **kwargs): num = len(datasets) while True: for i in range(num): yield i @staticmethod def random_generator(datasets, **kwargs): p = kwargs.get('p') if not p: raise ValueError("Random generator expects a 'p' keyowrd argument for sampling probabilities.") num = len(datasets) if len(p) != num: raise ValueError("Length of probabilities list must be equal to the number of datasets.") while True: ind = np.random.choice(np.arange(num), p=p) yield ind class ConcatMapDataset(Dataset): """ A dataset that accepts as argument multiple datasets and then samples from them based on the specified sampling technique. Args: datasets (list): A list of datasets to sample from. sampling_technique (str): Sampling technique to choose which dataset to draw a sample from. Defaults to 'temperature'. Currently supports 'temperature', 'random' and 'round-robin'. sampling_temperature (int): Temperature value for sampling. Only used when sampling_technique = 'temperature'. Defaults to 5. sampling_probabilities (list): Probability values for sampling. Only used when sampling_technique = 'random'. seed: Optional value to seed the numpy RNG. """ def __init__( self, datasets: List[Any], sampling_technique: str = 'temperature', sampling_temperature: int = 5, sampling_probabilities: Optional[List[float]] = None, seed: Optional[int] = None, ): super().__init__() self.datasets = datasets self.lengths = [len(x) for x in self.datasets] self.sampling_technique = sampling_technique self.sampling_temperature = sampling_temperature self.sampling_probabilities = sampling_probabilities self.np_rng = np.random.RandomState(seed) # Build a list of size `len(self)`. Each tuple contains (dataset_id, dataset_index) self.indices: List[Tuple[int, int]] = [] # Current position as we consume indices from each data set dataset_positions = [0] * len(self.datasets) # Random permutation of each dataset. Will be regenerated when exhausted. shuffled_indices = [self.np_rng.permutation(len(x)) for x in self.datasets] # Build the list of randomly-chosen datasets spanning the entire length, adhering to sampling technique if self.sampling_technique == "round-robin": # To exhaust longest dataset, need to draw `num_datasets * max_dataset_len` samples total_length = max(self.lengths) * len(self.lengths) # For round robin, iterate through each dataset dataset_ids = np.arange(total_length) % len(self.datasets) for dataset_id in dataset_ids: position = dataset_positions[dataset_id] index = shuffled_indices[dataset_id][position] self.indices.append((dataset_id, index)) dataset_positions[dataset_id] += 1 if dataset_positions[dataset_id] == len(shuffled_indices[dataset_id]): dataset_positions[dataset_id] = 0 shuffled_indices[dataset_id] = self.np_rng.permutation(len(self.datasets[dataset_id])) else: # Resolve probabilities of drawing from each data set if self.sampling_technique == "random": if sampling_probabilities is None or len(sampling_probabilities) != len(self.datasets): raise ValueError( f"Need {len(self.datasets)} probabilities; got " f"{len(sampling_probabilities) if sampling_probabilities is not None else 'None'}" ) p = np.array(self.sampling_probabilities) elif self.sampling_technique == "temperature": p = np.array([len(x) for x in self.datasets]) p = np.power(p, 1 / self.sampling_temperature) else: raise ValueError(f"Couldn't interpret sampling technique: {sampling_technique}") # Normalize probabilities p = p / np.sum(p) # Will randomly choose from datasets choices = np.arange(len(self.datasets)) # Keep going until largest dataset is exhausted. exhausted_datasets = set() while len(exhausted_datasets) < len(self.datasets): # Randomly choose a dataset for each position in accordance with p dataset_id = self.np_rng.choice(a=choices, p=p) dataset = self.datasets[dataset_id] # Pick next index from dataset position = dataset_positions[dataset_id] index = shuffled_indices[dataset_id][position] self.indices.append((dataset_id, index)) # Maybe reset this dataset's permutation dataset_positions[dataset_id] += 1 if dataset_positions[dataset_id] >= len(dataset): shuffled_indices[dataset_id] = self.np_rng.permutation(len(dataset)) dataset_positions[dataset_id] = 0 exhausted_datasets.add(dataset_id) def __len__(self): return len(self.indices) def __getitem__(self, idx): dataset_id, dataset_index = self.indices[idx] return self.datasets[dataset_id][dataset_index]