| import torch
|
| from torch.utils.data import Sampler, ConcatDataset
|
|
|
|
|
| class RandomConcatSampler(Sampler):
|
| """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
|
| in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
|
| However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
|
|
|
| For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
|
| Args:
|
| shuffle (bool): shuffle the random sampled indices across all sub-datsets.
|
| repeat (int): repeatedly use the sampled indices multiple times for training.
|
| [arXiv:1902.05509, arXiv:1901.09335]
|
| NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
|
| NOTE: This sampler behaves differently with DistributedSampler.
|
| It assume the dataset is splitted across ranks instead of replicated.
|
| TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
|
| ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
|
| """
|
| def __init__(self,
|
| data_source: ConcatDataset,
|
| n_samples_per_subset: int,
|
| subset_replacement: bool=True,
|
| shuffle: bool=True,
|
| repeat: int=1,
|
| seed: int=None):
|
| if not isinstance(data_source, ConcatDataset):
|
| raise TypeError("data_source should be torch.utils.data.ConcatDataset")
|
|
|
| self.data_source = data_source
|
| self.n_subset = len(self.data_source.datasets)
|
| self.n_samples_per_subset = n_samples_per_subset
|
| self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
|
| self.subset_replacement = subset_replacement
|
| self.repeat = repeat
|
| self.shuffle = shuffle
|
| self.generator = torch.manual_seed(seed)
|
| assert self.repeat >= 1
|
|
|
| def __len__(self):
|
| return self.n_samples
|
|
|
| def __iter__(self):
|
| indices = []
|
|
|
| for d_idx in range(self.n_subset):
|
| low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
|
| high = self.data_source.cumulative_sizes[d_idx]
|
| if self.subset_replacement:
|
| rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
|
| generator=self.generator, dtype=torch.int64)
|
| else:
|
| len_subset = len(self.data_source.datasets[d_idx])
|
| rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
|
| if len_subset >= self.n_samples_per_subset:
|
| rand_tensor = rand_tensor[:self.n_samples_per_subset]
|
| else:
|
| rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
|
| generator=self.generator, dtype=torch.int64)
|
| rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
|
| indices.append(rand_tensor)
|
| indices = torch.cat(indices)
|
| if self.shuffle:
|
| rand_tensor = torch.randperm(len(indices), generator=self.generator)
|
| indices = indices[rand_tensor]
|
|
|
|
|
| if self.repeat > 1:
|
| repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
|
| if self.shuffle:
|
| _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
|
| repeat_indices = map(_choice, repeat_indices)
|
| indices = torch.cat([indices, *repeat_indices], 0)
|
|
|
| assert indices.shape[0] == self.n_samples
|
| return iter(indices.tolist())
|
|
|