| import argparse |
| from dataset.scannet import ITWDataset, ScanNetDataset, ScanNet18Dataset, ARKitDataset, ScannetPP2Dataset, ScanNet20Dataset, WildDataset, ITWDataset |
| from dataset.matterport import MatterportDataset |
| from dataset.scannetpp import ScanNetPPDataset |
| from dataset.demo import DemoDataset |
| import json |
|
|
| def update_args(args): |
| config = args.config |
| config_file = config |
| if config in ['scannet', 'scannet18']: |
| config_file = 'scannet' |
| if config in ['scannetpp_v2_dust3r_posed', 'scannetpp_v2_dust3r_unposed']: |
| config_file = config |
| config_path = f'/app/MaskClustering/configs/{config_file}.json' |
| with open(config_path, 'r') as f: |
| config_data = json.load(f) |
| for key in config_data: |
| setattr(args, key, config_data[key]) |
| return args |
|
|
| def get_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--seq_name', type=str) |
| parser.add_argument('--seq_name_list', type=str) |
| parser.add_argument('--config', type=str, default='scannet') |
| parser.add_argument('--debug', action="store_true") |
| parser.add_argument('--root', type=str) |
| parser.add_argument('-d', '--devices', type=int, nargs='+', default=[0, 1, 2, 3]) |
|
|
| args = parser.parse_args() |
| args = update_args(args) |
| return args |
|
|
| def get_dataset(args): |
| if args.dataset == 'scannet': |
| dataset = ScanNetDataset(args.seq_name) |
| elif args.dataset == 'scannetpp_dust3r_filtered_depth': |
| dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_dust3r_filtered_depth') |
| elif args.dataset == 'wild': |
| dataset = WildDataset(args.seq_name, root=args.root) |
| elif args.dataset == 'itw': |
| dataset = ITWDataset(args.seq_name, root='data/itw') |
| elif args.dataset == 'scannetpp_dust3r_posed': |
| dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_dust3r_posed') |
| elif args.dataset == 'scannetpp_dust3r_unposed': |
| dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_dust3r_unposed') |
| elif args.dataset == 'scannetpp_v2_dust3r_posed': |
| dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_v2_dust3r_posed') |
| elif args.dataset == 'scannetpp_v2_dust3r_unposed': |
| dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_v2_dust3r_unposed') |
| elif args.dataset == 'scannetpp_mapanything_posed': |
| dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_mapanything_posed') |
| elif args.dataset == 'scannet_dust3r_posed': |
| dataset = ScanNet18Dataset(args.seq_name, root='data/scannet_dust3r_posed') |
| elif args.dataset == 'scannet_dust3r_unposed': |
| dataset = ScanNet18Dataset(args.seq_name, root='data/scannet_dust3r_unposed') |
| elif args.dataset == 'scannet_dust3r_posed_15': |
| dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_15') |
| elif args.dataset == 'scannet_dust3r_posed_25': |
| dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_25') |
| elif args.dataset == 'scannet_dust3r_posed_35': |
| dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_35') |
| elif args.dataset == 'scannet_dust3r_posed_45': |
| dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_45') |
| elif args.dataset == 'scannet_dust3r_posed_45_andrey': |
| dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_45_andrey') |
| elif args.dataset == 'scannet_dust3r_posed_45_bulat': |
| dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_45_bulat') |
| elif args.dataset == 'scannet_dust3r_posed_35_bulat': |
| dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_35_bulat') |
| elif args.dataset == 'scannet_dust3r_unposed_15': |
| dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_15') |
| elif args.dataset == 'scannet_dust3r_unposed_25': |
| dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_25') |
| elif args.dataset == 'scannet_dust3r_unposed_35': |
| dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_35') |
| elif args.dataset == 'scannet_dust3r_unposed_45': |
| dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_45') |
| elif args.dataset == 'arkit_dust3r_posed': |
| dataset = ARKitDataset(args.seq_name, root='data/arkit_dust3r_posed') |
| elif args.dataset == 'arkit_gt': |
| dataset = ARKitDataset(args.seq_name, root='data/arkit_gt') |
| elif args.dataset == 'arkit_gt_train': |
| dataset = ARKitDataset(args.seq_name, root='data/arkit_gt_train') |
| elif args.dataset == 'arkit_vggt': |
| dataset = ARKitDataset(args.seq_name, root='data/arkit_vggt') |
| elif args.dataset == 'scannet18': |
| dataset = ScanNet18Dataset(args.seq_name) |
| elif args.dataset == 'scannetpp': |
| dataset = ScanNetPPDataset(args.seq_name) |
| elif args.dataset == 'matterport3d': |
| dataset = MatterportDataset(args.seq_name) |
| elif args.dataset == 'demo': |
| dataset = DemoDataset(args.seq_name) |
| else: |
| print(args.dataset) |
| raise NotImplementedError |
| return dataset |
|
|
|
|