Spaces:
Sleeping
Sleeping
| from __future__ import absolute_import | |
| from __future__ import print_function | |
| from __future__ import division | |
| import torch | |
| from .datasets import EvalDataset, DataFactory | |
| from ..utils.data_utils import make_collate_fn | |
| def setup_eval_dataloader(cfg, data, split='test', backbone=None): | |
| if backbone is None: | |
| backbone = cfg.MODEL.BACKBONE | |
| dataset = EvalDataset(cfg, data, split, backbone) | |
| dloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=1, | |
| num_workers=0, | |
| shuffle=False, | |
| pin_memory=True, | |
| collate_fn=make_collate_fn() | |
| ) | |
| return dloader | |
| def setup_train_dataloader(cfg, ): | |
| n_workers = 0 if cfg.DEBUG else cfg.NUM_WORKERS | |
| train_dataset = DataFactory(cfg, cfg.TRAIN.STAGE) | |
| dloader = torch.utils.data.DataLoader( | |
| train_dataset, | |
| batch_size=cfg.TRAIN.BATCH_SIZE, | |
| num_workers=n_workers, | |
| shuffle=True, | |
| pin_memory=True, | |
| collate_fn=make_collate_fn() | |
| ) | |
| return dloader | |
| def setup_dloaders(cfg, dset='3dpw', split='val'): | |
| test_dloader = setup_eval_dataloader(cfg, dset, split, cfg.MODEL.BACKBONE) | |
| train_dloader = setup_train_dataloader(cfg) | |
| return train_dloader, test_dloader |