| import numpy as np
|
| import argparse
|
| from tqdm import tqdm
|
| import torch
|
| import pandas as pd
|
|
|
|
|
| from configs.default import get_cfg_defaults
|
| from datasets import dataset_dict
|
| from baselines.pose import PoseRecover
|
| from utils.metrics import relative_pose_error, rotation_angular_error, error_auc, add, adi, compute_continuous_auc
|
|
|
|
|
| def main(args):
|
| config = get_cfg_defaults()
|
| config.merge_from_file(args.config)
|
|
|
| task = config.DATASET.TASK
|
| dataset = config.DATASET.DATA_SOURCE
|
|
|
|
|
|
|
|
|
|
|
|
|
| build_fn = dataset_dict[task][dataset]
|
| testset = build_fn('test', config)
|
| testloader = torch.utils.data.DataLoader(testset, batch_size=1)
|
|
|
| device = args.device
|
| img_resize = args.resize
|
| poseRec = PoseRecover(matcher=args.matcher, solver=args.solver, img_resize=img_resize, device=device)
|
|
|
| preprocess_times, extract_times, match_times, recover_times = [], [], [], []
|
| R_errs, t_errs = [], []
|
| ts_errs = []
|
| adds, adis = [], []
|
| for i, data in enumerate(tqdm(testloader)):
|
| if dataset == 'ho3d' and args.obj_name is not None and data['objName'][0] != args.obj_name:
|
| continue
|
|
|
| image0, image1 = data['images'][0].to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| bbox0, bbox1 = None, None
|
| if task == 'object':
|
| bbox0, bbox1 = data['bboxes'][0]
|
| x1, y1, x2, y2 = bbox0
|
| u1, v1, u2, v2 = bbox1
|
| image0 = image0[:, y1:y2, x1:x2]
|
| image1 = image1[:, v1:v2, u1:u2]
|
|
|
| mask0, mask1 = None, None
|
| if args.mask:
|
| mask0, mask1 = data['masks'][0].to(device)
|
|
|
| depth0, depth1 = None, None
|
| if args.depth:
|
| depth0, depth1 = data['depths'][0]
|
|
|
| K0, K1 = data['intrinsics'][0]
|
| T = torch.eye(4)
|
| T[:3, :3] = data['rotation'][0]
|
| T[:3, 3] = data['translation'][0]
|
| T = T.numpy()
|
|
|
| R, t, points0, points1, preprocess_time, extract_time, match_time, recover_time = poseRec.recover(image0, image1, K0, K1, bbox0, bbox1, mask0, mask1, depth0, depth1)
|
| preprocess_times.append(preprocess_time)
|
| extract_times.append(extract_time)
|
| match_times.append(match_time)
|
| recover_times.append(recover_time)
|
|
|
| if np.isnan(R).any():
|
| R_err = 180
|
| R = np.identity(3)
|
| t_err = 180
|
| t = np.array([0., 0., 0.])
|
| else:
|
| t_err, R_err = relative_pose_error(T, R, t, ignore_gt_t_thr=0.0)
|
|
|
| R_errs.append(R_err)
|
| t_errs.append(t_err)
|
|
|
| if args.depth:
|
| t = np.nan_to_num(t)
|
| ts_errs.append(torch.tensor(T[:3, 3] - t).norm(2))
|
|
|
| if task == 'object':
|
| if np.isnan(R).any():
|
| adds.append(1.)
|
| adis.append(1.)
|
| else:
|
| adds.append(add(R, t, T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
|
| adis.append(adi(R, t, T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
|
|
|
| metrics = []
|
| values = []
|
|
|
| preprocess_times = np.array(preprocess_time) * 1000
|
| extract_times = np.array(extract_time) * 1000
|
| match_times = np.array(match_times) * 1000
|
| recover_times = np.array(recover_time) * 1000
|
|
|
| metrics.append('Extracting Time (ms)')
|
| values.append(f'{np.mean(extract_times):.1f}')
|
|
|
| metrics.append('Matching Time (ms)')
|
| values.append(f'{np.mean(match_times):.1f}')
|
|
|
| metrics.append('Recovering Time (ms)')
|
| values.append(f'{np.mean(recover_times):.1f}')
|
|
|
| metrics.append('Total Time (ms)')
|
| values.append(f'{np.mean(extract_times) + np.mean(match_times) + np.mean(recover_times):.1f}')
|
|
|
|
|
| angular_thresholds = [5, 10, 20]
|
| pose_errors = np.max(np.stack([R_errs, t_errs]), axis=0)
|
| aucs = error_auc(pose_errors, angular_thresholds, mode='Pose estimation')
|
| for k in aucs:
|
| metrics.append(k)
|
| values.append(f'{aucs[k] * 100:.2f}')
|
|
|
| R_errs = torch.tensor(R_errs)
|
| t_errs = torch.tensor(t_errs)
|
|
|
| metrics.append('Rotation Avg. Error (°)')
|
| values.append(f'{R_errs.mean():.2f}')
|
|
|
| metrics.append('Rotation Med. Error (°)')
|
| values.append(f'{R_errs.median():.2f}')
|
|
|
| metrics.append('Rotation @30° ACC')
|
| values.append(f'{(R_errs < 30).float().mean() * 100:.1f}')
|
|
|
| metrics.append('Rotation @15° ACC')
|
| values.append(f'{(R_errs < 15).float().mean() * 100:.1f}')
|
|
|
| if args.depth:
|
| ts_errs = torch.tensor(ts_errs)
|
|
|
| metrics.append('Translation Avg. Error (m)')
|
| values.append(f'{ts_errs.mean():.4f}')
|
|
|
| metrics.append('Translation Med. Error (m)')
|
| values.append(f'{ts_errs.median():.4f}')
|
|
|
| metrics.append('Translation @1m ACC')
|
| values.append(f'{(ts_errs < 1.0).float().mean() * 100:.1f}')
|
|
|
| metrics.append('Translation @10cm ACC')
|
| values.append(f'{(ts_errs < 0.1).float().mean() * 100:.1f}')
|
|
|
| if task == 'object':
|
| metrics.append('Object ADD')
|
| values.append(f'{compute_continuous_auc(adds, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
|
|
|
| metrics.append('Object ADD-S')
|
| values.append(f'{compute_continuous_auc(adis, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
|
|
|
| res = pd.DataFrame({'Metrics': metrics, 'Values': values})
|
| print(res)
|
|
|
|
|
| def get_parser():
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument('config', type=str, help='.yaml configure file path')
|
| parser.add_argument('matcher', type=str)
|
| parser.add_argument('--solver', type=str, default='procrustes')
|
|
|
| parser.add_argument('--resize', type=int, default=None)
|
| parser.add_argument('--depth', action='store_true')
|
|
|
| parser.add_argument('--mask', action='store_true')
|
| parser.add_argument('--obj_name', type=str, default=None)
|
|
|
| parser.add_argument('--device', type=str, default='cuda:0')
|
|
|
| return parser
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = get_parser()
|
| args = parser.parse_args()
|
| main(args)
|
|
|