| | import torch |
| | from torch.utils.data import DataLoader |
| | from datasets.citysundepth import CityScapesSunDepth |
| | from datasets.citysunrgb import CityScapesSunRGB |
| | from datasets.citysunrgbd import CityScapesSunRGBD |
| | from datasets.preprocessors import DepthTrainPre, DepthValPre, NYURGBDTrainPre, NYURGBDValPre, RGBDTrainPre, RGBDValPre, RGBTrainPre, RGBValPre |
| | from datasets.tfnyu import TFNYU |
| | from utils.constants import Constants as C |
| |
|
| | def get_dataset(args): |
| | datasetClass = None |
| | if args.data == "nyudv2": |
| | return TFNYU |
| | if args.data == "city" or args.data == "sunrgbd" or args.data == 'stanford_indoor': |
| | if len(args.modalities) == 1 and args.modalities[0] == 'rgb': |
| | datasetClass = CityScapesSunRGB |
| | elif len(args.modalities) == 1 and args.modalities[0] == 'depth': |
| | datasetClass = CityScapesSunDepth |
| | elif len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth': |
| | datasetClass = CityScapesSunRGBD |
| | else: |
| | raise Exception(f"{args.modalities} not configured in get_dataset function.") |
| | else: |
| | raise Exception(f"{args.data} not configured in get_dataset function.") |
| | return datasetClass |
| |
|
| | def get_preprocessors(args, dataset_settings, mode): |
| | if args.data == "nyudv2" and len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth': |
| | if mode == 'train': |
| | return NYURGBDTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings) |
| | elif mode == 'val': |
| | return NYURGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings) |
| |
|
| | if len(args.modalities) == 1 and args.modalities[0] == 'rgb': |
| | if mode == 'train': |
| | return RGBTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings) |
| | elif mode == 'val': |
| | return RGBValPre(C.pytorch_mean, C.pytorch_std, dataset_settings) |
| | else: |
| | return Exception("%s mode not defined" % mode) |
| | elif len(args.modalities) == 1 and args.modalities[0] == 'depth': |
| | if mode == 'train': |
| | return DepthTrainPre(dataset_settings) |
| | elif mode == 'val': |
| | return DepthValPre(dataset_settings) |
| | else: |
| | return Exception("%s mode not defined" % mode) |
| | elif len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth': |
| | if mode == 'train': |
| | return RGBDTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings) |
| | elif mode == 'val': |
| | return RGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings) |
| | else: |
| | return Exception("%s mode not defined" % mode) |
| | else: |
| | raise Exception("%s not configured for preprocessing" % args.modalities) |
| |
|
| | def get_train_loader(datasetClass, args, train_source, unsupervised = False): |
| | dataset_settings = {'rgb_root': args.rgb_root, |
| | 'gt_root': args.gt_root, |
| | 'depth_root': args.depth_root, |
| | 'train_source': train_source, |
| | 'eval_source': args.eval_source, |
| | 'required_length': args.total_train_imgs, |
| | |
| | 'train_scale_array': args.train_scale_array, |
| | 'image_height': args.image_height, |
| | 'image_width': args.image_width, |
| | 'modalities': args.modalities} |
| |
|
| | preprocessing = get_preprocessors(args, dataset_settings, "train") |
| | train_dataset = datasetClass(dataset_settings, "train", unsupervised, preprocessing) |
| | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas = args.world_size, rank = args.rank) |
| | if unsupervised and "unsup_batch_size" in args: |
| | batch_size = args.unsup_batch_size |
| | else: |
| | batch_size = args.batch_size |
| | train_loader = DataLoader(train_dataset, |
| | batch_size = args.batch_size // args.world_size, |
| | num_workers = args.num_workers, |
| | drop_last = True, |
| | shuffle = False, |
| | sampler = train_sampler) |
| | return train_loader |
| | |
| | def get_val_loader(datasetClass, args): |
| | dataset_settings = {'rgb_root': args.rgb_root, |
| | 'gt_root': args.gt_root, |
| | 'depth_root': args.depth_root, |
| | 'train_source': None, |
| | 'eval_source': args.eval_source, |
| | 'required_length': None, |
| | 'max_samples': None, |
| | 'train_scale_array': args.train_scale_array, |
| | 'image_height': args.image_height, |
| | 'image_width': args.image_width, |
| | 'modalities': args.modalities} |
| | if args.data == 'sunrgbd': |
| | eval_sources = [] |
| | for shape in ['427_561', '441_591', '530_730', '531_681']: |
| | eval_sources.append(dataset_settings['eval_source'].split('.')[0] + '_' + shape + '.txt') |
| | else: |
| | eval_sources = [args.eval_source] |
| |
|
| | preprocessing = get_preprocessors(args, dataset_settings, "val") |
| | if args.sliding_eval: |
| | collate_fn = _sliding_collate_fn |
| | else: |
| | collate_fn = None |
| |
|
| | val_loaders = [] |
| | for eval_source in eval_sources: |
| | dataset_settings['eval_source'] = eval_source |
| | val_dataset = datasetClass(dataset_settings, "val", False, preprocessing, args.sliding_eval, args.stride_rate) |
| | if args.rank is not None: |
| | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas = args.world_size, rank = args.rank) |
| | batch_size = args.val_batch_size // args.world_size |
| | else: |
| | val_sampler = None |
| | batch_size = args.val_batch_size |
| |
|
| | val_loader = DataLoader(val_dataset, |
| | batch_size = batch_size, |
| | num_workers = 4, |
| | drop_last = False, |
| | shuffle = False, |
| | collate_fn = collate_fn, |
| | sampler = val_sampler) |
| | val_loaders.append(val_loader) |
| | return val_loaders |
| | |
| |
|
| | def _sliding_collate_fn(batch): |
| | gt = torch.stack([b['gt'] for b in batch]) |
| | sliding_output = [] |
| | num_modalities = len(batch[0]['sliding_output'][0][0]) |
| | for i in range(len(batch[0]['sliding_output'])): |
| | imgs = [torch.stack([b['sliding_output'][i][0][m] for b in batch]) for m in range(num_modalities)] |
| | pos = batch[0]['sliding_output'][i][1] |
| | pos_compare = [(b['sliding_output'][i][1] == pos).all() for b in batch] |
| | assert all(pos_compare), f"Position not same for all points in the batch: {pos_compare}, {[b['sliding_output'][i][1] for b in batch]}" |
| | margin = batch[0]['sliding_output'][i][2] |
| | margin_compare = [(b['sliding_output'][i][2] == margin).all() for b in batch] |
| | assert all(margin_compare), f"Margin not same for all points in the batch: {margin_compare}, {[b['sliding_output'][i][2] for b in batch]}" |
| | sliding_output.append((imgs, pos, margin)) |
| | return {"gt": gt, "sliding_output": sliding_output} |