| | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset |
| |
|
| | from .builder import DATASETS |
| |
|
| |
|
| | @DATASETS.register_module() |
| | class ConcatDataset(_ConcatDataset): |
| | """A wrapper of concatenated dataset. |
| | |
| | Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but |
| | concat the group flag for image aspect ratio. |
| | |
| | Args: |
| | datasets (list[:obj:`Dataset`]): A list of datasets. |
| | """ |
| |
|
| | def __init__(self, datasets): |
| | super(ConcatDataset, self).__init__(datasets) |
| | self.CLASSES = datasets[0].CLASSES |
| | self.PALETTE = datasets[0].PALETTE |
| |
|
| |
|
| | @DATASETS.register_module() |
| | class RepeatDataset(object): |
| | """A wrapper of repeated dataset. |
| | |
| | The length of repeated dataset will be `times` larger than the original |
| | dataset. This is useful when the data loading time is long but the dataset |
| | is small. Using RepeatDataset can reduce the data loading time between |
| | epochs. |
| | |
| | Args: |
| | dataset (:obj:`Dataset`): The dataset to be repeated. |
| | times (int): Repeat times. |
| | """ |
| |
|
| | def __init__(self, dataset, times): |
| | self.dataset = dataset |
| | self.times = times |
| | self.CLASSES = dataset.CLASSES |
| | self.PALETTE = dataset.PALETTE |
| | self._ori_len = len(self.dataset) |
| |
|
| | def __getitem__(self, idx): |
| | """Get item from original dataset.""" |
| | return self.dataset[idx % self._ori_len] |
| |
|
| | def __len__(self): |
| | """The length is multiplied by ``times``""" |
| | return self.times * self._ori_len |
| |
|