| | |
| | |
| | |
| | |
| | |
| | |
| | import os |
| | import torch |
| | import copy |
| | import numpy as np |
| | import torchvision |
| | import numpy as np |
| | from tqdm import tqdm |
| | from scipy.cluster.hierarchy import DisjointSet |
| | from scipy.spatial.transform import Rotation as R |
| |
|
| | from mast3r.utils.misc import hash_md5 |
| |
|
| | from mast3r.fast_nn import extract_correspondences_nonsym, bruteforce_reciprocal_nns |
| |
|
| | import mast3r.utils.path_to_dust3r |
| | from dust3r.utils.geometry import find_reciprocal_matches, xy_grid, geotrf |
| |
|
| |
|
| | def convert_im_matches_pairs(img0, img1, image_to_colmap, im_keypoints, matches_im0, matches_im1, viz): |
| | if viz: |
| | from matplotlib import pyplot as pl |
| |
|
| | image_mean = torch.as_tensor( |
| | [0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1) |
| | image_std = torch.as_tensor( |
| | [0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1) |
| | rgb0 = img0['img'] * image_std + image_mean |
| | rgb0 = torchvision.transforms.functional.to_pil_image(rgb0[0]) |
| | rgb0 = np.array(rgb0) |
| |
|
| | rgb1 = img1['img'] * image_std + image_mean |
| | rgb1 = torchvision.transforms.functional.to_pil_image(rgb1[0]) |
| | rgb1 = np.array(rgb1) |
| |
|
| | imgs = [rgb0, rgb1] |
| | |
| | n_viz = 100 |
| | num_matches = matches_im0.shape[0] |
| | match_idx_to_viz = np.round(np.linspace( |
| | 0, num_matches - 1, n_viz)).astype(int) |
| | viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz] |
| |
|
| | H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2] |
| | rgb0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), |
| | (0, 0), (0, 0)), 'constant', constant_values=0) |
| | rgb1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), |
| | (0, 0), (0, 0)), 'constant', constant_values=0) |
| | img = np.concatenate((rgb0, rgb1), axis=1) |
| | pl.figure() |
| | pl.imshow(img) |
| | cmap = pl.get_cmap('jet') |
| | for ii in range(n_viz): |
| | (x0, y0), (x1, |
| | y1) = viz_matches_im0[ii].T, viz_matches_im1[ii].T |
| | pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(ii / |
| | (n_viz - 1)), scalex=False, scaley=False) |
| | pl.show(block=True) |
| |
|
| | matches = [matches_im0.astype(np.float64), matches_im1.astype(np.float64)] |
| | imgs = [img0, img1] |
| | imidx0 = img0['idx'] |
| | imidx1 = img1['idx'] |
| | ravel_matches = [] |
| | for j in range(2): |
| | H, W = imgs[j]['true_shape'][0] |
| | with np.errstate(invalid='ignore'): |
| | qx, qy = matches[j].round().astype(np.int32).T |
| | ravel_matches_j = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy) |
| | ravel_matches.append(ravel_matches_j) |
| | imidxj = imgs[j]['idx'] |
| | for m in ravel_matches_j: |
| | if m not in im_keypoints[imidxj]: |
| | im_keypoints[imidxj][m] = 0 |
| | im_keypoints[imidxj][m] += 1 |
| | imid0 = copy.deepcopy(image_to_colmap[imidx0]['colmap_imid']) |
| | imid1 = copy.deepcopy(image_to_colmap[imidx1]['colmap_imid']) |
| | if imid0 > imid1: |
| | colmap_matches = np.stack([ravel_matches[1], ravel_matches[0]], axis=-1) |
| | imid0, imid1 = imid1, imid0 |
| | imidx0, imidx1 = imidx1, imidx0 |
| | else: |
| | colmap_matches = np.stack([ravel_matches[0], ravel_matches[1]], axis=-1) |
| | colmap_matches = np.unique(colmap_matches, axis=0) |
| | return imidx0, imidx1, colmap_matches |
| |
|
| |
|
| | def get_im_matches(pred1, pred2, pairs, image_to_colmap, im_keypoints, conf_thr, |
| | is_sparse=True, subsample=8, pixel_tol=0, viz=False, device='cuda'): |
| | im_matches = {} |
| | for i in range(len(pred1['pts3d'])): |
| | imidx0 = pairs[i][0]['idx'] |
| | imidx1 = pairs[i][1]['idx'] |
| | if 'desc' in pred1: |
| | descs = [pred1['desc'][i], pred2['desc'][i]] |
| | confidences = [pred1['desc_conf'][i], pred2['desc_conf'][i]] |
| | desc_dim = descs[0].shape[-1] |
| |
|
| | if is_sparse: |
| | corres = extract_correspondences_nonsym(descs[0], descs[1], confidences[0], confidences[1], |
| | device=device, subsample=subsample, pixel_tol=pixel_tol) |
| | conf = corres[2] |
| | mask = conf >= conf_thr |
| | matches_im0 = corres[0][mask].cpu().numpy() |
| | matches_im1 = corres[1][mask].cpu().numpy() |
| | else: |
| | confidence_masks = [confidences[0] >= |
| | conf_thr, confidences[1] >= conf_thr] |
| | pts2d_list, desc_list = [], [] |
| | for j in range(2): |
| | conf_j = confidence_masks[j].cpu().numpy().flatten() |
| | true_shape_j = pairs[i][j]['true_shape'][0] |
| | pts2d_j = xy_grid( |
| | true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j] |
| | desc_j = descs[j].detach().cpu( |
| | ).numpy().reshape(-1, desc_dim)[conf_j] |
| | pts2d_list.append(pts2d_j) |
| | desc_list.append(desc_j) |
| | if len(desc_list[0]) == 0 or len(desc_list[1]) == 0: |
| | continue |
| |
|
| | nn0, nn1 = bruteforce_reciprocal_nns(desc_list[0], desc_list[1], |
| | device=device, dist='dot', block_size=2**13) |
| | reciprocal_in_P0 = (nn1[nn0] == np.arange(len(nn0))) |
| |
|
| | matches_im1 = pts2d_list[1][nn0][reciprocal_in_P0] |
| | matches_im0 = pts2d_list[0][reciprocal_in_P0] |
| | else: |
| | pts3d = [pred1['pts3d'][i], pred2['pts3d_in_other_view'][i]] |
| | confidences = [pred1['conf'][i], pred2['conf'][i]] |
| |
|
| | if is_sparse: |
| | corres = extract_correspondences_nonsym(pts3d[0], pts3d[1], confidences[0], confidences[1], |
| | device=device, subsample=subsample, pixel_tol=pixel_tol, |
| | ptmap_key='3d') |
| | conf = corres[2] |
| | mask = conf >= conf_thr |
| | matches_im0 = corres[0][mask].cpu().numpy() |
| | matches_im1 = corres[1][mask].cpu().numpy() |
| | else: |
| | confidence_masks = [confidences[0] >= |
| | conf_thr, confidences[1] >= conf_thr] |
| | |
| | pts2d_list, pts3d_list = [], [] |
| | for j in range(2): |
| | conf_j = confidence_masks[j].cpu().numpy().flatten() |
| | true_shape_j = pairs[i][j]['true_shape'][0] |
| | pts2d_j = xy_grid(true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j] |
| | pts3d_j = pts3d[j].detach().cpu().numpy().reshape(-1, 3)[conf_j] |
| | pts2d_list.append(pts2d_j) |
| | pts3d_list.append(pts3d_j) |
| |
|
| | PQ, PM = pts3d_list[0], pts3d_list[1] |
| | if len(PQ) == 0 or len(PM) == 0: |
| | continue |
| | reciprocal_in_PM, nnM_in_PQ, num_matches = find_reciprocal_matches( |
| | PQ, PM) |
| |
|
| | matches_im1 = pts2d_list[1][reciprocal_in_PM] |
| | matches_im0 = pts2d_list[0][nnM_in_PQ][reciprocal_in_PM] |
| |
|
| | if len(matches_im0) == 0: |
| | continue |
| | imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1], |
| | image_to_colmap, im_keypoints, |
| | matches_im0, matches_im1, viz) |
| | im_matches[(imidx0, imidx1)] = colmap_matches |
| | return im_matches |
| |
|
| |
|
| | def get_im_matches_from_cache(pairs, cache_path, desc_conf, subsample, |
| | image_to_colmap, im_keypoints, conf_thr, |
| | viz=False, device='cuda'): |
| | im_matches = {} |
| | for i in range(len(pairs)): |
| | imidx0 = pairs[i][0]['idx'] |
| | imidx1 = pairs[i][1]['idx'] |
| |
|
| | corres_idx1 = hash_md5(pairs[i][0]['instance']) |
| | corres_idx2 = hash_md5(pairs[i][1]['instance']) |
| |
|
| | path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx1}-{corres_idx2}.pth' |
| | if os.path.isfile(path_corres): |
| | score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device) |
| | else: |
| | path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx2}-{corres_idx1}.pth' |
| | score, (xy2, xy1, confs) = torch.load(path_corres, map_location=device) |
| | mask = confs >= conf_thr |
| | matches_im0 = xy1[mask].cpu().numpy() |
| | matches_im1 = xy2[mask].cpu().numpy() |
| |
|
| | if len(matches_im0) == 0: |
| | continue |
| | imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1], |
| | image_to_colmap, im_keypoints, |
| | matches_im0, matches_im1, viz) |
| | im_matches[(imidx0, imidx1)] = colmap_matches |
| | return im_matches |
| |
|
| |
|
| | def export_images(db, images, image_paths, focals, ga_world_to_cam, camera_model): |
| | |
| | |
| | image_to_colmap = {} |
| | im_keypoints = {} |
| | for idx in range(len(image_paths)): |
| | im_keypoints[idx] = {} |
| | H, W = images[idx]["orig_shape"] |
| | if focals is None: |
| | focal_x = focal_y = 1.2 * max(W, H) |
| | prior_focal_length = False |
| | cx = W / 2.0 |
| | cy = H / 2.0 |
| | elif isinstance(focals[idx], np.ndarray) and len(focals[idx].shape) == 2: |
| | |
| | focal_x = focals[idx][0, 0] |
| | focal_y = focals[idx][1, 1] |
| | cx = focals[idx][0, 2] * images[idx]["to_orig"][0, 0] |
| | cy = focals[idx][1, 2] * images[idx]["to_orig"][1, 1] |
| | prior_focal_length = True |
| | else: |
| | focal_x = focal_y = float(focals[idx]) |
| | prior_focal_length = True |
| | cx = W / 2.0 |
| | cy = H / 2.0 |
| | focal_x = focal_x * images[idx]["to_orig"][0, 0] |
| | focal_y = focal_y * images[idx]["to_orig"][1, 1] |
| |
|
| | if camera_model == "SIMPLE_PINHOLE": |
| | model_id = 0 |
| | focal = (focal_x + focal_y) / 2.0 |
| | params = np.asarray([focal, cx, cy], np.float64) |
| | elif camera_model == "PINHOLE": |
| | model_id = 1 |
| | params = np.asarray([focal_x, focal_y, cx, cy], np.float64) |
| | elif camera_model == "SIMPLE_RADIAL": |
| | model_id = 2 |
| | focal = (focal_x + focal_y) / 2.0 |
| | params = np.asarray([focal, cx, cy, 0.0], np.float64) |
| | elif camera_model == "OPENCV": |
| | model_id = 4 |
| | params = np.asarray([focal_x, focal_y, cx, cy, 0.0, 0.0, 0.0, 0.0], np.float64) |
| | else: |
| | raise ValueError(f"invalid camera model {camera_model}") |
| |
|
| | H, W = int(H), int(W) |
| | |
| | camid = db.add_camera( |
| | model_id, W, H, params, prior_focal_length=prior_focal_length) |
| | if ga_world_to_cam is None: |
| | prior_t = np.zeros(3) |
| | prior_q = np.zeros(4) |
| | else: |
| | q = R.from_matrix(ga_world_to_cam[idx][:3, :3]).as_quat() |
| | prior_t = ga_world_to_cam[idx][:3, 3] |
| | prior_q = np.array([q[-1], q[0], q[1], q[2]]) |
| | imid = db.add_image( |
| | image_paths[idx], camid, prior_q=prior_q, prior_t=prior_t) |
| | image_to_colmap[idx] = { |
| | 'colmap_imid': imid, |
| | 'colmap_camid': camid |
| | } |
| | return image_to_colmap, im_keypoints |
| |
|
| |
|
| | def export_matches(db, images, image_to_colmap, im_keypoints, im_matches, min_len_track, skip_geometric_verification): |
| | colmap_image_pairs = [] |
| | |
| | |
| | |
| | |
| | print("building tracks") |
| | keypoints_to_track_id = {} |
| | track_id_to_kpt_list = [] |
| | to_merge = [] |
| | for (imidx0, imidx1), colmap_matches in tqdm(im_matches.items()): |
| | if imidx0 not in keypoints_to_track_id: |
| | keypoints_to_track_id[imidx0] = {} |
| | if imidx1 not in keypoints_to_track_id: |
| | keypoints_to_track_id[imidx1] = {} |
| |
|
| | for m in colmap_matches: |
| | if m[0] not in keypoints_to_track_id[imidx0] and m[1] not in keypoints_to_track_id[imidx1]: |
| | |
| | track_idx = len(track_id_to_kpt_list) |
| | keypoints_to_track_id[imidx0][m[0]] = track_idx |
| | keypoints_to_track_id[imidx1][m[1]] = track_idx |
| | track_id_to_kpt_list.append( |
| | [(imidx0, m[0]), (imidx1, m[1])]) |
| | elif m[1] not in keypoints_to_track_id[imidx1]: |
| | |
| | track_idx = keypoints_to_track_id[imidx0][m[0]] |
| | keypoints_to_track_id[imidx1][m[1]] = track_idx |
| | track_id_to_kpt_list[track_idx].append((imidx1, m[1])) |
| | elif m[0] not in keypoints_to_track_id[imidx0]: |
| | |
| | track_idx = keypoints_to_track_id[imidx1][m[1]] |
| | keypoints_to_track_id[imidx0][m[0]] = track_idx |
| | track_id_to_kpt_list[track_idx].append((imidx0, m[0])) |
| | else: |
| | |
| | track_idx0 = keypoints_to_track_id[imidx0][m[0]] |
| | track_idx1 = keypoints_to_track_id[imidx1][m[1]] |
| | if track_idx0 != track_idx1: |
| | |
| | to_merge.append((track_idx0, track_idx1)) |
| |
|
| | |
| | print("merging tracks") |
| | unique = np.unique(to_merge) |
| | tree = DisjointSet(unique) |
| | for track_idx0, track_idx1 in tqdm(to_merge): |
| | tree.merge(track_idx0, track_idx1) |
| |
|
| | subsets = tree.subsets() |
| | print("applying merge") |
| | for setvals in tqdm(subsets): |
| | new_trackid = len(track_id_to_kpt_list) |
| | kpt_list = [] |
| | for track_idx in setvals: |
| | kpt_list.extend(track_id_to_kpt_list[track_idx]) |
| | for imidx, kpid in track_id_to_kpt_list[track_idx]: |
| | keypoints_to_track_id[imidx][kpid] = new_trackid |
| | track_id_to_kpt_list.append(kpt_list) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | num_valid_tracks = sum( |
| | [1 for v in track_id_to_kpt_list if len(v) >= min_len_track]) |
| |
|
| | keypoints_to_idx = {} |
| | print(f"squashing keypoints - {num_valid_tracks} valid tracks") |
| | for imidx, keypoints_imid in tqdm(im_keypoints.items()): |
| | imid = image_to_colmap[imidx]['colmap_imid'] |
| | keypoints_kept = [] |
| | keypoints_to_idx[imidx] = {} |
| | for kp in keypoints_imid.keys(): |
| | if kp not in keypoints_to_track_id[imidx]: |
| | continue |
| | track_idx = keypoints_to_track_id[imidx][kp] |
| | track_length = len(track_id_to_kpt_list[track_idx]) |
| | if track_length < min_len_track: |
| | continue |
| | keypoints_to_idx[imidx][kp] = len(keypoints_kept) |
| | keypoints_kept.append(kp) |
| | if len(keypoints_kept) == 0: |
| | continue |
| | keypoints_kept = np.array(keypoints_kept) |
| | keypoints_kept = np.unravel_index(keypoints_kept, images[imidx]['true_shape'][0])[ |
| | 0].base[:, ::-1].copy().astype(np.float32) |
| | |
| | keypoints_kept[:, 0] += 0.5 |
| | keypoints_kept[:, 1] += 0.5 |
| | keypoints_kept = geotrf(images[imidx]['to_orig'], keypoints_kept, norm=True) |
| |
|
| | H, W = images[imidx]['orig_shape'] |
| | keypoints_kept[:, 0] = keypoints_kept[:, 0].clip(min=0, max=W - 0.01) |
| | keypoints_kept[:, 1] = keypoints_kept[:, 1].clip(min=0, max=H - 0.01) |
| |
|
| | db.add_keypoints(imid, keypoints_kept) |
| |
|
| | print("exporting im_matches") |
| | for (imidx0, imidx1), colmap_matches in im_matches.items(): |
| | imid0, imid1 = image_to_colmap[imidx0]['colmap_imid'], image_to_colmap[imidx1]['colmap_imid'] |
| | assert imid0 < imid1 |
| | final_matches = np.array([[keypoints_to_idx[imidx0][m[0]], keypoints_to_idx[imidx1][m[1]]] |
| | for m in colmap_matches |
| | if m[0] in keypoints_to_idx[imidx0] and m[1] in keypoints_to_idx[imidx1]]) |
| | if len(final_matches) > 0: |
| | colmap_image_pairs.append( |
| | (images[imidx0]['instance'], images[imidx1]['instance'])) |
| | db.add_matches(imid0, imid1, final_matches) |
| | if skip_geometric_verification: |
| | db.add_two_view_geometry(imid0, imid1, final_matches) |
| | return colmap_image_pairs |
| |
|