| import argparse
|
| import numpy as np
|
| import torch
|
| from collections import defaultdict
|
| from tqdm import tqdm
|
| from transforms3d.quaternions import mat2quat
|
| import pandas as pd
|
|
|
| from model import PL_RelPose, keypoint_dict
|
| from utils.reproject import reprojection_error, Pose, save_submission
|
| from utils.metrics import reproj, add, adi, compute_continuous_auc, relative_pose_error, rotation_angular_error
|
| from datasets import dataset_dict
|
| from configs.default import get_cfg_defaults
|
|
|
|
|
| @torch.no_grad()
|
| def main(args):
|
| config = get_cfg_defaults()
|
| config.merge_from_file(args.config)
|
|
|
| task = config.DATASET.TASK
|
| dataset = config.DATASET.DATA_SOURCE
|
| device = args.device
|
|
|
| test_num_keypoints = test_num_keypoints = config.MODEL.TEST_NUM_KEYPOINTS
|
|
|
| build_fn = dataset_dict[task][dataset]
|
| testset = build_fn('test', config)
|
| testloader = torch.utils.data.DataLoader(testset, batch_size=1)
|
|
|
| pl_relpose = PL_RelPose.load_from_checkpoint(args.ckpt_path)
|
| pl_relpose.extractor = keypoint_dict[pl_relpose.hparams['features']](max_num_keypoints=test_num_keypoints, detection_threshold=0.0).eval().to(device)
|
| pl_relpose.module = pl_relpose.module.eval().to(device)
|
|
|
| preprocess_times, extract_times, regress_times = [], [], []
|
| adds, adis = [], []
|
| repr_errs = []
|
| R_errs, t_errs = [], []
|
| ts_errs = []
|
| results_dict = defaultdict(list)
|
| for i, data in enumerate(tqdm(testloader)):
|
| if dataset == 'ho3d' and args.obj_name is not None and data['objName'][0] != args.obj_name:
|
| continue
|
| image0, image1 = data['images'][0]
|
| K0, K1 = data['intrinsics'][0]
|
| T = torch.eye(4)
|
| T[:3, :3] = data['rotation'][0]
|
| T[:3, 3] = data['translation'][0]
|
| T = T.numpy()
|
|
|
|
|
| R_est, t_est, preprocess_time, extract_time, regress_time = pl_relpose.predict_one_data(data)
|
| preprocess_times.append(preprocess_time)
|
| extract_times.append(extract_time)
|
| regress_times.append(regress_time)
|
|
|
| t_err, R_err = relative_pose_error(T, R_est.cpu().numpy(), t_est.cpu().numpy(), ignore_gt_t_thr=0.0)
|
|
|
| R_errs.append(R_err)
|
| t_errs.append(t_err)
|
|
|
| ts_errs.append(torch.tensor(T[:3, 3] - t_est.cpu().numpy()).norm(2))
|
|
|
| if dataset == 'mapfree':
|
| repr_err = reprojection_error(R_est.cpu().numpy(), t_est.cpu().numpy(), T[:3, :3], T[:3, 3], K=K1, W=image1.shape[-1], H=image1.shape[-2])
|
| repr_errs.append(repr_err)
|
| R = R_est.detach().cpu().numpy()
|
| t = t_est.reshape(-1).detach().cpu().numpy()
|
| scene = data['scene_id'][0]
|
| estimated_pose = Pose(
|
| image_name=data['pair_names'][1][0],
|
| q=mat2quat(R).reshape(-1),
|
| t=t.reshape(-1),
|
| inliers=0
|
| )
|
| results_dict[scene].append(estimated_pose)
|
|
|
| if 'point_cloud' in data:
|
| adds.append(add(R_est.cpu().numpy(), t_est.cpu().numpy(), T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
|
| adis.append(adi(R_est.cpu().numpy(), t_est.cpu().numpy(), T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
|
|
|
| metrics = []
|
| values = []
|
|
|
| preprocess_times = np.array(preprocess_times) * 1000
|
| extract_times = np.array(extract_times) * 1000
|
| regress_times = np.array(regress_times) * 1000
|
|
|
| metrics.append('Extracting Time (ms)')
|
| values.append(f'{np.mean(extract_times):.1f}')
|
|
|
| metrics.append('Recovering Time (ms)')
|
| values.append(f'{np.mean(regress_times):.1f}')
|
|
|
| metrics.append('Total Time (ms)')
|
| values.append(f'{np.mean(extract_times) + np.mean(regress_times):.1f}')
|
|
|
|
|
|
|
|
|
|
|
| if task == 'object':
|
| metrics.append('Object ADD')
|
| values.append(f'{compute_continuous_auc(adds, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
|
|
|
| metrics.append('Object ADD-S')
|
| values.append(f'{compute_continuous_auc(adis, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
|
|
|
| if dataset == 'mapfree':
|
| re = np.array(repr_errs)
|
|
|
| metrics.append('VCRE @90px Prec.')
|
| values.append(f'{(re < 90).mean() * 100:.2f}')
|
|
|
| metrics.append('VCRE Med.')
|
| values.append(f'{np.median(re):.2f}')
|
|
|
| save_submission(results_dict, 'assets/new_submission.zip')
|
|
|
| res = pd.DataFrame({'Metrics': metrics, 'Values': values})
|
| print(res)
|
|
|
|
|
| def get_parser():
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument('config', type=str, help='.yaml configure file path')
|
| parser.add_argument('ckpt_path', type=str)
|
|
|
| parser.add_argument('--device', type=str, default='cuda:0')
|
| parser.add_argument('--obj_name', type=str, default=None)
|
|
|
| return parser
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = get_parser()
|
| args = parser.parse_args()
|
| main(args)
|
|
|