File size: 5,241 Bytes
55e58d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
from dataset.scannet import ITWDataset, ScanNetDataset, ScanNet18Dataset, ARKitDataset, ScannetPP2Dataset, ScanNet20Dataset, WildDataset, ITWDataset
from dataset.matterport import MatterportDataset
from dataset.scannetpp import ScanNetPPDataset
from dataset.demo import DemoDataset
import json

def update_args(args):
    config = args.config
    config_file = config
    if config in ['scannet', 'scannet18']:
        config_file = 'scannet'
    if config in ['scannetpp_v2_dust3r_posed', 'scannetpp_v2_dust3r_unposed']:
        config_file = config
    config_path = f'/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/configs/{config_file}.json'
    with open(config_path, 'r') as f:
        config_data = json.load(f)
    for key in config_data:
        setattr(args, key, config_data[key])
    return args

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seq_name', type=str)
    parser.add_argument('--seq_name_list', type=str)
    parser.add_argument('--config', type=str, default='scannet')
    parser.add_argument('--debug', action="store_true")
    parser.add_argument('--root', type=str)
    parser.add_argument('-d', '--devices', type=int, nargs='+', default=[0, 1, 2, 3])

    args = parser.parse_args()
    args = update_args(args)
    return args

def get_dataset(args):
    if args.dataset == 'scannet':
        dataset = ScanNetDataset(args.seq_name)
    elif args.dataset == 'scannetpp_dust3r_filtered_depth':
        dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_dust3r_filtered_depth')
    elif args.dataset == 'wild':
        dataset = WildDataset(args.seq_name, root=args.root)
    elif args.dataset == 'itw':
        dataset = ITWDataset(args.seq_name, root='data/itw')
    elif args.dataset == 'scannetpp_dust3r_posed':
        dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_dust3r_posed')
    elif args.dataset == 'scannetpp_dust3r_unposed':
        dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_dust3r_unposed')
    elif args.dataset == 'scannetpp_v2_dust3r_posed':
        dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_v2_dust3r_posed')
    elif args.dataset == 'scannetpp_v2_dust3r_unposed':
        dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_v2_dust3r_unposed')
    elif args.dataset == 'scannetpp_mapanything_posed':
        dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_mapanything_posed')
    elif args.dataset == 'scannet_dust3r_posed':
        dataset = ScanNet18Dataset(args.seq_name, root='data/scannet_dust3r_posed')
    elif args.dataset == 'scannet_dust3r_unposed':
        dataset = ScanNet18Dataset(args.seq_name, root='data/scannet_dust3r_unposed')
    elif args.dataset == 'scannet_dust3r_posed_15':
        dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_15')
    elif args.dataset == 'scannet_dust3r_posed_25':
        dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_25')
    elif args.dataset == 'scannet_dust3r_posed_35':
        dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_35')
    elif args.dataset == 'scannet_dust3r_posed_45':
        dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_45')
    elif args.dataset == 'scannet_dust3r_posed_45_andrey':
        dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_45_andrey')
    elif args.dataset == 'scannet_dust3r_posed_45_bulat':
        dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_45_bulat')
    elif args.dataset == 'scannet_dust3r_posed_35_bulat':
        dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_35_bulat')
    elif args.dataset == 'scannet_dust3r_unposed_15':
        dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_15')
    elif args.dataset == 'scannet_dust3r_unposed_25':
        dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_25')
    elif args.dataset == 'scannet_dust3r_unposed_35':
        dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_35')
    elif args.dataset == 'scannet_dust3r_unposed_45':
        dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_45')
    elif args.dataset == 'arkit_dust3r_posed':
        dataset = ARKitDataset(args.seq_name, root='data/arkit_dust3r_posed')
    elif args.dataset == 'arkit_gt':
        dataset = ARKitDataset(args.seq_name, root='data/arkit_gt')
    elif args.dataset == 'arkit_gt_train':
        dataset = ARKitDataset(args.seq_name, root='data/arkit_gt_train')
    elif args.dataset == 'arkit_vggt':
        dataset = ARKitDataset(args.seq_name, root='data/arkit_vggt')
    elif args.dataset == 'scannet18':
        dataset = ScanNet18Dataset(args.seq_name)
    elif args.dataset == 'scannetpp':
        dataset = ScanNetPPDataset(args.seq_name)
    elif args.dataset == 'matterport3d':
        dataset = MatterportDataset(args.seq_name)
    elif args.dataset == 'demo':
        dataset = DemoDataset(args.seq_name)
    else:
        print(args.dataset)
        raise NotImplementedError
    return dataset