bulatko's picture
adding real MK
55e58d1
raw
history blame
5.24 kB
import argparse
from dataset.scannet import ITWDataset, ScanNetDataset, ScanNet18Dataset, ARKitDataset, ScannetPP2Dataset, ScanNet20Dataset, WildDataset, ITWDataset
from dataset.matterport import MatterportDataset
from dataset.scannetpp import ScanNetPPDataset
from dataset.demo import DemoDataset
import json
def update_args(args):
config = args.config
config_file = config
if config in ['scannet', 'scannet18']:
config_file = 'scannet'
if config in ['scannetpp_v2_dust3r_posed', 'scannetpp_v2_dust3r_unposed']:
config_file = config
config_path = f'/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/configs/{config_file}.json'
with open(config_path, 'r') as f:
config_data = json.load(f)
for key in config_data:
setattr(args, key, config_data[key])
return args
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--seq_name', type=str)
parser.add_argument('--seq_name_list', type=str)
parser.add_argument('--config', type=str, default='scannet')
parser.add_argument('--debug', action="store_true")
parser.add_argument('--root', type=str)
parser.add_argument('-d', '--devices', type=int, nargs='+', default=[0, 1, 2, 3])
args = parser.parse_args()
args = update_args(args)
return args
def get_dataset(args):
if args.dataset == 'scannet':
dataset = ScanNetDataset(args.seq_name)
elif args.dataset == 'scannetpp_dust3r_filtered_depth':
dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_dust3r_filtered_depth')
elif args.dataset == 'wild':
dataset = WildDataset(args.seq_name, root=args.root)
elif args.dataset == 'itw':
dataset = ITWDataset(args.seq_name, root='data/itw')
elif args.dataset == 'scannetpp_dust3r_posed':
dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_dust3r_posed')
elif args.dataset == 'scannetpp_dust3r_unposed':
dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_dust3r_unposed')
elif args.dataset == 'scannetpp_v2_dust3r_posed':
dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_v2_dust3r_posed')
elif args.dataset == 'scannetpp_v2_dust3r_unposed':
dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_v2_dust3r_unposed')
elif args.dataset == 'scannetpp_mapanything_posed':
dataset = ScannetPP2Dataset(args.seq_name, root='data/scannetpp_mapanything_posed')
elif args.dataset == 'scannet_dust3r_posed':
dataset = ScanNet18Dataset(args.seq_name, root='data/scannet_dust3r_posed')
elif args.dataset == 'scannet_dust3r_unposed':
dataset = ScanNet18Dataset(args.seq_name, root='data/scannet_dust3r_unposed')
elif args.dataset == 'scannet_dust3r_posed_15':
dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_15')
elif args.dataset == 'scannet_dust3r_posed_25':
dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_25')
elif args.dataset == 'scannet_dust3r_posed_35':
dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_35')
elif args.dataset == 'scannet_dust3r_posed_45':
dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_45')
elif args.dataset == 'scannet_dust3r_posed_45_andrey':
dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_45_andrey')
elif args.dataset == 'scannet_dust3r_posed_45_bulat':
dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_45_bulat')
elif args.dataset == 'scannet_dust3r_posed_35_bulat':
dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_posed_35_bulat')
elif args.dataset == 'scannet_dust3r_unposed_15':
dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_15')
elif args.dataset == 'scannet_dust3r_unposed_25':
dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_25')
elif args.dataset == 'scannet_dust3r_unposed_35':
dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_35')
elif args.dataset == 'scannet_dust3r_unposed_45':
dataset = ScanNet20Dataset(args.seq_name, root='data/scannet_dust3r_unposed_45')
elif args.dataset == 'arkit_dust3r_posed':
dataset = ARKitDataset(args.seq_name, root='data/arkit_dust3r_posed')
elif args.dataset == 'arkit_gt':
dataset = ARKitDataset(args.seq_name, root='data/arkit_gt')
elif args.dataset == 'arkit_gt_train':
dataset = ARKitDataset(args.seq_name, root='data/arkit_gt_train')
elif args.dataset == 'arkit_vggt':
dataset = ARKitDataset(args.seq_name, root='data/arkit_vggt')
elif args.dataset == 'scannet18':
dataset = ScanNet18Dataset(args.seq_name)
elif args.dataset == 'scannetpp':
dataset = ScanNetPPDataset(args.seq_name)
elif args.dataset == 'matterport3d':
dataset = MatterportDataset(args.seq_name)
elif args.dataset == 'demo':
dataset = DemoDataset(args.seq_name)
else:
print(args.dataset)
raise NotImplementedError
return dataset