import torch from utils.config import get_dataset, get_args from utils.post_process import post_process from graph.construction import mask_graph_construction from graph.iterative_clustering import iterative_clustering from tqdm import tqdm import os def main(args): dataset = get_dataset(args) scene_points = dataset.get_scene_points() frame_list = dataset.get_frame_list(args.step) if os.path.exists(os.path.join(dataset.object_dict_dir, args.config, f'object_dict.npy')): return with torch.no_grad(): nodes, observer_num_thresholds, mask_point_clouds, point_frame_matrix = mask_graph_construction(args, scene_points, frame_list, dataset) object_list = iterative_clustering(nodes, observer_num_thresholds, args.view_consensus_threshold, args.debug) post_process(dataset, object_list, mask_point_clouds, scene_points, point_frame_matrix, frame_list, args) if __name__ == '__main__': args = get_args() seq_name_list = args.seq_name_list.split('+') for seq_name in tqdm(seq_name_list): args.seq_name = seq_name main(args)