|
|
|
|
|
|
|
|
|
|
|
|
| """Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular"""
|
|
|
| from typing import Iterable
|
|
|
| import torch
|
| from torch.utils.data import (
|
| ConcatDataset as TorchConcatDataset,
|
| Dataset,
|
| Subset as TorchSubset,
|
| )
|
|
|
|
|
| class ConcatDataset(TorchConcatDataset):
|
| def __init__(self, datasets: Iterable[Dataset]) -> None:
|
| super(ConcatDataset, self).__init__(datasets)
|
|
|
| self.repeat_factors = torch.cat([d.repeat_factors for d in datasets])
|
|
|
| def set_epoch(self, epoch: int):
|
| for dataset in self.datasets:
|
| if hasattr(dataset, "epoch"):
|
| dataset.epoch = epoch
|
| if hasattr(dataset, "set_epoch"):
|
| dataset.set_epoch(epoch)
|
|
|
|
|
| class Subset(TorchSubset):
|
| def __init__(self, dataset, indices) -> None:
|
| super(Subset, self).__init__(dataset, indices)
|
|
|
| self.repeat_factors = dataset.repeat_factors[indices]
|
| assert len(indices) == len(self.repeat_factors)
|
|
|
|
|
|
|
| class RepeatFactorWrapper(Dataset):
|
| """
|
| Thin wrapper around a dataset to implement repeat factor sampling.
|
| The underlying dataset must have a repeat_factors member to indicate the per-image factor.
|
| Set it to uniformly ones to disable repeat factor sampling
|
| """
|
|
|
| def __init__(self, dataset, seed: int = 0):
|
| self.dataset = dataset
|
| self.epoch_ids = None
|
| self._seed = seed
|
|
|
|
|
| self._int_part = torch.trunc(dataset.repeat_factors)
|
| self._frac_part = dataset.repeat_factors - self._int_part
|
|
|
| def _get_epoch_indices(self, generator):
|
| """
|
| Create a list of dataset indices (with repeats) to use for one epoch.
|
|
|
| Args:
|
| generator (torch.Generator): pseudo random number generator used for
|
| stochastic rounding.
|
|
|
| Returns:
|
| torch.Tensor: list of dataset indices to use in one epoch. Each index
|
| is repeated based on its calculated repeat factor.
|
| """
|
|
|
|
|
|
|
| rands = torch.rand(len(self._frac_part), generator=generator)
|
| rep_factors = self._int_part + (rands < self._frac_part).float()
|
|
|
| indices = []
|
| for dataset_index, rep_factor in enumerate(rep_factors):
|
| indices.extend([dataset_index] * int(rep_factor.item()))
|
| return torch.tensor(indices, dtype=torch.int64)
|
|
|
| def __len__(self):
|
| if self.epoch_ids is None:
|
|
|
|
|
|
|
| raise RuntimeError("please call set_epoch first to get wrapped length")
|
|
|
|
|
| return len(self.epoch_ids)
|
|
|
| def set_epoch(self, epoch: int):
|
| g = torch.Generator()
|
| g.manual_seed(self._seed + epoch)
|
| self.epoch_ids = self._get_epoch_indices(g)
|
| if hasattr(self.dataset, "set_epoch"):
|
| self.dataset.set_epoch(epoch)
|
|
|
| def __getitem__(self, idx):
|
| if self.epoch_ids is None:
|
| raise RuntimeError(
|
| "Repeat ids haven't been computed. Did you forget to call set_epoch?"
|
| )
|
|
|
| return self.dataset[self.epoch_ids[idx]]
|
|
|