File size: 1,136 Bytes
04c78c7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | import math
from easydict import EasyDict
from torch.utils.data import Dataset, default_collate
class EmptyDataset(Dataset):
def __init__(self, length):
self.length = length
def __getitem__(self, _):
return None
def __len__(self):
return self.length
class MultiLoader:
"""Iterator wrapper to iterate over multiple dataloaders at the same time."""
def __init__(self, a, b):
# a = self._repeat(a, b)
self.loaders = [a,b]
def __iter__(self):
return zip(*self.loaders)
def __len__(self):
return min(map(len, self.loaders))
def _repeat(self, a, b):
if len(a) < len(b):
k = math.ceil(len(b)/len(a))
return RepeatLoader(a, k)
return a
class RepeatLoader:
def __init__(self, loader, k):
self.loader = loader
self.k = k
def __iter__(self):
for _ in range(self.k):
for x in self.loader:
yield x
def __len__(self):
return self.k*len(self.loader)
def collate_fn(data):
return data if None in data else EasyDict(default_collate(data)) |