Spaces:
Sleeping
Sleeping
| from torch.utils.data import Dataset | |
| from typing import List | |
| class ZipDataset(Dataset): | |
| def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False): | |
| self.datasets = datasets | |
| self.transforms = transforms | |
| if assert_equal_length: | |
| for i in range(1, len(datasets)): | |
| assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.' | |
| def __len__(self): | |
| return max(len(d) for d in self.datasets) | |
| def __getitem__(self, idx): | |
| x = tuple(d[idx % len(d)] for d in self.datasets) | |
| if self.transforms: | |
| x = self.transforms(*x) | |
| return x | |