from torch.utils.data import Sampler import numpy as np class ConcatDatasetBatchSampler(Sampler): """This sampler is built to work with a standard Pytorch ConcatDataset. From SpeechBrain dataio see https://github.com/speechbrain/ It is used to retrieve elements from the different concatenated datasets placing them in the same batch with proportion specified by batch_sizes, e.g 8, 16 means each batch will be of 24 elements with the first 8 belonging to the first dataset in ConcatDataset object and the last 16 to the second. More than two datasets are supported, in that case you need to provide 3 batch sizes. Note ---- Batched are drawn from the datasets till the one with smallest length is exhausted. Thus number of examples in your training epoch is dictated by the dataset whose length is the smallest. Arguments --------- samplers : int The base seed to use for the random number generator. It is recommended to use a value which has a good mix of 0 and 1 bits. batch_sizes: list Batch sizes. epoch : int The epoch to start at. """ def __init__(self, samplers, batch_sizes: (tuple, list), epoch=0) -> None: if not isinstance(samplers, (list, tuple)): raise ValueError( "samplers should be a list or tuple of Pytorch Samplers, " "but got samplers={}".format(batch_sizes) ) if not isinstance(batch_sizes, (list, tuple)): raise ValueError( "batch_sizes should be a list or tuple of integers, " "but got batch_sizes={}".format(batch_sizes) ) if not len(batch_sizes) == len(samplers): raise ValueError("batch_sizes and samplers should be have same length") self.batch_sizes = batch_sizes self.samplers = samplers self.offsets = [0] + np.cumsum([len(x) for x in self.samplers]).tolist()[:-1] self.epoch = epoch self.set_epoch(self.epoch) def _iter_one_dataset(self, c_batch_size, c_sampler, c_offset): batch = [] for idx in c_sampler: batch.append(c_offset + idx) if len(batch) == c_batch_size: yield batch def set_epoch(self, epoch): if hasattr(self.samplers[0], "epoch"): for s in self.samplers: s.set_epoch(epoch) def __iter__(self): iterators = [iter(i) for i in self.samplers] tot_batch = [] for b_num in range(len(self)): for samp_idx in range(len(self.samplers)): c_batch = [] while len(c_batch) < self.batch_sizes[samp_idx]: c_batch.append(self.offsets[samp_idx] + next(iterators[samp_idx])) tot_batch.extend(c_batch) yield tot_batch tot_batch = [] def __len__(self): min_len = float("inf") for idx, sampler in enumerate(self.samplers): c_len = (len(sampler)) // self.batch_sizes[idx] min_len = min(c_len, min_len) return min_len