File size: 1,136 Bytes
04c78c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import math

from easydict import EasyDict
from torch.utils.data import Dataset, default_collate


class EmptyDataset(Dataset):
    def __init__(self, length):
        self.length = length
    def __getitem__(self, _):
        return None
    def __len__(self):
        return self.length

class MultiLoader:
    """Iterator wrapper to iterate over multiple dataloaders at the same time."""
    def __init__(self, a, b):
        # a = self._repeat(a, b)
        self.loaders = [a,b]

    def __iter__(self):
        return zip(*self.loaders)

    def __len__(self):
        return min(map(len, self.loaders))

    def _repeat(self, a, b):
        if len(a) < len(b):
            k = math.ceil(len(b)/len(a))
            return RepeatLoader(a, k)
        return a

class RepeatLoader:
    def __init__(self, loader, k):
        self.loader = loader
        self.k = k

    def __iter__(self):
        for _ in range(self.k):
            for x in self.loader:
                yield x

    def __len__(self):
        return self.k*len(self.loader)

def collate_fn(data):
    return data if None in data else EasyDict(default_collate(data))