| | import bisect |
| | import random |
| | from typing import Iterable |
| |
|
| | from torch.utils.data import Dataset, IterableDataset |
| |
|
| |
|
| | class ConcatRepeatDataset(Dataset): |
| | datasets: list[Dataset] |
| | cumulative_sizes: list[int] |
| | repeats: list[int] |
| |
|
| | @staticmethod |
| | def cumsum(sequence, repeats): |
| | r, s = [], 0 |
| | for dataset, repeat in zip(sequence, repeats): |
| | l = len(dataset) * repeat |
| | r.append(l + s) |
| | s += l |
| | return r |
| |
|
| | def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): |
| | super().__init__() |
| |
|
| | self.datasets = list(datasets) |
| | self.repeats = repeats |
| |
|
| | assert len(self.datasets) > 0, "datasets should not be an empty iterable" |
| | assert len(self.datasets) == len( |
| | repeats |
| | ), "datasets and repeats should have the same length" |
| |
|
| | for d in self.datasets: |
| | assert not isinstance( |
| | d, IterableDataset |
| | ), "ConcatRepeatDataset does not support IterableDataset" |
| |
|
| | self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) |
| |
|
| | def __len__(self): |
| | return self.cumulative_sizes[-1] |
| |
|
| | def __getitem__(self, idx): |
| | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
| |
|
| | if dataset_idx == 0: |
| | sample_idx = idx |
| | else: |
| | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
| |
|
| | dataset = self.datasets[dataset_idx] |
| |
|
| | return dataset[sample_idx % len(dataset)] |
| |
|