| import os |
| import copy |
| import argparse |
| import numpy as np |
| import random |
| import torch |
| from tabulate import tabulate |
| from tqdm import tqdm |
| import sys |
| from pathlib import Path |
|
|
| _REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(_REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(_REPO_ROOT)) |
|
|
| from tools import data, l3d_helper, visualization |
| from tools import icp_registration_and_evaluation, l3d_registration_and_evaluation, predator_registration_and_evaluation, geotransformer_registration_and_evaluation, logdesc_registration_and_evaluation, regtr_registration_and_evaluation |
| from r3pm_net.config_loader import get_pretrained_rpmnet_dir, get_sioux_data_root, get_method_paths |
|
|
| ''' |
| This script is used to evaluate the performance of the pipeline with R3PM-Net as global and GICP as local registeration. |
| |
| The script takes the following arguments: |
| --local_reg: the local registration method to be used. |
| --seed: random seed for python/numpy/torch. The default is 42. |
| --verbose: if set to True, the results will be printed in a table format. The default is False. |
| ''' |
| def set_seed(seed: int) -> None: |
| os.environ["PYTHONHASHSEED"] = str(seed) |
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" |
|
|
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
| torch.backends.cudnn.benchmark = False |
| torch.backends.cudnn.deterministic = True |
| torch.use_deterministic_algorithms(True) |
|
|
|
|
| |
| parser = argparse.ArgumentParser(description="Choosing local registration method") |
| parser.add_argument( |
| "--local_reg", type=str, default="gicp", help="local registration: gicp or freg" |
| ) |
| parser.add_argument("--seed", type=int, default=42, help="random seed (default: 42)") |
|
|
| args = parser.parse_args() |
| set_seed(args.seed) |
| print(f"Using {args.local_reg} for local registration") |
|
|
| def analyze_results(results: dict, recall_threshold = 1, rmse_threshold = 0.053, verbose = False): |
| table = [] |
| fail_count = 0 |
| success_count = 0 |
| for object, values in results.items(): |
| row = [object] + list(values) |
| if round(row[2], 3) < recall_threshold or round(row[3], 3) > rmse_threshold: |
| status = 'failed' |
| fail_count += 1 |
| print(f'No match for {object}! Try a different method. If the issue persists, please check the data.') |
| else: |
| status = 'success' |
| success_count += 1 |
| print(f'Found match for {object}!') |
| row.append(status) |
| table.append(row) |
|
|
| if verbose: |
| print(tabulate(table, headers=['Object', 'Chamfer Distance', 'Reg. Recall', 'Inlier RMSE', 'Computation Time', 'Status'], tablefmt='grid')) |
| print(f"Success rate: {success_count / (success_count + fail_count) * 100:.2f}%") |
|
|
| return table |
|
|
| def show_successful_resutls(table, sources, targets, pc_results, method_name = None): |
| for i in range (len(table)): |
| if table[i][-1] == 'success': |
| |
| visualization.draw_registration_result(targets[i], list(pc_results.values())[i], np.eye(4), method_name) |
|
|
| def main(): |
| base_dir = get_sioux_data_root() |
| scan_dir = os.path.join(base_dir, 'sioux_scans') |
| cad_dir = os.path.join(base_dir, 'sioux_cranfield') |
|
|
| pcd_paths = [ os.path.join(scan_dir,'teeth_clean.ply'), |
| os.path.join(scan_dir,'lime_clean.ply'), |
| os.path.join(scan_dir,'cube_clean.ply'), |
| os.path.join(scan_dir,'lego_clean.ply'), |
| os.path.join(scan_dir,'elephant_clean.ply'), |
| os.path.join(scan_dir,'house_clean.ply'), |
| os.path.join(scan_dir,'shoe_clean.ply')] |
|
|
| cad_paths = [ os.path.join(cad_dir,'teeth.stl'), |
| os.path.join(cad_dir,'lime.stl'), |
| os.path.join(cad_dir,'cube.stl'), |
| os.path.join(cad_dir,'lego.stl'), |
| os.path.join(cad_dir,'elephant.stl'), |
| os.path.join(cad_dir,'house.stl'), |
| os.path.join(cad_dir,'shoe.stl')] |
|
|
| |
| rpm_net_results = {} |
| rpm_net_pc_results = {} |
| predator_results = {} |
| predator_pc_results = {} |
| geotransformer_results = {} |
| geotransformer_pc_results = {} |
| logdesc_results = {} |
| logdesc_pc_results = {} |
| regtr_results = {} |
| regtr_pc_results = {} |
| r3pm_net_results = {} |
| r3pm_net_pc_results ={} |
| tuned_r3pm_net_results = {} |
| tuned_r3pm_net_pc_results = {} |
| subset_tuned_r3pm_net_results = {} |
| subset_tuned_r3pm_net_pc_results = {} |
|
|
| sources = [] |
| targets = [] |
|
|
| pretrained_base_dir = get_pretrained_rpmnet_dir() |
| method_paths = get_method_paths() |
| _path_zs = os.path.join(pretrained_base_dir, "clean-trained.pth") |
| _path_ft = os.path.join(pretrained_base_dir, "best_model_PointNet2.t7") |
| _path_ft_sub = os.path.join(pretrained_base_dir, "best_model_PointNet_subset.t7") |
|
|
| |
| rpm_args = l3d_helper.options(modelName="RPMNet") |
| rpm_args.pretrained = _path_zs |
|
|
| |
| predator_cfg = method_paths.get("predator", {}) |
| predator_root = predator_cfg.get("root") |
| predator_config_path = predator_cfg.get("config_path") |
| predator_weights_path = predator_cfg.get("weights_path") |
|
|
| |
| geo_cfg = method_paths.get("geotransformer", {}) |
| geotransformer_root = geo_cfg.get("root") |
| geotransformer_exp_subdir = geo_cfg.get("exp_subdir") |
| geotransformer_weights_path = geo_cfg.get("weights_path") |
|
|
| |
| logdesc_cfg = method_paths.get("logdesc", {}) |
| logdesc_root = logdesc_cfg.get("root") |
| logdesc_weights_path = logdesc_cfg.get("weights_path") |
|
|
| |
| regtr_cfg = method_paths.get("regtr", {}) |
| regtr_root = regtr_cfg.get("root") |
| regtr_ckpt_path = regtr_cfg.get("ckpt_path") |
| regtr_config_path = regtr_cfg.get("config_path") |
|
|
| |
| r3pm_net_args = l3d_helper.options(modelName="R3PMNet") |
| r3pm_net_args.pretrained = _path_zs |
|
|
| |
| tuned_r3pm_net_args = l3d_helper.options(modelName="R3PMNet") |
| tuned_r3pm_net_args.pretrained = _path_ft |
|
|
| |
| subset_tuned_r3pm_net_args = l3d_helper.options(modelName="R3PMNet") |
| subset_tuned_r3pm_net_args.pretrained = _path_ft_sub |
|
|
| for pcdPath, cadPath in tqdm(zip(pcd_paths, cad_paths), desc="Registering objects", total=len(pcd_paths)): |
| |
| if 'teeth' in pcdPath: |
| every_k_points = 100 |
| key = 'teeth' |
| elif'lime' in pcdPath: |
| every_k_points = 100 |
| key = 'lime' |
| elif 'cube' in pcdPath: |
| every_k_points = 1 |
| key = 'cube' |
| elif 'lego' in pcdPath: |
| every_k_points = 10 |
| key = 'lego' |
| elif 'elephant' in pcdPath: |
| every_k_points = 30 |
| key = 'elephant' |
| elif 'house' in pcdPath: |
| every_k_points = 25 |
| key = 'house' |
| elif 'shoe' in pcdPath: |
| every_k_points = 15 |
| key = 'shoe' |
| else: |
| print("Unknown object type, using default every_k_points = 1") |
| every_k_points = 1 |
| |
| |
| pcd, cad = data.load_data(pcdPath, cadPath, every_k_points=every_k_points) |
| source = copy.deepcopy(pcd) |
| target = copy.deepcopy(cad) |
| |
| |
| source = data.normalize_pc(source) |
| target = data.normalize_pc(target) |
|
|
| sources.append(source) |
| targets.append(target) |
|
|
| gt_transformation = None |
|
|
| |
|
|
| |
| rpm_pc_result, _ = l3d_registration_and_evaluation.l3d_reg_and_eval( |
| source, target, 'rpmnet', gt_transformation, rpm_args) |
| if args.local_reg == 'gicp': |
| final_rpm_net_pc_result, final_rpm_net_results = icp_registration_and_evaluation.icp_reg_and_eval(rpm_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation) |
| rpm_net_results[key] = final_rpm_net_results |
| rpm_net_pc_results[key] = final_rpm_net_pc_result |
|
|
| |
| predator_results_pc, _ = predator_registration_and_evaluation.predator_reg_and_eval( |
| source, |
| target, |
| gt_transformation=gt_transformation, |
| predator_root=predator_root, |
| config_path=predator_config_path, |
| weights_path=predator_weights_path, |
| ransac_n_points=1000, |
| ransac_distance_threshold=0.05, |
| ransac_n=3, |
| sampling="prob", |
| mutual=False, |
| input_num_points=1024, |
| ) |
| if args.local_reg == 'gicp': |
| final_predator_pc_result, final_predator_results = icp_registration_and_evaluation.icp_reg_and_eval(predator_results_pc, target, 'gicp', 1, np.identity(4), gt_transformation) |
| predator_results[key] = final_predator_results |
| predator_pc_results[key] = final_predator_pc_result |
|
|
| |
| geotransformer_pc_result, _ = geotransformer_registration_and_evaluation.geotransformer_reg_and_eval( |
| source, |
| target, |
| gt_transformation=gt_transformation, |
| geotransformer_root=geotransformer_root, |
| exp_subdir=geotransformer_exp_subdir, |
| weights_path=geotransformer_weights_path, |
| ) |
| if args.local_reg == 'gicp': |
| final_geotransformer_pc_result, final_geotransformer_results = icp_registration_and_evaluation.icp_reg_and_eval(geotransformer_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation) |
| geotransformer_results[key] = final_geotransformer_results |
| geotransformer_pc_results[key] = final_geotransformer_pc_result |
|
|
| |
| logdesc_pc_result, _ = logdesc_registration_and_evaluation.logdesc_reg_and_eval( |
| source, |
| target, |
| gt_transformation=gt_transformation, |
| logdesc_root=logdesc_root, |
| weights_path=logdesc_weights_path, |
| max_keypoints=768, |
| num_points_per_sample=128, |
| sample_radius=0.3, |
| topk_matches=128, |
| use_kpt=False, |
| ) |
| if args.local_reg == 'gicp': |
| final_logdesc_pc_result, final_logdesc_results = icp_registration_and_evaluation.icp_reg_and_eval(logdesc_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation) |
| logdesc_results[key] = final_logdesc_results |
| logdesc_pc_results[key] = final_logdesc_pc_result |
|
|
| |
| regtr_pc_result, _ = regtr_registration_and_evaluation.regtr_reg_and_eval( |
| source, |
| target, |
| gt_transformation=gt_transformation, |
| regtr_root=regtr_root, |
| ckpt_path=regtr_ckpt_path, |
| config_path=regtr_config_path, |
| ) |
| if args.local_reg == 'gicp': |
| final_regtr_pc_result, final_regtr_results = icp_registration_and_evaluation.icp_reg_and_eval(regtr_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation) |
| regtr_results[key] = final_regtr_results |
| regtr_pc_results[key] = final_regtr_pc_result |
|
|
| |
| r3pm_net_pc_result, _ = l3d_registration_and_evaluation.l3d_reg_and_eval(source, target, 'r3pmnet', gt_transformation, r3pm_net_args) |
| if args.local_reg == 'gicp': |
| final_r3pm_net_pc_result, final_r3pm_net_results = icp_registration_and_evaluation.icp_reg_and_eval(r3pm_net_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation) |
| r3pm_net_results[key] = final_r3pm_net_results |
| r3pm_net_pc_results[key] = final_r3pm_net_pc_result |
|
|
| |
| tuned_r3pm_net_pc_result, _ = l3d_registration_and_evaluation.l3d_reg_and_eval(source, target, 'r3pmnet', gt_transformation, tuned_r3pm_net_args) |
| if args.local_reg == 'gicp': |
| final_tuned_r3pm_net_pc_result, final_tuned_r3pm_net_results = icp_registration_and_evaluation.icp_reg_and_eval(tuned_r3pm_net_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation) |
| tuned_r3pm_net_results[key] = final_tuned_r3pm_net_results |
| tuned_r3pm_net_pc_results[key] = final_tuned_r3pm_net_pc_result |
|
|
| |
| subset_tuned_r3pm_net_pc_result, _ = l3d_registration_and_evaluation.l3d_reg_and_eval(source, target, 'r3pmnet', gt_transformation, subset_tuned_r3pm_net_args) |
| if args.local_reg == 'gicp': |
| final_subset_tuned_r3pm_net_pc_result, final_subset_tuned_r3pm_net_results = icp_registration_and_evaluation.icp_reg_and_eval(subset_tuned_r3pm_net_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation) |
| subset_tuned_r3pm_net_results[key] = final_subset_tuned_r3pm_net_results |
| subset_tuned_r3pm_net_pc_results[key] = final_subset_tuned_r3pm_net_pc_result |
|
|
| |
| print("----- RPMNet: -----") |
| rpm_net_table = analyze_results(rpm_net_results, verbose=True) |
| show_successful_resutls(rpm_net_table, sources, targets, rpm_net_pc_results, 'RPMNet') |
|
|
| print("----- Predator: -----") |
| predator_table = analyze_results(predator_results, verbose=True) |
| show_successful_resutls(predator_table, sources, targets, predator_pc_results, 'Predator') |
|
|
| print("----- GeoTransformer: -----") |
| geotransformer_table = analyze_results(geotransformer_results, verbose=True) |
| show_successful_resutls(geotransformer_table, sources, targets, geotransformer_pc_results, 'GeoTransformer') |
|
|
| print("----- LoGDesc: -----") |
| logdesc_table = analyze_results(logdesc_results, verbose=True) |
| show_successful_resutls(logdesc_table, sources, targets, logdesc_pc_results, 'LoGDesc') |
|
|
| print("----- RegTR: -----") |
| regtr_table = analyze_results(regtr_results, verbose=True) |
| show_successful_resutls(regtr_table, sources, targets, regtr_pc_results, 'RegTR') |
|
|
| print("----- R3PM-Net (ours) (ZS): -----") |
| r3pm_net_table = analyze_results(r3pm_net_results, verbose=True) |
| show_successful_resutls(r3pm_net_table, sources, targets, r3pm_net_pc_results, 'R3PM-Net (ours) (ZS)') |
|
|
| print("----- R3PM-Net (ours) (FT): ----- ") |
| tuned_r3pm_net_table = analyze_results(tuned_r3pm_net_results, verbose=True) |
| show_successful_resutls(tuned_r3pm_net_table, sources, targets, tuned_r3pm_net_pc_results, 'R3PM-Net (ours) (FT)') |
|
|
| print("----- R3PM-Net (ours) (FT) (Subset): ----- ") |
| subset_tuned_r3pm_net_table = analyze_results(subset_tuned_r3pm_net_results, verbose=True) |
| show_successful_resutls(subset_tuned_r3pm_net_table, sources, targets, subset_tuned_r3pm_net_pc_results, 'R3PM-Net (ours) (FT) (Subset)') |
|
|
| if __name__ == "__main__": |
| main() |