| | |
| | |
| | |
| | |
| |
|
| | import math |
| | import random |
| |
|
| | from torch.utils.data import ConcatDataset, Dataset |
| | from torch.utils.data.sampler import ( |
| | BatchSampler, |
| | RandomSampler, |
| | Sampler, |
| | SequentialSampler, |
| | ) |
| |
|
| |
|
| | class ScheduledSampler(Sampler): |
| | """A sampler that samples data from a given concat-dataset. |
| | |
| | Args: |
| | concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets |
| | batch_size (int): batch size |
| | holistic_shuffle (bool): whether to shuffle the whole dataset or not |
| | logger (logging.Logger): logger to print warning message |
| | |
| | Usage: |
| | For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True: |
| | >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]]))) |
| | [3, 4, 5, 0, 1, 2, 6, 7, 8] |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | concat_dataset, |
| | batch_size, |
| | holistic_shuffle, |
| | logger=None, |
| | loader_type="train", |
| | ): |
| | if not isinstance(concat_dataset, ConcatDataset): |
| | raise ValueError( |
| | "concat_dataset must be an instance of ConcatDataset, but got {}".format( |
| | type(concat_dataset) |
| | ) |
| | ) |
| | if not isinstance(batch_size, int): |
| | raise ValueError( |
| | "batch_size must be an integer, but got {}".format(type(batch_size)) |
| | ) |
| | if not isinstance(holistic_shuffle, bool): |
| | raise ValueError( |
| | "holistic_shuffle must be a boolean, but got {}".format( |
| | type(holistic_shuffle) |
| | ) |
| | ) |
| |
|
| | self.concat_dataset = concat_dataset |
| | self.batch_size = batch_size |
| | self.holistic_shuffle = holistic_shuffle |
| |
|
| | affected_dataset_name = [] |
| | affected_dataset_len = [] |
| | for dataset in concat_dataset.datasets: |
| | dataset_len = len(dataset) |
| | dataset_name = dataset.get_dataset_name() |
| | if dataset_len < batch_size: |
| | affected_dataset_name.append(dataset_name) |
| | affected_dataset_len.append(dataset_len) |
| |
|
| | self.type = loader_type |
| | for dataset_name, dataset_len in zip( |
| | affected_dataset_name, affected_dataset_len |
| | ): |
| | if not loader_type == "valid": |
| | logger.warning( |
| | "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format( |
| | loader_type, dataset_name, dataset_len, batch_size |
| | ) |
| | ) |
| |
|
| | def __len__(self): |
| | |
| | num_of_batches = sum( |
| | [ |
| | math.floor(len(dataset) / self.batch_size) |
| | for dataset in self.concat_dataset.datasets |
| | ] |
| | ) |
| | |
| | if self.type == "valid" and num_of_batches < 1: |
| | return len(self.concat_dataset) |
| | return num_of_batches * self.batch_size |
| |
|
| | def __iter__(self): |
| | iters = [] |
| | for dataset in self.concat_dataset.datasets: |
| | iters.append( |
| | SequentialSampler(dataset).__iter__() |
| | if not self.holistic_shuffle |
| | else RandomSampler(dataset).__iter__() |
| | ) |
| | |
| | init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1] |
| | output_batches = [] |
| | for dataset_idx in range(len(self.concat_dataset.datasets)): |
| | cur_batch = [] |
| | for idx in iters[dataset_idx]: |
| | cur_batch.append(idx + init_indices[dataset_idx]) |
| | if len(cur_batch) == self.batch_size: |
| | output_batches.append(cur_batch) |
| | cur_batch = [] |
| | |
| | if self.type == "valid" and len(cur_batch) > 0: |
| | output_batches.append(cur_batch) |
| |
|
| | |
| | random.shuffle(output_batches) |
| | output_indices = [item for sublist in output_batches for item in sublist] |
| | return iter(output_indices) |
| |
|
| |
|
| | def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type): |
| | sampler = ScheduledSampler( |
| | concat_dataset, |
| | cfg.train.batch_size, |
| | cfg.train.sampler.holistic_shuffle, |
| | logger, |
| | loader_type, |
| | ) |
| | batch_sampler = BatchSampler( |
| | sampler, |
| | cfg.train.batch_size, |
| | cfg.train.sampler.drop_last if not loader_type == "valid" else False, |
| | ) |
| | return sampler, batch_sampler |
| |
|