| from torch.utils.data.dataset import ConcatDataset as _ConcatDataset |
|
|
| from .builder import DATASETS |
|
|
|
|
| @DATASETS.register_module() |
| class ConcatDataset(_ConcatDataset): |
| """A wrapper of concatenated dataset. |
| |
| Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but |
| concat the group flag for image aspect ratio. |
| |
| Args: |
| datasets (list[:obj:`Dataset`]): A list of datasets. |
| """ |
|
|
| def __init__(self, datasets): |
| super(ConcatDataset, self).__init__(datasets) |
| self.CLASSES = datasets[0].CLASSES |
| self.PALETTE = datasets[0].PALETTE |
|
|
|
|
| @DATASETS.register_module() |
| class RepeatDataset(object): |
| """A wrapper of repeated dataset. |
| |
| The length of repeated dataset will be `times` larger than the original |
| dataset. This is useful when the data loading time is long but the dataset |
| is small. Using RepeatDataset can reduce the data loading time between |
| epochs. |
| |
| Args: |
| dataset (:obj:`Dataset`): The dataset to be repeated. |
| times (int): Repeat times. |
| """ |
|
|
| def __init__(self, dataset, times): |
| self.dataset = dataset |
| self.times = times |
| self.CLASSES = dataset.CLASSES |
| self.PALETTE = dataset.PALETTE |
| self._ori_len = len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| """Get item from original dataset.""" |
| return self.dataset[idx % self._ori_len] |
|
|
| def __len__(self): |
| """The length is multiplied by ``times``""" |
| return self.times * self._ori_len |
|
|