import argparse from dataset.scannet import ITWDataset, ScanNetDataset, ScanNet18Dataset, ARKitDataset, ScannetPP2Dataset, ScanNet20Dataset, WildDataset, ITWDataset from dataset.matterport import MatterportDataset from dataset.scannetpp import ScanNetPPDataset from dataset.demo import DemoDataset import json def update_args(args): config = args.config config_file = config if config in ['scannet', 'scannet18']: config_file = 'scannet' if config in ['scannetpp_v2_dust3r_posed', 'scannetpp_v2_dust3r_unposed']: config_file = config config_path = f'/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/configs/{config_file}.json' with open(config_path, 'r') as f: config_data = json.load(f) for key in config_data: setattr(args, key, config_data[key]) return args def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--seq_name', type=str) parser.add_argument('--seq_name_list', type=str) parser.add_argument('--config', type=str, default='scannet') parser.add_argument('--debug', action="store_true") parser.add_argument('--root', type=str) parser.add_argument('-d', '--devices', type=int, nargs='+', default=[0, 1, 2, 3]) args = parser.parse_args() args = update_args(args) return args def get_dataset(args): if args.dataset == 'scannet': dataset = ScanNetDataset(args.seq_name) elif args.dataset == 'scannetpp_dust3r_filtered_depth': dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_dust3r_filtered_depth') elif args.dataset == 'wild': dataset = WildDataset(args.seq_name, root=args.root) elif args.dataset == 'itw': dataset = ITWDataset(args.seq_name, root='data/itw') elif args.dataset == 'scannetpp_dust3r_posed': dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_dust3r_posed') elif args.dataset == 'scannetpp_dust3r_unposed': dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_dust3r_unposed') elif args.dataset == 'scannetpp_v2_dust3r_posed': dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_v2_dust3r_posed') elif args.dataset == 'scannetpp_v2_dust3r_unposed': dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_v2_dust3r_unposed') elif args.dataset == 'scannetpp_mapanything_posed': dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_mapanything_posed') elif args.dataset == 'scannet_dust3r_posed': dataset = ScanNet18Dataset(args.seq_name, root='data/scannet_dust3r_posed') elif args.dataset == 'scannet_dust3r_unposed': dataset = ScanNet18Dataset(args.seq_name, root='data/scannet_dust3r_unposed') elif args.dataset == 'scannet_dust3r_posed_15': dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_15') elif args.dataset == 'scannet_dust3r_posed_25': dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_25') elif args.dataset == 'scannet_dust3r_posed_35': dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_35') elif args.dataset == 'scannet_dust3r_posed_45': dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_45') elif args.dataset == 'scannet_dust3r_posed_45_andrey': dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_45_andrey') elif args.dataset == 'scannet_dust3r_posed_45_bulat': dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_45_bulat') elif args.dataset == 'scannet_dust3r_posed_35_bulat': dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_35_bulat') elif args.dataset == 'scannet_dust3r_unposed_15': dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_15') elif args.dataset == 'scannet_dust3r_unposed_25': dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_25') elif args.dataset == 'scannet_dust3r_unposed_35': dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_35') elif args.dataset == 'scannet_dust3r_unposed_45': dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_45') elif args.dataset == 'arkit_dust3r_posed': dataset = ARKitDataset(args.seq_name, root='data/arkit_dust3r_posed') elif args.dataset == 'arkit_gt': dataset = ARKitDataset(args.seq_name, root='data/arkit_gt') elif args.dataset == 'arkit_gt_train': dataset = ARKitDataset(args.seq_name, root='data/arkit_gt_train') elif args.dataset == 'arkit_vggt': dataset = ARKitDataset(args.seq_name, root='data/arkit_vggt') elif args.dataset == 'scannet18': dataset = ScanNet18Dataset(args.seq_name) elif args.dataset == 'scannetpp': dataset = ScanNetPPDataset(args.seq_name) elif args.dataset == 'matterport3d': dataset = MatterportDataset(args.seq_name) elif args.dataset == 'demo': dataset = DemoDataset(args.seq_name) else: print(args.dataset) raise NotImplementedError return dataset