| |
| |
| |
| |
| |
| |
| |
| import os |
| import numpy as np |
| import random |
| import torch |
| import torchvision.transforms as tvf |
| import argparse |
| from tqdm import tqdm |
| from PIL import Image |
| import math |
|
|
| from mast3r.model import AsymmetricMASt3R |
| from mast3r.fast_nn import fast_reciprocal_NNs |
| from mast3r.utils.coarse_to_fine import select_pairs_of_crops, crop_slice |
| from mast3r.utils.collate import cat_collate, cat_collate_fn_map |
| from mast3r.utils.misc import mkdir_for |
| from mast3r.datasets.utils.cropping import crop_to_homography |
|
|
| import mast3r.utils.path_to_dust3r |
| from dust3r.inference import inference, loss_of_one_batch |
| from dust3r.utils.geometry import geotrf, colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics |
| from dust3r.datasets.utils.transforms import ImgNorm |
| from dust3r_visloc.datasets import * |
| from dust3r_visloc.localization import run_pnp |
| from dust3r_visloc.evaluation import get_pose_error, aggregate_stats, export_results |
| from dust3r_visloc.datasets.utils import get_HW_resolution, rescale_points3d |
|
|
|
|
| def get_args_parser(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--dataset", type=str, required=True, help="visloc dataset to eval") |
| parser_weights = parser.add_mutually_exclusive_group(required=True) |
| parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None) |
| parser_weights.add_argument("--model_name", type=str, help="name of the model weights", |
| choices=["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]) |
|
|
| parser.add_argument("--confidence_threshold", type=float, default=1.001, |
| help="confidence values higher than threshold are invalid") |
| parser.add_argument('--pixel_tol', default=5, type=int) |
|
|
| parser.add_argument("--coarse_to_fine", action='store_true', default=False, |
| help="do the matching from coarse to fine") |
| parser.add_argument("--max_image_size", type=int, default=None, |
| help="max image size for the fine resolution") |
| parser.add_argument("--c2f_crop_with_homography", action='store_true', default=False, |
| help="when using coarse to fine, crop with homographies to keep cx, cy centered") |
|
|
| parser.add_argument("--device", type=str, default='cuda', help="pytorch device") |
| parser.add_argument("--pnp_mode", type=str, default="cv2", choices=['cv2', 'poselib', 'pycolmap'], |
| help="pnp lib to use") |
| parser_reproj = parser.add_mutually_exclusive_group() |
| parser_reproj.add_argument("--reprojection_error", type=float, default=5.0, help="pnp reprojection error") |
| parser_reproj.add_argument("--reprojection_error_diag_ratio", type=float, default=None, |
| help="pnp reprojection error as a ratio of the diagonal of the image") |
|
|
| parser.add_argument("--max_batch_size", type=int, default=48, |
| help="max batch size for inference on crops when using coarse to fine") |
| parser.add_argument("--pnp_max_points", type=int, default=100_000, help="pnp maximum number of points kept") |
| parser.add_argument("--viz_matches", type=int, default=0, help="debug matches") |
|
|
| parser.add_argument("--output_dir", type=str, default=None, help="output path") |
| parser.add_argument("--output_label", type=str, default='', help="prefix for results files") |
| return parser |
|
|
|
|
| @torch.no_grad() |
| def coarse_matching(query_view, map_view, model, device, pixel_tol, fast_nn_params): |
| |
| imgs = [] |
| for idx, img in enumerate([query_view['rgb_rescaled'], map_view['rgb_rescaled']]): |
| imgs.append(dict(img=img.unsqueeze(0), true_shape=np.int32([img.shape[1:]]), |
| idx=idx, instance=str(idx))) |
| output = inference([tuple(imgs)], model, device, batch_size=1, verbose=False) |
| pred1, pred2 = output['pred1'], output['pred2'] |
| conf_list = [pred1['desc_conf'].squeeze(0).cpu().numpy(), pred2['desc_conf'].squeeze(0).cpu().numpy()] |
| desc_list = [pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()] |
|
|
| |
| PQ, PM = desc_list[0], desc_list[1] |
| if len(PQ) == 0 or len(PM) == 0: |
| return [], [], [], [] |
|
|
| if pixel_tol == 0: |
| matches_im_map, matches_im_query = fast_reciprocal_NNs(PM, PQ, subsample_or_initxy1=8, **fast_nn_params) |
| HM, WM = map_view['rgb_rescaled'].shape[1:] |
| HQ, WQ = query_view['rgb_rescaled'].shape[1:] |
| |
| valid_matches_map = (matches_im_map[:, 0] >= 3) & (matches_im_map[:, 0] < WM - 3) & ( |
| matches_im_map[:, 1] >= 3) & (matches_im_map[:, 1] < HM - 3) |
| valid_matches_query = (matches_im_query[:, 0] >= 3) & (matches_im_query[:, 0] < WQ - 3) & ( |
| matches_im_query[:, 1] >= 3) & (matches_im_query[:, 1] < HQ - 3) |
| valid_matches = valid_matches_map & valid_matches_query |
| matches_im_map = matches_im_map[valid_matches] |
| matches_im_query = matches_im_query[valid_matches] |
| valid_pts3d = [] |
| matches_confs = [] |
| else: |
| yM, xM = torch.where(map_view['valid_rescaled']) |
| matches_im_map, matches_im_query = fast_reciprocal_NNs(PM, PQ, (xM, yM), pixel_tol=pixel_tol, **fast_nn_params) |
| valid_pts3d = map_view['pts3d_rescaled'].cpu().numpy()[matches_im_map[:, 1], matches_im_map[:, 0]] |
| matches_confs = np.minimum( |
| conf_list[1][matches_im_map[:, 1], matches_im_map[:, 0]], |
| conf_list[0][matches_im_query[:, 1], matches_im_query[:, 0]] |
| ) |
| |
| matches_im_query = matches_im_query.astype(np.float64) |
| matches_im_map = matches_im_map.astype(np.float64) |
| matches_im_query[:, 0] += 0.5 |
| matches_im_query[:, 1] += 0.5 |
| matches_im_map[:, 0] += 0.5 |
| matches_im_map[:, 1] += 0.5 |
| |
| matches_im_query = geotrf(query_view['to_orig'], matches_im_query, norm=True) |
| matches_im_map = geotrf(map_view['to_orig'], matches_im_map, norm=True) |
| |
| matches_im_query[:, 0] -= 0.5 |
| matches_im_query[:, 1] -= 0.5 |
| matches_im_map[:, 0] -= 0.5 |
| matches_im_map[:, 1] -= 0.5 |
| return valid_pts3d, matches_im_query, matches_im_map, matches_confs |
|
|
|
|
| @torch.no_grad() |
| def crops_inference(pairs, model, device, batch_size=48, verbose=True): |
| assert len(pairs) == 2, "Error, data should be a tuple of dicts containing the batch of image pairs" |
| |
| B = pairs[0]['img'].shape[0] |
| if B < batch_size: |
| return loss_of_one_batch(pairs, model, None, device=device, symmetrize_batch=False) |
| preds = [] |
| for ii in range(0, B, batch_size): |
| sel = slice(ii, ii + min(B - ii, batch_size)) |
| temp_data = [{}, {}] |
| for di in [0, 1]: |
| temp_data[di] = {kk: pairs[di][kk][sel] |
| for kk in pairs[di].keys() if pairs[di][kk] is not None} |
| preds.append(loss_of_one_batch(temp_data, model, |
| None, device=device, symmetrize_batch=False)) |
| |
| return cat_collate(preds, collate_fn_map=cat_collate_fn_map) |
|
|
|
|
| @torch.no_grad() |
| def fine_matching(query_views, map_views, model, device, max_batch_size, pixel_tol, fast_nn_params): |
| assert pixel_tol > 0 |
| output = crops_inference([query_views, map_views], |
| model, device, batch_size=max_batch_size, verbose=False) |
| pred1, pred2 = output['pred1'], output['pred2'] |
| descs1 = pred1['desc'].clone() |
| descs2 = pred2['desc'].clone() |
| confs1 = pred1['desc_conf'].clone() |
| confs2 = pred2['desc_conf'].clone() |
|
|
| |
| valid_pts3d, matches_im_map, matches_im_query, matches_confs = [], [], [], [] |
| for ppi, (pp1, pp2, cc11, cc21) in enumerate(zip(descs1, descs2, confs1, confs2)): |
| valid_ppi = map_views['valid'][ppi] |
| pts3d_ppi = map_views['pts3d'][ppi].cpu().numpy() |
| conf_list_ppi = [cc11.cpu().numpy(), cc21.cpu().numpy()] |
|
|
| y_ppi, x_ppi = torch.where(valid_ppi) |
| matches_im_map_ppi, matches_im_query_ppi = fast_reciprocal_NNs(pp2, pp1, (x_ppi, y_ppi), |
| pixel_tol=pixel_tol, **fast_nn_params) |
|
|
| valid_pts3d_ppi = pts3d_ppi[matches_im_map_ppi[:, 1], matches_im_map_ppi[:, 0]] |
| matches_confs_ppi = np.minimum( |
| conf_list_ppi[1][matches_im_map_ppi[:, 1], matches_im_map_ppi[:, 0]], |
| conf_list_ppi[0][matches_im_query_ppi[:, 1], matches_im_query_ppi[:, 0]] |
| ) |
| |
| matches_im_map_ppi = geotrf(map_views['to_orig'][ppi].cpu().numpy(), matches_im_map_ppi.copy(), norm=True) |
| matches_im_query_ppi = geotrf(query_views['to_orig'][ppi].cpu().numpy(), matches_im_query_ppi.copy(), norm=True) |
|
|
| matches_im_map.append(matches_im_map_ppi) |
| matches_im_query.append(matches_im_query_ppi) |
| valid_pts3d.append(valid_pts3d_ppi) |
| matches_confs.append(matches_confs_ppi) |
|
|
| if len(valid_pts3d) == 0: |
| return [], [], [], [] |
|
|
| matches_im_map = np.concatenate(matches_im_map, axis=0) |
| matches_im_query = np.concatenate(matches_im_query, axis=0) |
| valid_pts3d = np.concatenate(valid_pts3d, axis=0) |
| matches_confs = np.concatenate(matches_confs, axis=0) |
| return valid_pts3d, matches_im_query, matches_im_map, matches_confs |
|
|
|
|
| def crop(img, mask, pts3d, crop, intrinsics=None): |
| out_cropped_img = img.clone() |
| if mask is not None: |
| out_cropped_mask = mask.clone() |
| else: |
| out_cropped_mask = None |
| if pts3d is not None: |
| out_cropped_pts3d = pts3d.clone() |
| else: |
| out_cropped_pts3d = None |
| to_orig = torch.eye(3, device=img.device) |
|
|
| |
| if intrinsics is not None: |
| K_old = intrinsics |
| imsize, K_new, R, H = crop_to_homography(K_old, crop) |
| |
| H /= H[2, 2] |
| homo8 = H.ravel().tolist()[:8] |
| |
| pilim = Image.fromarray((255 * (img + 1.) / 2).to(torch.uint8).numpy()) |
| pilout_cropped_img = pilim.transform(imsize, Image.Transform.PERSPECTIVE, |
| homo8, resample=Image.Resampling.BICUBIC) |
|
|
| |
| out_cropped_img = 2. * torch.tensor(np.array(pilout_cropped_img)).to(img) / 255. - 1. |
| if out_cropped_mask is not None: |
| pilmask = Image.fromarray((255 * out_cropped_mask).to(torch.uint8).numpy()) |
| pilout_cropped_mask = pilmask.transform( |
| imsize, Image.Transform.PERSPECTIVE, homo8, resample=Image.Resampling.NEAREST) |
| out_cropped_mask = torch.from_numpy(np.array(pilout_cropped_mask) > 0).to(out_cropped_mask.dtype) |
| if out_cropped_pts3d is not None: |
| out_cropped_pts3d = out_cropped_pts3d.numpy() |
| out_cropped_X = np.array(Image.fromarray(out_cropped_pts3d[:, :, 0]).transform(imsize, |
| Image.Transform.PERSPECTIVE, |
| homo8, |
| resample=Image.Resampling.NEAREST)) |
| out_cropped_Y = np.array(Image.fromarray(out_cropped_pts3d[:, :, 1]).transform(imsize, |
| Image.Transform.PERSPECTIVE, |
| homo8, |
| resample=Image.Resampling.NEAREST)) |
| out_cropped_Z = np.array(Image.fromarray(out_cropped_pts3d[:, :, 2]).transform(imsize, |
| Image.Transform.PERSPECTIVE, |
| homo8, |
| resample=Image.Resampling.NEAREST)) |
|
|
| out_cropped_pts3d = torch.from_numpy(np.stack([out_cropped_X, out_cropped_Y, out_cropped_Z], axis=-1)) |
|
|
| to_orig = torch.tensor(H, device=img.device) |
| else: |
| out_cropped_img = img[crop_slice(crop)] |
| if out_cropped_mask is not None: |
| out_cropped_mask = out_cropped_mask[crop_slice(crop)] |
| if out_cropped_pts3d is not None: |
| out_cropped_pts3d = out_cropped_pts3d[crop_slice(crop)] |
| to_orig[:2, -1] = torch.tensor(crop[:2]) |
|
|
| return out_cropped_img, out_cropped_mask, out_cropped_pts3d, to_orig |
|
|
|
|
| def resize_image_to_max(max_image_size, rgb, K): |
| W, H = rgb.size |
| if max_image_size and max(W, H) > max_image_size: |
| islandscape = (W >= H) |
| if islandscape: |
| WMax = max_image_size |
| HMax = int(H * (WMax / W)) |
| else: |
| HMax = max_image_size |
| WMax = int(W * (HMax / H)) |
| resize_op = tvf.Compose([ImgNorm, tvf.Resize(size=[HMax, WMax])]) |
| rgb_tensor = resize_op(rgb).permute(1, 2, 0) |
| to_orig_max = np.array([[W / WMax, 0, 0], |
| [0, H / HMax, 0], |
| [0, 0, 1]]) |
| to_resize_max = np.array([[WMax / W, 0, 0], |
| [0, HMax / H, 0], |
| [0, 0, 1]]) |
|
|
| |
| new_K = opencv_to_colmap_intrinsics(K) |
| new_K[0, :] *= WMax / W |
| new_K[1, :] *= HMax / H |
| new_K = colmap_to_opencv_intrinsics(new_K) |
| else: |
| rgb_tensor = ImgNorm(rgb).permute(1, 2, 0) |
| to_orig_max = np.eye(3) |
| to_resize_max = np.eye(3) |
| HMax, WMax = H, W |
| new_K = K |
| return rgb_tensor, new_K, to_orig_max, to_resize_max, (HMax, WMax) |
|
|
|
|
| if __name__ == '__main__': |
| parser = get_args_parser() |
| args = parser.parse_args() |
| conf_thr = args.confidence_threshold |
| device = args.device |
| pnp_mode = args.pnp_mode |
| assert args.pixel_tol > 0 |
| reprojection_error = args.reprojection_error |
| reprojection_error_diag_ratio = args.reprojection_error_diag_ratio |
| pnp_max_points = args.pnp_max_points |
| viz_matches = args.viz_matches |
|
|
| if args.weights is not None: |
| weights_path = args.weights |
| else: |
| weights_path = "naver/" + args.model_name |
| model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device) |
| fast_nn_params = dict(device=device, dist='dot', block_size=2**13) |
| dataset = eval(args.dataset) |
| dataset.set_resolution(model) |
|
|
| query_names = [] |
| poses_pred = [] |
| pose_errors = [] |
| angular_errors = [] |
| params_str = f'tol_{args.pixel_tol}' + ("_c2f" if args.coarse_to_fine else '') |
| if args.max_image_size is not None: |
| params_str = params_str + f'_{args.max_image_size}' |
| if args.coarse_to_fine and args.c2f_crop_with_homography: |
| params_str = params_str + '_with_homography' |
| for idx in tqdm(range(len(dataset))): |
| views = dataset[(idx)] |
| query_view = views[0] |
| map_views = views[1:] |
| query_names.append(query_view['image_name']) |
|
|
| query_pts2d = [] |
| query_pts3d = [] |
| maxdim = max(model.patch_embed.img_size) |
| query_rgb_tensor, query_K, query_to_orig_max, query_to_resize_max, (HQ, WQ) = resize_image_to_max( |
| args.max_image_size, query_view['rgb'], query_view['intrinsics']) |
|
|
| |
| query_resolution = get_HW_resolution(HQ, WQ, maxdim=maxdim, patchsize=model.patch_embed.patch_size) |
| for map_view in map_views: |
| if args.output_dir is not None: |
| cache_file = os.path.join(args.output_dir, 'matches', params_str, |
| query_view['image_name'], map_view['image_name'] + '.npz') |
| else: |
| cache_file = None |
|
|
| if cache_file is not None and os.path.isfile(cache_file): |
| matches = np.load(cache_file) |
| valid_pts3d = matches['valid_pts3d'] |
| matches_im_query = matches['matches_im_query'] |
| matches_im_map = matches['matches_im_map'] |
| matches_conf = matches['matches_conf'] |
| else: |
| |
| if args.coarse_to_fine and (maxdim < max(WQ, HQ)): |
| |
| _, coarse_matches_im0, coarse_matches_im1, _ = coarse_matching(query_view, map_view, model, device, |
| 0, fast_nn_params) |
|
|
| |
| if viz_matches > 0: |
| num_matches = coarse_matches_im1.shape[0] |
| print(f'found {num_matches} matches') |
|
|
| viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])] |
| from matplotlib import pyplot as pl |
| n_viz = viz_matches |
| match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int) |
| viz_matches_im_query = coarse_matches_im0[match_idx_to_viz] |
| viz_matches_im_map = coarse_matches_im1[match_idx_to_viz] |
|
|
| H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2] |
| img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), |
| 'constant', constant_values=0) |
| img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), |
| 'constant', constant_values=0) |
| img = np.concatenate((img0, img1), axis=1) |
| pl.figure() |
| pl.imshow(img) |
| cmap = pl.get_cmap('jet') |
| for i in range(n_viz): |
| (x0, y0), (x1, y1) = viz_matches_im_query[i].T, viz_matches_im_map[i].T |
| pl.plot([x0, x1 + W0], [y0, y1], '-+', |
| color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) |
| pl.show(block=True) |
|
|
| valid_all = map_view['valid'] |
| pts3d = map_view['pts3d'] |
|
|
| WM_full, HM_full = map_view['rgb'].size |
| map_rgb_tensor, map_K, map_to_orig_max, map_to_resize_max, (HM, WM) = resize_image_to_max( |
| args.max_image_size, map_view['rgb'], map_view['intrinsics']) |
| if WM_full != WM or HM_full != HM: |
| y_full, x_full = torch.where(valid_all) |
| pos2d_cv2 = torch.stack([x_full, y_full], dim=-1).cpu().numpy().astype(np.float64) |
| sparse_pts3d = pts3d[y_full, x_full].cpu().numpy() |
| _, _, pts3d_max, valid_max = rescale_points3d( |
| pos2d_cv2, sparse_pts3d, map_to_resize_max, HM, WM) |
| pts3d = torch.from_numpy(pts3d_max) |
| valid_all = torch.from_numpy(valid_max) |
|
|
| coarse_matches_im0 = geotrf(query_to_resize_max, coarse_matches_im0, norm=True) |
| coarse_matches_im1 = geotrf(map_to_resize_max, coarse_matches_im1, norm=True) |
|
|
| crops1, crops2 = [], [] |
| crops_v1, crops_p1 = [], [] |
| to_orig1, to_orig2 = [], [] |
| map_resolution = get_HW_resolution(HM, WM, maxdim=maxdim, patchsize=model.patch_embed.patch_size) |
|
|
| for crop_q, crop_b, pair_tag in select_pairs_of_crops(map_rgb_tensor, |
| query_rgb_tensor, |
| coarse_matches_im1, |
| coarse_matches_im0, |
| maxdim=maxdim, |
| overlap=.5, |
| forced_resolution=[map_resolution, |
| query_resolution]): |
| |
| if not args.c2f_crop_with_homography: |
| map_K = None |
| query_K = None |
|
|
| c1, v1, p1, trf1 = crop(map_rgb_tensor, valid_all, pts3d, crop_q, map_K) |
| c2, _, _, trf2 = crop(query_rgb_tensor, None, None, crop_b, query_K) |
| crops1.append(c1) |
| crops2.append(c2) |
| crops_v1.append(v1) |
| crops_p1.append(p1) |
| to_orig1.append(trf1) |
| to_orig2.append(trf2) |
|
|
| if len(crops1) == 0 or len(crops2) == 0: |
| valid_pts3d, matches_im_query, matches_im_map, matches_conf = [], [], [], [] |
| else: |
| crops1, crops2 = torch.stack(crops1), torch.stack(crops2) |
| if len(crops1.shape) == 3: |
| crops1, crops2 = crops1[None], crops2[None] |
| crops_v1 = torch.stack(crops_v1) |
| crops_p1 = torch.stack(crops_p1) |
| to_orig1, to_orig2 = torch.stack(to_orig1), torch.stack(to_orig2) |
| map_crop_view = dict(img=crops1.permute(0, 3, 1, 2), |
| instance=['1' for _ in range(crops1.shape[0])], |
| valid=crops_v1, pts3d=crops_p1, |
| to_orig=to_orig1) |
| query_crop_view = dict(img=crops2.permute(0, 3, 1, 2), |
| instance=['2' for _ in range(crops2.shape[0])], |
| to_orig=to_orig2) |
|
|
| |
| valid_pts3d, matches_im_query, matches_im_map, matches_conf = fine_matching(query_crop_view, |
| map_crop_view, |
| model, device, |
| args.max_batch_size, |
| args.pixel_tol, |
| fast_nn_params) |
| matches_im_query = geotrf(query_to_orig_max, matches_im_query, norm=True) |
| matches_im_map = geotrf(map_to_orig_max, matches_im_map, norm=True) |
| else: |
| |
| valid_pts3d, matches_im_query, matches_im_map, matches_conf = coarse_matching(query_view, map_view, |
| model, device, |
| args.pixel_tol, |
| fast_nn_params) |
| if cache_file is not None: |
| mkdir_for(cache_file) |
| np.savez(cache_file, valid_pts3d=valid_pts3d, matches_im_query=matches_im_query, |
| matches_im_map=matches_im_map, matches_conf=matches_conf) |
|
|
| |
| if len(matches_conf) > 0: |
| mask = matches_conf >= conf_thr |
| valid_pts3d = valid_pts3d[mask] |
| matches_im_query = matches_im_query[mask] |
| matches_im_map = matches_im_map[mask] |
| matches_conf = matches_conf[mask] |
|
|
| |
| if viz_matches > 0: |
| num_matches = matches_im_map.shape[0] |
| print(f'found {num_matches} matches') |
|
|
| viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])] |
| from matplotlib import pyplot as pl |
| n_viz = viz_matches |
| match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int) |
| viz_matches_im_query = matches_im_query[match_idx_to_viz] |
| viz_matches_im_map = matches_im_map[match_idx_to_viz] |
|
|
| H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2] |
| img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) |
| img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) |
| img = np.concatenate((img0, img1), axis=1) |
| pl.figure() |
| pl.imshow(img) |
| cmap = pl.get_cmap('jet') |
| for i in range(n_viz): |
| (x0, y0), (x1, y1) = viz_matches_im_query[i].T, viz_matches_im_map[i].T |
| pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) |
| pl.show(block=True) |
|
|
| if len(valid_pts3d) == 0: |
| pass |
| else: |
| query_pts3d.append(valid_pts3d) |
| query_pts2d.append(matches_im_query) |
|
|
| if len(query_pts2d) == 0: |
| success = False |
| pr_querycam_to_world = None |
| else: |
| query_pts2d = np.concatenate(query_pts2d, axis=0).astype(np.float32) |
| query_pts3d = np.concatenate(query_pts3d, axis=0) |
| if len(query_pts2d) > pnp_max_points: |
| idxs = random.sample(range(len(query_pts2d)), pnp_max_points) |
| query_pts3d = query_pts3d[idxs] |
| query_pts2d = query_pts2d[idxs] |
|
|
| W, H = query_view['rgb'].size |
| if reprojection_error_diag_ratio is not None: |
| reprojection_error_img = reprojection_error_diag_ratio * math.sqrt(W**2 + H**2) |
| else: |
| reprojection_error_img = reprojection_error |
| success, pr_querycam_to_world = run_pnp(query_pts2d, query_pts3d, |
| query_view['intrinsics'], query_view['distortion'], |
| pnp_mode, reprojection_error_img, img_size=[W, H]) |
|
|
| if not success: |
| abs_transl_error = float('inf') |
| abs_angular_error = float('inf') |
| else: |
| abs_transl_error, abs_angular_error = get_pose_error(pr_querycam_to_world, query_view['cam_to_world']) |
|
|
| pose_errors.append(abs_transl_error) |
| angular_errors.append(abs_angular_error) |
| poses_pred.append(pr_querycam_to_world) |
|
|
| xp_label = params_str + f'_conf_{conf_thr}' |
| if args.output_label: |
| xp_label = args.output_label + "_" + xp_label |
| if reprojection_error_diag_ratio is not None: |
| xp_label = xp_label + f'_reproj_diag_{reprojection_error_diag_ratio}' |
| else: |
| xp_label = xp_label + f'_reproj_err_{reprojection_error}' |
| export_results(args.output_dir, xp_label, query_names, poses_pred) |
| out_string = aggregate_stats(f'{args.dataset}', pose_errors, angular_errors) |
| print(out_string) |
|
|