| |
| |
|
|
| |
| |
|
|
| """Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular""" |
|
|
| from typing import Iterable |
|
|
| import torch |
| from torch.utils.data import ( |
| ConcatDataset as TorchConcatDataset, |
| Dataset, |
| Subset as TorchSubset, |
| ) |
|
|
|
|
| class ConcatDataset(TorchConcatDataset): |
| def __init__(self, datasets: Iterable[Dataset]) -> None: |
| super(ConcatDataset, self).__init__(datasets) |
|
|
| self.repeat_factors = torch.cat([d.repeat_factors for d in datasets]) |
|
|
| def set_epoch(self, epoch: int): |
| for dataset in self.datasets: |
| if hasattr(dataset, "epoch"): |
| dataset.epoch = epoch |
| if hasattr(dataset, "set_epoch"): |
| dataset.set_epoch(epoch) |
|
|
|
|
| class Subset(TorchSubset): |
| def __init__(self, dataset, indices) -> None: |
| super(Subset, self).__init__(dataset, indices) |
|
|
| self.repeat_factors = dataset.repeat_factors[indices] |
| assert len(indices) == len(self.repeat_factors) |
|
|
|
|
| |
| class RepeatFactorWrapper(Dataset): |
| """ |
| Thin wrapper around a dataset to implement repeat factor sampling. |
| The underlying dataset must have a repeat_factors member to indicate the per-image factor. |
| Set it to uniformly ones to disable repeat factor sampling |
| """ |
|
|
| def __init__(self, dataset, seed: int = 0): |
| self.dataset = dataset |
| self.epoch_ids = None |
| self._seed = seed |
|
|
| |
| self._int_part = torch.trunc(dataset.repeat_factors) |
| self._frac_part = dataset.repeat_factors - self._int_part |
|
|
| def _get_epoch_indices(self, generator): |
| """ |
| Create a list of dataset indices (with repeats) to use for one epoch. |
| |
| Args: |
| generator (torch.Generator): pseudo random number generator used for |
| stochastic rounding. |
| |
| Returns: |
| torch.Tensor: list of dataset indices to use in one epoch. Each index |
| is repeated based on its calculated repeat factor. |
| """ |
| |
| |
| |
| rands = torch.rand(len(self._frac_part), generator=generator) |
| rep_factors = self._int_part + (rands < self._frac_part).float() |
| |
| indices = [] |
| for dataset_index, rep_factor in enumerate(rep_factors): |
| indices.extend([dataset_index] * int(rep_factor.item())) |
| return torch.tensor(indices, dtype=torch.int64) |
|
|
| def __len__(self): |
| if self.epoch_ids is None: |
| |
| |
| |
| raise RuntimeError("please call set_epoch first to get wrapped length") |
| |
|
|
| return len(self.epoch_ids) |
|
|
| def set_epoch(self, epoch: int): |
| g = torch.Generator() |
| g.manual_seed(self._seed + epoch) |
| self.epoch_ids = self._get_epoch_indices(g) |
| if hasattr(self.dataset, "set_epoch"): |
| self.dataset.set_epoch(epoch) |
|
|
| def __getitem__(self, idx): |
| if self.epoch_ids is None: |
| raise RuntimeError( |
| "Repeat ids haven't been computed. Did you forget to call set_epoch?" |
| ) |
|
|
| return self.dataset[self.epoch_ids[idx]] |
|
|