File size: 1,953 Bytes
ff07ed4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import random
from torch.utils.data.dataset import Dataset
import numpy as np
from .agora import AGORA
from .bedlam import BEDLAM
datasets_dict = {'bedlam': BEDLAM, 'agora': AGORA}
class MultipleDatasets(Dataset):
def __init__(self, datasets_used, datasets_split = None, make_same_len = False, **kwargs):
if datasets_split is None:
self.dbs = [datasets_dict[ds](**kwargs) for ds in datasets_used]
else:
self.dbs = [datasets_dict[ds](split, **kwargs) for ds, split in zip(datasets_used, datasets_split)]
self.db_num = len(self.dbs)
self.max_db_data_num = max([len(db) for db in self.dbs])
self.db_len_cumsum = np.cumsum([len(db) for db in self.dbs])
self.make_same_len = make_same_len
self.human_model = self.dbs[0].human_model
def __len__(self):
# all dbs have the same length
if self.make_same_len:
return self.max_db_data_num * self.db_num
# each db has different length
else:
return sum([len(db) for db in self.dbs])
def __getitem__(self, index):
if self.make_same_len:
db_idx = index // self.max_db_data_num
data_idx = index % self.max_db_data_num
if data_idx >= len(self.dbs[db_idx]) * (self.max_db_data_num // len(self.dbs[db_idx])): # last batch: random sampling
data_idx = random.randint(0,len(self.dbs[db_idx])-1)
else: # before last batch: use modular
data_idx = data_idx % len(self.dbs[db_idx])
else:
for i in range(self.db_num):
if index < self.db_len_cumsum[i]:
db_idx = i
break
if db_idx == 0:
data_idx = index
else:
data_idx = index - self.db_len_cumsum[db_idx-1]
norm_img, meta_data = self.dbs[db_idx][data_idx]
return norm_img, meta_data |