| | """Generic dataset loader"""
|
| |
|
| | from importlib import import_module
|
| |
|
| | from torch.utils.data import DataLoader
|
| | from torch.utils.data import SequentialSampler, RandomSampler
|
| | from torch.utils.data.distributed import DistributedSampler
|
| | from .sampler import DistributedEvalSampler
|
| |
|
| | class Data():
|
| | def __init__(self, args):
|
| |
|
| | self.modes = ['train', 'val', 'test', 'demo']
|
| |
|
| | self.action = {
|
| | 'train': args.do_train,
|
| | 'val': args.do_validate,
|
| | 'test': args.do_test,
|
| | 'demo': args.demo
|
| | }
|
| |
|
| | self.dataset_name = {
|
| | 'train': args.data_train,
|
| | 'val': args.data_val,
|
| | 'test': args.data_test,
|
| | 'demo': 'Demo'
|
| | }
|
| |
|
| | self.args = args
|
| |
|
| | def _get_data_loader(mode='train'):
|
| | dataset_name = self.dataset_name[mode]
|
| | dataset = import_module('data.' + dataset_name.lower())
|
| | dataset = getattr(dataset, dataset_name)(args, mode)
|
| |
|
| | if mode == 'train':
|
| | if args.distributed:
|
| | batch_size = int(args.batch_size / args.n_GPUs)
|
| | sampler = DistributedSampler(dataset, shuffle=True, num_replicas=args.world_size, rank=args.rank)
|
| | num_workers = int((args.num_workers + args.n_GPUs - 1) / args.n_GPUs)
|
| | else:
|
| | batch_size = args.batch_size
|
| | sampler = RandomSampler(dataset, replacement=False)
|
| | num_workers = args.num_workers
|
| | drop_last = True
|
| |
|
| | elif mode in ('val', 'test', 'demo'):
|
| | if args.distributed:
|
| | batch_size = 1
|
| | sampler = DistributedEvalSampler(dataset, shuffle=False, num_replicas=args.world_size, rank=args.rank)
|
| | num_workers = int((args.num_workers + args.n_GPUs - 1) / args.n_GPUs)
|
| | else:
|
| | batch_size = args.n_GPUs
|
| | sampler = SequentialSampler(dataset)
|
| | num_workers = args.num_workers
|
| | drop_last = False
|
| |
|
| | loader = DataLoader(
|
| | dataset=dataset,
|
| | batch_size=batch_size,
|
| | shuffle=False,
|
| | sampler=sampler,
|
| | num_workers=num_workers,
|
| | pin_memory=True,
|
| | drop_last=drop_last,
|
| | )
|
| |
|
| | return loader
|
| |
|
| | self.loaders = {}
|
| | for mode in self.modes:
|
| | if self.action[mode]:
|
| | self.loaders[mode] = _get_data_loader(mode)
|
| | print('===> Loading {} dataset: {}'.format(mode, self.dataset_name[mode]))
|
| | else:
|
| | self.loaders[mode] = None
|
| |
|
| | def get_loader(self):
|
| | return self.loaders
|
| |
|