| """ | |
| @Date: 2021/07/18 | |
| @description: | |
| """ | |
| import numpy as np | |
| import torch.utils.data | |
| from dataset.mp3d_dataset import MP3DDataset | |
| from dataset.pano_s2d3d_dataset import PanoS2D3DDataset | |
| from dataset.pano_s2d3d_mix_dataset import PanoS2D3DMixDataset | |
| from dataset.zind_dataset import ZindDataset | |
| def build_loader(config, logger): | |
| name = config.DATA.DATASET | |
| ddp = config.WORLD_SIZE > 1 | |
| train_dataset = None | |
| train_data_loader = None | |
| if config.MODE == 'train': | |
| train_dataset = build_dataset(mode='train', config=config, logger=logger) | |
| val_dataset = build_dataset(mode=config.VAL_NAME if config.MODE != 'test' else 'test', config=config, logger=logger) | |
| train_sampler = None | |
| val_sampler = None | |
| if ddp: | |
| if train_dataset: | |
| train_sampler = torch.utils.data.DistributedSampler(train_dataset, shuffle=True) | |
| val_sampler = torch.utils.data.DistributedSampler(val_dataset, shuffle=False) | |
| batch_size = config.DATA.BATCH_SIZE | |
| num_workers = 0 if config.DEBUG else config.DATA.NUM_WORKERS | |
| pin_memory = config.DATA.PIN_MEMORY | |
| if train_dataset: | |
| logger.info(f'Train data loader batch size: {batch_size}') | |
| train_data_loader = torch.utils.data.DataLoader( | |
| train_dataset, sampler=train_sampler, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| drop_last=True, | |
| ) | |
| batch_size = batch_size - (len(val_dataset) % np.arange(batch_size, 0, -1)).tolist().index(0) | |
| logger.info(f'Val data loader batch size: {batch_size}') | |
| val_data_loader = torch.utils.data.DataLoader( | |
| val_dataset, sampler=val_sampler, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| drop_last=False | |
| ) | |
| logger.info(f'Build data loader: num_workers:{num_workers} pin_memory:{pin_memory}') | |
| return train_data_loader, val_data_loader | |
| def build_dataset(mode, config, logger): | |
| name = config.DATA.DATASET | |
| if name == 'mp3d': | |
| dataset = MP3DDataset( | |
| root_dir=config.DATA.DIR, | |
| mode=mode, | |
| shape=config.DATA.SHAPE, | |
| max_wall_num=config.DATA.WALL_NUM, | |
| aug=config.DATA.AUG if mode == 'train' else None, | |
| camera_height=config.DATA.CAMERA_HEIGHT, | |
| logger=logger, | |
| for_test_index=config.DATA.FOR_TEST_INDEX, | |
| keys=config.DATA.KEYS | |
| ) | |
| elif name == 'pano_s2d3d': | |
| dataset = PanoS2D3DDataset( | |
| root_dir=config.DATA.DIR, | |
| mode=mode, | |
| shape=config.DATA.SHAPE, | |
| max_wall_num=config.DATA.WALL_NUM, | |
| aug=config.DATA.AUG if mode == 'train' else None, | |
| camera_height=config.DATA.CAMERA_HEIGHT, | |
| logger=logger, | |
| for_test_index=config.DATA.FOR_TEST_INDEX, | |
| subset=config.DATA.SUBSET, | |
| keys=config.DATA.KEYS | |
| ) | |
| elif name == 'pano_s2d3d_mix': | |
| dataset = PanoS2D3DMixDataset( | |
| root_dir=config.DATA.DIR, | |
| mode=mode, | |
| shape=config.DATA.SHAPE, | |
| max_wall_num=config.DATA.WALL_NUM, | |
| aug=config.DATA.AUG if mode == 'train' else None, | |
| camera_height=config.DATA.CAMERA_HEIGHT, | |
| logger=logger, | |
| for_test_index=config.DATA.FOR_TEST_INDEX, | |
| subset=config.DATA.SUBSET, | |
| keys=config.DATA.KEYS | |
| ) | |
| elif name == 'zind': | |
| dataset = ZindDataset( | |
| root_dir=config.DATA.DIR, | |
| mode=mode, | |
| shape=config.DATA.SHAPE, | |
| max_wall_num=config.DATA.WALL_NUM, | |
| aug=config.DATA.AUG if mode == 'train' else None, | |
| camera_height=config.DATA.CAMERA_HEIGHT, | |
| logger=logger, | |
| for_test_index=config.DATA.FOR_TEST_INDEX, | |
| is_simple=True, | |
| is_ceiling_flat=False, | |
| keys=config.DATA.KEYS, | |
| vp_align=config.EVAL.POST_PROCESSING is not None and 'manhattan' in config.EVAL.POST_PROCESSING | |
| ) | |
| else: | |
| raise NotImplementedError(f"Unknown dataset: {name}") | |
| return dataset | |