| | import math |
| |
|
| | from easydict import EasyDict |
| | from torch.utils.data import Dataset, default_collate |
| |
|
| |
|
| | class EmptyDataset(Dataset): |
| | def __init__(self, length): |
| | self.length = length |
| | def __getitem__(self, _): |
| | return None |
| | def __len__(self): |
| | return self.length |
| |
|
| | class MultiLoader: |
| | """Iterator wrapper to iterate over multiple dataloaders at the same time.""" |
| | def __init__(self, a, b): |
| | |
| | self.loaders = [a,b] |
| |
|
| | def __iter__(self): |
| | return zip(*self.loaders) |
| |
|
| | def __len__(self): |
| | return min(map(len, self.loaders)) |
| |
|
| | def _repeat(self, a, b): |
| | if len(a) < len(b): |
| | k = math.ceil(len(b)/len(a)) |
| | return RepeatLoader(a, k) |
| | return a |
| |
|
| | class RepeatLoader: |
| | def __init__(self, loader, k): |
| | self.loader = loader |
| | self.k = k |
| |
|
| | def __iter__(self): |
| | for _ in range(self.k): |
| | for x in self.loader: |
| | yield x |
| |
|
| | def __len__(self): |
| | return self.k*len(self.loader) |
| |
|
| | def collate_fn(data): |
| | return data if None in data else EasyDict(default_collate(data)) |