| |
| |
| |
| |
| |
| |
| from tqdm import tqdm |
| import roma |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import os |
| from collections import namedtuple |
| from functools import lru_cache |
| from scipy import sparse as sp |
|
|
| from mast3r.utils.misc import mkdir_for, hash_md5 |
| from mast3r.cloud_opt.utils.losses import gamma_loss |
| from mast3r.cloud_opt.utils.schedules import linear_schedule, cosine_schedule |
| from mast3r.fast_nn import fast_reciprocal_NNs, merge_corres |
|
|
| import mast3r.utils.path_to_dust3r |
| from dust3r.utils.geometry import inv, geotrf |
| from dust3r.utils.device import to_cpu, to_numpy, todevice |
| from dust3r.post_process import estimate_focal_knowing_depth |
| from dust3r.optim_factory import adjust_learning_rate_by_lr |
| from dust3r.cloud_opt.base_opt import clean_pointcloud |
| from dust3r.viz import SceneViz |
|
|
|
|
| class SparseGA(): |
| def __init__(self, img_paths, pairs_in, res_fine, anchors, canonical_paths=None): |
| def fetch_img(im): |
| def torgb(x): return (x[0].permute(1, 2, 0).numpy() * .5 + .5).clip(min=0., max=1.) |
| for im1, im2 in pairs_in: |
| if im1['instance'] == im: |
| return torgb(im1['img']) |
| if im2['instance'] == im: |
| return torgb(im2['img']) |
| self.canonical_paths = canonical_paths |
| self.img_paths = img_paths |
| self.imgs = [fetch_img(img) for img in img_paths] |
| self.intrinsics = res_fine['intrinsics'] |
| self.cam2w = res_fine['cam2w'] |
| self.depthmaps = res_fine['depthmaps'] |
| self.pts3d = res_fine['pts3d'] |
| self.pts3d_colors = [] |
| self.working_device = self.cam2w.device |
| for i in range(len(self.imgs)): |
| im = self.imgs[i] |
| x, y = anchors[i][0][..., :2].detach().cpu().numpy().T |
| self.pts3d_colors.append(im[y, x]) |
| assert self.pts3d_colors[-1].shape == self.pts3d[i].shape |
| self.n_imgs = len(self.imgs) |
|
|
| def get_focals(self): |
| return torch.tensor([ff[0, 0] for ff in self.intrinsics]).to(self.working_device) |
|
|
| def get_principal_points(self): |
| return torch.stack([ff[:2, -1] for ff in self.intrinsics]).to(self.working_device) |
|
|
| def get_im_poses(self): |
| return self.cam2w |
|
|
| def get_sparse_pts3d(self): |
| return self.pts3d |
|
|
| def get_dense_pts3d(self, clean_depth=True, subsample=8): |
| assert self.canonical_paths, 'cache_path is required for dense 3d points' |
| device = self.cam2w.device |
| confs = [] |
| base_focals = [] |
| anchors = {} |
| for i, canon_path in enumerate(self.canonical_paths): |
| (canon, canon2, conf), focal = torch.load(canon_path, map_location=device) |
| confs.append(conf) |
| base_focals.append(focal) |
|
|
| H, W = conf.shape |
| pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device) |
| idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample) |
| anchors[i] = (pixels, idxs[i], offsets[i]) |
|
|
| |
| pts3d, depthmaps = make_pts3d(anchors, self.intrinsics, self.cam2w, [ |
| d.ravel() for d in self.depthmaps], base_focals=base_focals, ret_depth=True) |
|
|
| if clean_depth: |
| confs = clean_pointcloud(confs, self.intrinsics, inv(self.cam2w), depthmaps, pts3d) |
|
|
| return pts3d, depthmaps, confs |
|
|
| def get_pts3d_colors(self): |
| return self.pts3d_colors |
|
|
| def get_depthmaps(self): |
| return self.depthmaps |
|
|
| def get_masks(self): |
| return [slice(None, None) for _ in range(len(self.imgs))] |
|
|
| def show(self, show_cams=True): |
| pts3d, _, confs = self.get_dense_pts3d() |
| show_reconstruction(self.imgs, self.intrinsics if show_cams else None, self.cam2w, |
| [p.clip(min=-50, max=50) for p in pts3d], |
| masks=[c > 1 for c in confs]) |
|
|
|
|
| def convert_dust3r_pairs_naming(imgs, pairs_in): |
| for pair_id in range(len(pairs_in)): |
| for i in range(2): |
| pairs_in[pair_id][i]['instance'] = imgs[pairs_in[pair_id][i]['idx']] |
| return pairs_in |
|
|
|
|
| def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc_conf='desc_conf', |
| device='cuda', dtype=torch.float32, **kw): |
| """ Sparse alignment with MASt3R |
| imgs: list of image paths |
| cache_path: path where to dump temporary files (str) |
| |
| lr1, niter1: learning rate and #iterations for coarse global alignment (3D matching) |
| lr2, niter2: learning rate and #iterations for refinement (2D reproj error) |
| |
| lora_depth: smart dimensionality reduction with depthmaps |
| """ |
| |
| pairs_in = convert_dust3r_pairs_naming(imgs, pairs_in) |
| |
| pairs, cache_path = forward_mast3r(pairs_in, model, |
| cache_path=cache_path, subsample=subsample, |
| desc_conf=desc_conf, device=device) |
|
|
| |
| tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21 = \ |
| prepare_canonical_data(imgs, pairs, subsample, cache_path=cache_path, mode='avg-angle', device=device) |
|
|
| |
| mst = compute_min_spanning_tree(pairwise_scores) |
|
|
| |
| |
| |
|
|
| |
| imsizes, pps, base_focals, core_depth, anchors, corres, corres2d = \ |
| condense_data(imgs, tmp_pairs, canonical_views, dtype) |
|
|
| imgs, res_coarse, res_fine = sparse_scene_optimizer( |
| imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21, canonical_paths, |
| mst, cache_path=cache_path, device=device, dtype=dtype, **kw) |
| return SparseGA(imgs, pairs_in, res_fine or res_coarse, anchors, canonical_paths) |
|
|
|
|
| def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, |
| preds_21, canonical_paths, mst, cache_path, |
| lr1=0.2, niter1=500, loss1=gamma_loss(1.1), |
| lr2=0.02, niter2=500, loss2=gamma_loss(0.4), |
| lossd=gamma_loss(1.1), |
| opt_pp=True, opt_depth=True, |
| schedule=cosine_schedule, depth_mode='add', exp_depth=False, |
| lora_depth=False, |
| init={}, device='cuda', dtype=torch.float32, |
| matching_conf_thr=4., loss_dust3r_w=0.01, |
| verbose=True, dbg=()): |
|
|
| |
| vec0001 = torch.tensor((0, 0, 0, 1), dtype=dtype, device=device) |
| quats = [nn.Parameter(vec0001.clone()) for _ in range(len(imgs))] |
| trans = [nn.Parameter(torch.zeros(3, device=device, dtype=dtype)) for _ in range(len(imgs))] |
|
|
| |
| ones = torch.ones((len(imgs), 1), device=device, dtype=dtype) |
| median_depths = torch.ones(len(imgs), device=device, dtype=dtype) |
| for img in imgs: |
| idx = imgs.index(img) |
| init_values = init.setdefault(img, {}) |
| if verbose and init_values: |
| print(f' >> initializing img=...{img[-25:]} [{idx}] for {set(init_values)}') |
|
|
| K = init_values.get('intrinsics') |
| if K is not None: |
| K = K.detach() |
| focal = K[:2, :2].diag().mean() |
| pp = K[:2, 2] |
| base_focals[idx] = focal |
| pps[idx] = pp |
| pps[idx] /= imsizes[idx] |
|
|
| depth = init_values.get('depthmap') |
| if depth is not None: |
| core_depth[idx] = depth.detach() |
|
|
| median_depths[idx] = med_depth = core_depth[idx].median() |
| core_depth[idx] /= med_depth |
|
|
| cam2w = init_values.get('cam2w') |
| if cam2w is not None: |
| rot = cam2w[:3, :3].detach() |
| cam_center = cam2w[:3, 3].detach() |
| quats[idx].data[:] = roma.rotmat_to_unitquat(rot) |
| trans_offset = med_depth * torch.cat((imsizes[idx] / base_focals[idx] * (0.5 - pps[idx]), ones[:1, 0])) |
| trans[idx].data[:] = cam_center + rot @ trans_offset |
| del rot |
| assert False, 'inverse kinematic chain not yet implemented' |
|
|
| |
| pps = [nn.Parameter(pp.to(dtype)) for pp in pps] |
| diags = imsizes.float().norm(dim=1) |
| min_focals = 0.25 * diags |
| max_focals = 10 * diags |
| log_focals = [nn.Parameter(f.view(1).log().to(dtype)) for f in base_focals] |
| assert len(mst[1]) == len(pps) - 1 |
|
|
| def make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth): |
| |
| focals = torch.cat(log_focals).exp().clip(min=min_focals, max=max_focals) |
| pps = torch.stack(pps) |
| K = torch.eye(3, dtype=dtype, device=device)[None].expand(len(imgs), 3, 3).clone() |
| K[:, 0, 0] = K[:, 1, 1] = focals |
| K[:, 0:2, 2] = pps * imsizes |
| if trans is None: |
| return K |
|
|
| |
| sizes = torch.cat(log_sizes).exp() |
| global_scaling = 1 / sizes.min() |
|
|
| |
| |
| z_cameras = sizes * median_depths * focals / base_focals |
|
|
| |
| rel_cam2cam = torch.eye(4, dtype=dtype, device=device)[None].expand(len(imgs), 4, 4).clone() |
| rel_cam2cam[:, :3, :3] = roma.unitquat_to_rotmat(F.normalize(torch.stack(quats), dim=1)) |
| rel_cam2cam[:, :3, 3] = torch.stack(trans) |
|
|
| |
| tmp_cam2w = [None] * len(K) |
| tmp_cam2w[mst[0]] = rel_cam2cam[mst[0]] |
| for i, j in mst[1]: |
| |
| tmp_cam2w[j] = tmp_cam2w[i] @ rel_cam2cam[j] |
| tmp_cam2w = torch.stack(tmp_cam2w) |
|
|
| |
| trans_offset = z_cameras.unsqueeze(1) * torch.cat((imsizes / focals.unsqueeze(1) * (0.5 - pps), ones), dim=-1) |
| new_trans = global_scaling * (tmp_cam2w[:, :3, 3:4] - tmp_cam2w[:, :3, :3] @ trans_offset.unsqueeze(-1)) |
| cam2w = torch.cat((torch.cat((tmp_cam2w[:, :3, :3], new_trans), dim=2), |
| vec0001.view(1, 1, 4).expand(len(K), 1, 4)), dim=1) |
|
|
| depthmaps = [] |
| for i in range(len(imgs)): |
| core_depth_img = core_depth[i] |
| if exp_depth: |
| core_depth_img = core_depth_img.exp() |
| if lora_depth: |
| core_depth_img = lora_depth_proj[i] @ core_depth_img |
| if depth_mode == 'add': |
| core_depth_img = z_cameras[i] + (core_depth_img - 1) * (median_depths[i] * sizes[i]) |
| elif depth_mode == 'mul': |
| core_depth_img = z_cameras[i] * core_depth_img |
| else: |
| raise ValueError(f'Bad {depth_mode=}') |
| depthmaps.append(global_scaling * core_depth_img) |
|
|
| return K, (inv(cam2w), cam2w), depthmaps |
|
|
| K = make_K_cam_depth(log_focals, pps, None, None, None, None) |
| print('init focals =', to_numpy(K[:, 0, 0])) |
|
|
| |
| if lora_depth: |
| core_depth, lora_depth_proj = spectral_projection_of_depthmaps( |
| imgs, K, core_depth, subsample, cache_path=cache_path, **lora_depth) |
| if exp_depth: |
| core_depth = [d.clip(min=1e-4).log() for d in core_depth] |
| core_depth = [nn.Parameter(d.ravel().to(dtype)) for d in core_depth] |
| log_sizes = [nn.Parameter(torch.zeros(1, dtype=dtype, device=device)) for _ in range(len(imgs))] |
|
|
| |
| _, confs_sum, imgs_slices = corres |
|
|
| |
| def matching_check(x): return x.max() > matching_conf_thr |
| is_matching_ok = {} |
| for s in imgs_slices: |
| is_matching_ok[s.img1, s.img2] = matching_check(s.confs) |
|
|
| |
| subsamp_preds_21 = {} |
| for imk, imv in preds_21.items(): |
| subsamp_preds_21[imk] = {} |
| for im2k, (pred, conf) in preds_21[imk].items(): |
| subpred = pred[::subsample, ::subsample].reshape(-1, 3) |
| subconf = conf[::subsample, ::subsample].ravel() |
| idxs = anchors[imgs.index(im2k)][1] |
| subsamp_preds_21[imk][im2k] = (subpred[idxs], subconf[idxs]) |
|
|
| def loss_dust3r(cam2w, pts3d, pix_loss): |
| |
| loss = 0. |
| cf_sum = 0. |
| for s in imgs_slices: |
| if not is_matching_ok[s.img1, s.img2]: |
| |
| tgt_pts, tgt_confs = subsamp_preds_21[imgs[s.img2]][imgs[s.img1]] |
| tgt_pts = geotrf(cam2w[s.img2], tgt_pts) |
| cf_sum += tgt_confs.sum() |
| loss += tgt_confs @ pix_loss(pts3d[s.img1], tgt_pts) |
| return loss / cf_sum if cf_sum != 0. else 0. |
|
|
| def loss_3d(K, w2cam, pts3d, pix_loss): |
| |
| |
| if any(v.get('freeze') for v in init.values()): |
| pts3d_1 = [] |
| pts3d_2 = [] |
| confs = [] |
| for s in imgs_slices: |
| if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'): |
| continue |
| if is_matching_ok[s.img1, s.img2]: |
| pts3d_1.append(pts3d[s.img1][s.slice1]) |
| pts3d_2.append(pts3d[s.img2][s.slice2]) |
| confs.append(s.confs) |
| else: |
| pts3d_1 = [pts3d[s.img1][s.slice1] for s in imgs_slices if is_matching_ok[s.img1, s.img2]] |
| pts3d_2 = [pts3d[s.img2][s.slice2] for s in imgs_slices if is_matching_ok[s.img1, s.img2]] |
| confs = [s.confs for s in imgs_slices if is_matching_ok[s.img1, s.img2]] |
|
|
| if pts3d_1 != []: |
| confs = torch.cat(confs) |
| pts3d_1 = torch.cat(pts3d_1) |
| pts3d_2 = torch.cat(pts3d_2) |
| loss = confs @ pix_loss(pts3d_1, pts3d_2) |
| cf_sum = confs.sum() |
| else: |
| loss = 0. |
| cf_sum = 1. |
|
|
| return loss / cf_sum |
|
|
| def loss_2d(K, w2cam, pts3d, pix_loss): |
| |
| |
| proj_matrix = K @ w2cam[:, :3] |
| loss = npix = 0 |
| for img1, pix1, confs, cf_sum, imgs_slices in corres2d: |
| if init[imgs[img1]].get('freeze', 0) >= 1: |
| continue |
| pts3d_in_img1 = [pts3d[img2][slice2] for img2, slice2 in imgs_slices if is_matching_ok[img1, img2]] |
| pix1_filtered = [] |
| confs_filtered = [] |
| curstep = 0 |
| for img2, slice2 in imgs_slices: |
| if is_matching_ok[img1, img2]: |
| tslice = slice(curstep, curstep + slice2.stop - slice2.start, slice2.step) |
| pix1_filtered.append(pix1[tslice]) |
| confs_filtered.append(confs[tslice]) |
| curstep += slice2.stop - slice2.start |
| if pts3d_in_img1 != []: |
| pts3d_in_img1 = torch.cat(pts3d_in_img1) |
| pix1_filtered = torch.cat(pix1_filtered) |
| confs_filtered = torch.cat(confs_filtered) |
| loss += confs_filtered @ pix_loss(pix1_filtered, reproj2d(proj_matrix[img1], pts3d_in_img1)) |
| npix += confs_filtered.sum() |
| return loss / npix if npix != 0 else 0. |
|
|
| def optimize_loop(loss_func, lr_base, niter, pix_loss, lr_end=0): |
| |
| params = pps + log_focals + quats + trans + log_sizes + core_depth |
| optimizer = torch.optim.Adam(params, lr=1, weight_decay=0, betas=(0.9, 0.9)) |
| ploss = pix_loss if 'meta' in repr(pix_loss) else (lambda a: pix_loss) |
|
|
| with tqdm(total=niter) as bar: |
| for iter in range(niter or 1): |
| K, (w2cam, cam2w), depthmaps = make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth) |
| pts3d = make_pts3d(anchors, K, cam2w, depthmaps, base_focals=base_focals) |
| if niter == 0: |
| break |
|
|
| alpha = (iter / niter) |
| lr = schedule(alpha, lr_base, lr_end) |
| adjust_learning_rate_by_lr(optimizer, lr) |
| pix_loss = ploss(1 - alpha) |
| optimizer.zero_grad() |
| loss = loss_func(K, w2cam, pts3d, pix_loss) + loss_dust3r_w * loss_dust3r(cam2w, pts3d, lossd) |
| loss.backward() |
| optimizer.step() |
|
|
| |
| for i in range(len(imgs)): |
| quats[i].data[:] /= quats[i].data.norm() |
|
|
| loss = float(loss) |
| if loss != loss: |
| break |
| bar.set_postfix_str(f'{lr=:.4f}, {loss=:.3f}') |
| bar.update(1) |
|
|
| if niter: |
| print(f'>> final loss = {loss}') |
| return dict(intrinsics=K.detach(), cam2w=cam2w.detach(), |
| depthmaps=[d.detach() for d in depthmaps], pts3d=[p.detach() for p in pts3d]) |
|
|
| |
| for i, img in enumerate(imgs): |
| trainable = not (init[img].get('freeze')) |
| pps[i].requires_grad_(False) |
| log_focals[i].requires_grad_(False) |
| quats[i].requires_grad_(trainable) |
| trans[i].requires_grad_(trainable) |
| log_sizes[i].requires_grad_(trainable) |
| core_depth[i].requires_grad_(False) |
|
|
| res_coarse = optimize_loop(loss_3d, lr_base=lr1, niter=niter1, pix_loss=loss1) |
|
|
| res_fine = None |
| if niter2: |
| |
| for i, img in enumerate(imgs): |
| if init[img].get('freeze', 0) >= 1: |
| continue |
| pps[i].requires_grad_(bool(opt_pp)) |
| log_focals[i].requires_grad_(True) |
| core_depth[i].requires_grad_(opt_depth) |
|
|
| |
| res_fine = optimize_loop(loss_2d, lr_base=lr2, niter=niter2, pix_loss=loss2) |
|
|
| return imgs, res_coarse, res_fine |
|
|
|
|
| @lru_cache |
| def mask110(device, dtype): |
| return torch.tensor((1, 1, 0), device=device, dtype=dtype) |
|
|
|
|
| def proj3d(inv_K, pixels, z): |
| if pixels.shape[-1] == 2: |
| pixels = torch.cat((pixels, torch.ones_like(pixels[..., :1])), dim=-1) |
| return z.unsqueeze(-1) * (pixels * inv_K.diag() + inv_K[:, 2] * mask110(z.device, z.dtype)) |
|
|
|
|
| def make_pts3d(anchors, K, cam2w, depthmaps, base_focals=None, ret_depth=False): |
| focals = K[:, 0, 0] |
| invK = inv(K) |
| all_pts3d = [] |
| depth_out = [] |
|
|
| for img, (pixels, idxs, offsets) in anchors.items(): |
| |
| if base_focals is None: |
| pass |
| else: |
| |
| |
| |
| offsets = 1 + (offsets - 1) * (base_focals[img] / focals[img]) |
|
|
| pts3d = proj3d(invK[img], pixels, depthmaps[img][idxs] * offsets) |
| if ret_depth: |
| depth_out.append(pts3d[..., 2]) |
|
|
| |
| pts3d = geotrf(cam2w[img], pts3d) |
| all_pts3d.append(pts3d) |
|
|
| if ret_depth: |
| return all_pts3d, depth_out |
| return all_pts3d |
|
|
|
|
| def make_dense_pts3d(intrinsics, cam2w, depthmaps, canonical_paths, subsample, device='cuda'): |
| base_focals = [] |
| anchors = {} |
| confs = [] |
| for i, canon_path in enumerate(canonical_paths): |
| (canon, canon2, conf), focal = torch.load(canon_path, map_location=device) |
| confs.append(conf) |
| base_focals.append(focal) |
| H, W = conf.shape |
| pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device) |
| idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample) |
| anchors[i] = (pixels, idxs[i], offsets[i]) |
|
|
| |
| pts3d, depthmaps_out = make_pts3d(anchors, intrinsics, cam2w, [ |
| d.ravel() for d in depthmaps], base_focals=base_focals, ret_depth=True) |
|
|
| return pts3d, depthmaps_out, confs |
|
|
|
|
| @torch.no_grad() |
| def forward_mast3r(pairs, model, cache_path, desc_conf='desc_conf', |
| device='cuda', subsample=8, **matching_kw): |
| res_paths = {} |
|
|
| for img1, img2 in tqdm(pairs): |
| idx1 = hash_md5(img1['instance']) |
| idx2 = hash_md5(img2['instance']) |
|
|
| path1 = cache_path + f'/forward/{idx1}/{idx2}.pth' |
| path2 = cache_path + f'/forward/{idx2}/{idx1}.pth' |
| path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx1}-{idx2}.pth' |
| path_corres2 = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx2}-{idx1}.pth' |
|
|
| if os.path.isfile(path_corres2) and not os.path.isfile(path_corres): |
| score, (xy1, xy2, confs) = torch.load(path_corres2) |
| torch.save((score, (xy2, xy1, confs)), path_corres) |
|
|
| if not all(os.path.isfile(p) for p in (path1, path2, path_corres)): |
| if model is None: |
| continue |
| res = symmetric_inference(model, img1, img2, device=device) |
| X11, X21, X22, X12 = [r['pts3d'][0] for r in res] |
| C11, C21, C22, C12 = [r['conf'][0] for r in res] |
| descs = [r['desc'][0] for r in res] |
| qonfs = [r[desc_conf][0] for r in res] |
|
|
| |
| torch.save(to_cpu((X11, C11, X21, C21)), mkdir_for(path1)) |
| torch.save(to_cpu((X22, C22, X12, C12)), mkdir_for(path2)) |
|
|
| |
| corres = extract_correspondences(descs, qonfs, device=device, subsample=subsample) |
|
|
| conf_score = (C11.mean() * C12.mean() * C21.mean() * C22.mean()).sqrt().sqrt() |
| matching_score = (float(conf_score), float(corres[2].sum()), len(corres[2])) |
| if cache_path is not None: |
| torch.save((matching_score, corres), mkdir_for(path_corres)) |
|
|
| res_paths[img1['instance'], img2['instance']] = (path1, path2), path_corres |
|
|
| del model |
| torch.cuda.empty_cache() |
|
|
| return res_paths, cache_path |
|
|
|
|
| def symmetric_inference(model, img1, img2, device): |
| shape1 = torch.from_numpy(img1['true_shape']).to(device, non_blocking=True) |
| shape2 = torch.from_numpy(img2['true_shape']).to(device, non_blocking=True) |
| img1 = img1['img'].to(device, non_blocking=True) |
| img2 = img2['img'].to(device, non_blocking=True) |
|
|
| |
| feat1, feat2, pos1, pos2 = model._encode_image_pairs(img1, img2, shape1, shape2) |
|
|
| def decoder(feat1, feat2, pos1, pos2, shape1, shape2): |
| dec1, dec2 = model._decoder(feat1, pos1, feat2, pos2) |
| with torch.cuda.amp.autocast(enabled=False): |
| res1 = model._downstream_head(1, [tok.float() for tok in dec1], shape1) |
| res2 = model._downstream_head(2, [tok.float() for tok in dec2], shape2) |
| return res1, res2 |
|
|
| |
| res11, res21 = decoder(feat1, feat2, pos1, pos2, shape1, shape2) |
| |
| res22, res12 = decoder(feat2, feat1, pos2, pos1, shape2, shape1) |
|
|
| return (res11, res21, res22, res12) |
|
|
|
|
| def extract_correspondences(feats, qonfs, subsample=8, device=None, ptmap_key='pred_desc'): |
| feat11, feat21, feat22, feat12 = feats |
| qonf11, qonf21, qonf22, qonf12 = qonfs |
| assert feat11.shape[:2] == feat12.shape[:2] == qonf11.shape == qonf12.shape |
| assert feat21.shape[:2] == feat22.shape[:2] == qonf21.shape == qonf22.shape |
|
|
| if '3d' in ptmap_key: |
| opt = dict(device='cpu', workers=32) |
| else: |
| opt = dict(device=device, dist='dot', block_size=2**13) |
|
|
| |
| idx1 = [] |
| idx2 = [] |
| qonf1 = [] |
| qonf2 = [] |
| |
| for A, B, QA, QB in [(feat11, feat21, qonf11.cpu(), qonf21.cpu()), |
| (feat12, feat22, qonf12.cpu(), qonf22.cpu())]: |
| nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=subsample, ret_xy=False, **opt) |
| nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=subsample, ret_xy=False, **opt) |
|
|
| idx1.append(np.r_[nn1to2[0], nn2to1[1]]) |
| idx2.append(np.r_[nn1to2[1], nn2to1[0]]) |
| qonf1.append(QA.ravel()[idx1[-1]]) |
| qonf2.append(QB.ravel()[idx2[-1]]) |
|
|
| |
| H1, W1 = feat11.shape[:2] |
| H2, W2 = feat22.shape[:2] |
| cat = np.concatenate |
|
|
| xy1, xy2, idx = merge_corres(cat(idx1), cat(idx2), (H1, W1), (H2, W2), ret_xy=True, ret_index=True) |
| corres = (xy1.copy(), xy2.copy(), np.sqrt(cat(qonf1)[idx] * cat(qonf2)[idx])) |
|
|
| return todevice(corres, device) |
|
|
|
|
| @torch.no_grad() |
| def prepare_canonical_data(imgs, tmp_pairs, subsample, order_imgs=False, min_conf_thr=0, |
| cache_path=None, device='cuda', **kw): |
| canonical_views = {} |
| pairwise_scores = torch.zeros((len(imgs), len(imgs)), device=device) |
| canonical_paths = [] |
| preds_21 = {} |
|
|
| for img in tqdm(imgs): |
| if cache_path: |
| cache = os.path.join(cache_path, 'canon_views', hash_md5(img) + f'_{subsample=}_{kw=}.pth') |
| canonical_paths.append(cache) |
| try: |
| (canon, canon2, cconf), focal = torch.load(cache, map_location=device) |
| except IOError: |
| |
| canon = focal = None |
|
|
| |
| n_pairs = sum((img in pair) for pair in tmp_pairs) |
|
|
| ptmaps11 = None |
| pixels = {} |
| n = 0 |
| for (img1, img2), ((path1, path2), path_corres) in tmp_pairs.items(): |
| score = None |
| if img == img1: |
| X, C, X2, C2 = torch.load(path1, map_location=device) |
| score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr) |
| pixels[img2] = xy1, confs |
| if img not in preds_21: |
| preds_21[img] = {} |
| preds_21[img][img2] = X2, C2 |
|
|
| if img == img2: |
| X, C, X2, C2 = torch.load(path2, map_location=device) |
| score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr) |
| pixels[img1] = xy2, confs |
| if img not in preds_21: |
| preds_21[img] = {} |
| preds_21[img][img1] = X2, C2 |
|
|
| if score is not None: |
| i, j = imgs.index(img1), imgs.index(img2) |
| |
| |
| score = score[2] |
| pairwise_scores[i, j] = score |
| pairwise_scores[j, i] = score |
|
|
| if canon is not None: |
| continue |
| if ptmaps11 is None: |
| H, W = C.shape |
| ptmaps11 = torch.empty((n_pairs, H, W, 3), device=device) |
| confs11 = torch.empty((n_pairs, H, W), device=device) |
|
|
| ptmaps11[n] = X |
| confs11[n] = C |
| n += 1 |
|
|
| if canon is None: |
| canon, canon2, cconf = canonical_view(ptmaps11, confs11, subsample, **kw) |
| del ptmaps11 |
| del confs11 |
|
|
| |
| H, W = canon.shape[:2] |
| pp = torch.tensor([W / 2, H / 2], device=device) |
| if focal is None: |
| focal = estimate_focal_knowing_depth(canon[None], pp, focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5) |
| if cache: |
| torch.save(to_cpu(((canon, canon2, cconf), focal)), mkdir_for(cache)) |
|
|
| |
| core_depth = canon[subsample // 2::subsample, subsample // 2::subsample, 2] |
| idxs, offsets = anchor_depth_offsets(canon2, pixels, subsample=subsample) |
|
|
| canonical_views[img] = (pp, (H, W), focal.view(1), core_depth, pixels, idxs, offsets) |
|
|
| return tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21 |
|
|
|
|
| def load_corres(path_corres, device, min_conf_thr): |
| score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device) |
| valid = confs > min_conf_thr if min_conf_thr else slice(None) |
| |
| |
| return score, (xy1[valid], xy2[valid], confs[valid]) |
|
|
|
|
| PairOfSlices = namedtuple( |
| 'ImgPair', 'img1, slice1, pix1, anchor_idxs1, img2, slice2, pix2, anchor_idxs2, confs, confs_sum') |
|
|
|
|
| def condense_data(imgs, tmp_paths, canonical_views, dtype=torch.float32): |
| |
| set_imgs = set(imgs) |
|
|
| principal_points = [] |
| shapes = [] |
| focals = [] |
| core_depth = [] |
| img_anchors = {} |
| tmp_pixels = {} |
|
|
| for idx1, img1 in enumerate(imgs): |
| |
| pp, shape, focal, anchors, pixels_confs, idxs, offsets = canonical_views[img1] |
|
|
| principal_points.append(pp) |
| shapes.append(shape) |
| focals.append(focal) |
| core_depth.append(anchors) |
|
|
| img_uv1 = [] |
| img_idxs = [] |
| img_offs = [] |
| cur_n = [0] |
|
|
| for img2, (pixels, match_confs) in pixels_confs.items(): |
| if img2 not in set_imgs: |
| continue |
| assert len(pixels) == len(idxs[img2]) == len(offsets[img2]) |
| img_uv1.append(torch.cat((pixels, torch.ones_like(pixels[:, :1])), dim=-1)) |
| img_idxs.append(idxs[img2]) |
| img_offs.append(offsets[img2]) |
| cur_n.append(cur_n[-1] + len(pixels)) |
| |
| tmp_pixels[img1, img2] = pixels.to(dtype), match_confs.to(dtype), slice(*cur_n[-2:]) |
| img_anchors[idx1] = (torch.cat(img_uv1), torch.cat(img_idxs), torch.cat(img_offs)) |
|
|
| all_confs = [] |
| imgs_slices = [] |
| corres2d = {img: [] for img in range(len(imgs))} |
|
|
| for img1, img2 in tmp_paths: |
| try: |
| pix1, confs1, slice1 = tmp_pixels[img1, img2] |
| pix2, confs2, slice2 = tmp_pixels[img2, img1] |
| except KeyError: |
| continue |
| img1 = imgs.index(img1) |
| img2 = imgs.index(img2) |
| confs = (confs1 * confs2).sqrt() |
|
|
| |
| all_confs.append(confs) |
| anchor_idxs1 = canonical_views[imgs[img1]][5][imgs[img2]] |
| anchor_idxs2 = canonical_views[imgs[img2]][5][imgs[img1]] |
| imgs_slices.append(PairOfSlices(img1, slice1, pix1, anchor_idxs1, |
| img2, slice2, pix2, anchor_idxs2, |
| confs, float(confs.sum()))) |
|
|
| |
| corres2d[img1].append((pix1, confs, img2, slice2)) |
| corres2d[img2].append((pix2, confs, img1, slice1)) |
|
|
| all_confs = torch.cat(all_confs) |
| corres = (all_confs, float(all_confs.sum()), imgs_slices) |
|
|
| def aggreg_matches(img1, list_matches): |
| pix1, confs, img2, slice2 = zip(*list_matches) |
| all_pix1 = torch.cat(pix1).to(dtype) |
| all_confs = torch.cat(confs).to(dtype) |
| return img1, all_pix1, all_confs, float(all_confs.sum()), [(j, sl2) for j, sl2 in zip(img2, slice2)] |
| corres2d = [aggreg_matches(img, m) for img, m in corres2d.items()] |
|
|
| imsizes = torch.tensor([(W, H) for H, W in shapes], device=pp.device) |
| principal_points = torch.stack(principal_points) |
| focals = torch.cat(focals) |
| return imsizes, principal_points, focals, core_depth, img_anchors, corres, corres2d |
|
|
|
|
| def canonical_view(ptmaps11, confs11, subsample, mode='avg-angle'): |
| assert len(ptmaps11) == len(confs11) > 0, 'not a single view1 for img={i}' |
|
|
| |
| confs11 = confs11.unsqueeze(-1) - 0.999 |
| canon = (confs11 * ptmaps11).sum(0) / confs11.sum(0) |
|
|
| canon_depth = ptmaps11[..., 2].unsqueeze(1) |
| S = slice(subsample // 2, None, subsample) |
| center_depth = canon_depth[:, :, S, S] |
| assert (center_depth > 0).all() |
| stacked_depth = F.pixel_unshuffle(canon_depth, subsample) |
| stacked_confs = F.pixel_unshuffle(confs11[:, None, :, :, 0], subsample) |
|
|
| if mode == 'avg-reldepth': |
| rel_depth = stacked_depth / center_depth |
| stacked_canon = (stacked_confs * rel_depth).sum(dim=0) / stacked_confs.sum(dim=0) |
| canon2 = F.pixel_shuffle(stacked_canon.unsqueeze(0), subsample).squeeze() |
|
|
| elif mode == 'avg-angle': |
| xy = ptmaps11[..., 0:2].permute(0, 3, 1, 2) |
| stacked_xy = F.pixel_unshuffle(xy, subsample) |
| B, _, H, W = stacked_xy.shape |
| stacked_radius = (stacked_xy.view(B, 2, -1, H, W) - xy[:, :, None, S, S]).norm(dim=1) |
| stacked_radius.clip_(min=1e-8) |
|
|
| stacked_angle = torch.arctan((stacked_depth - center_depth) / stacked_radius) |
| avg_angle = (stacked_confs * stacked_angle).sum(dim=0) / stacked_confs.sum(dim=0) |
|
|
| |
| stacked_depth = stacked_radius.mean(dim=0) * torch.tan(avg_angle) |
|
|
| canon2 = F.pixel_shuffle((1 + stacked_depth / canon[S, S, 2]).unsqueeze(0), subsample).squeeze() |
| else: |
| raise ValueError(f'bad {mode=}') |
|
|
| confs = (confs11.square().sum(dim=0) / confs11.sum(dim=0)).squeeze() |
| return canon, canon2, confs |
|
|
|
|
| def anchor_depth_offsets(canon_depth, pixels, subsample=8): |
| device = canon_depth.device |
|
|
| |
| H1, W1 = canon_depth.shape |
| yx = np.mgrid[subsample // 2:H1:subsample, subsample // 2:W1:subsample] |
| H2, W2 = yx.shape[1:] |
| cy, cx = yx.reshape(2, -1) |
| core_depth = canon_depth[cy, cx] |
| assert (core_depth > 0).all() |
|
|
| |
| core_idxs = {} |
| core_offs = {} |
|
|
| for img2, (xy1, _confs) in pixels.items(): |
| px, py = xy1.long().T |
|
|
| |
| core_idx = (py // subsample) * W2 + (px // subsample) |
| core_idxs[img2] = core_idx.to(device) |
|
|
| |
| ref_z = core_depth[core_idx] |
| pts_z = canon_depth[py, px] |
| offset = pts_z / ref_z |
| core_offs[img2] = offset.detach().to(device) |
|
|
| return core_idxs, core_offs |
|
|
|
|
| def spectral_clustering(graph, k=None, normalized_cuts=False): |
| graph.fill_diagonal_(0) |
|
|
| |
| degrees = graph.sum(dim=-1) |
| laplacian = torch.diag(degrees) - graph |
| if normalized_cuts: |
| i_inv = torch.diag(degrees.sqrt().reciprocal()) |
| laplacian = i_inv @ laplacian @ i_inv |
|
|
| |
| eigval, eigvec = torch.linalg.eigh(laplacian) |
| return eigval[:k], eigvec[:, :k] |
|
|
|
|
| def sim_func(p1, p2, gamma): |
| diff = (p1 - p2).norm(dim=-1) |
| avg_depth = (p1[:, :, 2] + p2[:, :, 2]) |
| rel_distance = diff / avg_depth |
| sim = torch.exp(-gamma * rel_distance.square()) |
| return sim |
|
|
|
|
| def backproj(K, depthmap, subsample): |
| H, W = depthmap.shape |
| uv = np.mgrid[subsample // 2:subsample * W:subsample, subsample // 2:subsample * H:subsample].T.reshape(H, W, 2) |
| xyz = depthmap.unsqueeze(-1) * geotrf(inv(K), todevice(uv, K.device), ncol=3) |
| return xyz |
|
|
|
|
| def spectral_projection_depth(K, depthmap, subsample, k=64, cache_path='', |
| normalized_cuts=True, gamma=7, min_norm=5): |
| try: |
| if cache_path: |
| cache_path = cache_path + f'_{k=}_norm={normalized_cuts}_{gamma=}.pth' |
| lora_proj = torch.load(cache_path, map_location=K.device) |
|
|
| except IOError: |
| |
| xyz = backproj(K, depthmap, subsample) |
|
|
| |
| xyz = xyz.reshape(-1, 3) |
| graph = sim_func(xyz[:, None], xyz[None, :], gamma=gamma) |
| _, lora_proj = spectral_clustering(graph, k, normalized_cuts=normalized_cuts) |
|
|
| if cache_path: |
| torch.save(lora_proj.cpu(), mkdir_for(cache_path)) |
|
|
| lora_proj, coeffs = lora_encode_normed(lora_proj, depthmap.ravel(), min_norm=min_norm) |
|
|
| |
| return coeffs, lora_proj |
|
|
|
|
| def lora_encode_normed(lora_proj, x, min_norm, global_norm=False): |
| |
| coeffs = torch.linalg.pinv(lora_proj) @ x |
|
|
| |
| if coeffs.ndim == 1: |
| coeffs = coeffs[:, None] |
| if global_norm: |
| lora_proj *= coeffs[1:].norm() * min_norm / coeffs.shape[1] |
| elif min_norm: |
| lora_proj *= coeffs.norm(dim=1).clip(min=min_norm) |
| |
| coeffs = (torch.linalg.pinv(lora_proj.double()) @ x.double()).float() |
|
|
| return lora_proj.detach(), coeffs.detach() |
|
|
|
|
| @torch.no_grad() |
| def spectral_projection_of_depthmaps(imgs, intrinsics, depthmaps, subsample, cache_path=None, **kw): |
| |
| core_depth = [] |
| lora_proj = [] |
|
|
| for i, img in enumerate(tqdm(imgs)): |
| cache = os.path.join(cache_path, 'lora_depth', hash_md5(img)) if cache_path else None |
| depth, proj = spectral_projection_depth(intrinsics[i], depthmaps[i], subsample, |
| cache_path=cache, **kw) |
| core_depth.append(depth) |
| lora_proj.append(proj) |
|
|
| return core_depth, lora_proj |
|
|
|
|
| def reproj2d(Trf, pts3d): |
| res = (pts3d @ Trf[:3, :3].transpose(-1, -2)) + Trf[:3, 3] |
| clipped_z = res[:, 2:3].clip(min=1e-3) |
| uv = res[:, 0:2] / clipped_z |
| return uv.clip(min=-1000, max=2000) |
|
|
|
|
| def bfs(tree, start_node): |
| order, predecessors = sp.csgraph.breadth_first_order(tree, start_node, directed=False) |
| ranks = np.arange(len(order)) |
| ranks[order] = ranks.copy() |
| return ranks, predecessors |
|
|
|
|
| def compute_min_spanning_tree(pws): |
| sparse_graph = sp.dok_array(pws.shape) |
| for i, j in pws.nonzero().cpu().tolist(): |
| sparse_graph[i, j] = -float(pws[i, j]) |
| msp = sp.csgraph.minimum_spanning_tree(sparse_graph) |
|
|
| |
| ranks1, _ = bfs(msp, 0) |
| ranks2, _ = bfs(msp, ranks1.argmax()) |
| ranks1, _ = bfs(msp, ranks2.argmax()) |
| |
| root = np.minimum(ranks1, ranks2).argmax() |
|
|
| |
| order, predecessors = sp.csgraph.breadth_first_order(msp, root, directed=False) |
| order = order[1:] |
| edges = [(predecessors[i], i) for i in order] |
|
|
| return root, edges |
|
|
|
|
| def show_reconstruction(shapes_or_imgs, K, cam2w, pts3d, gt_cam2w=None, gt_K=None, cam_size=None, masks=None, **kw): |
| viz = SceneViz() |
|
|
| cc = cam2w[:, :3, 3] |
| cs = cam_size or float(torch.cdist(cc, cc).fill_diagonal_(np.inf).min(dim=0).values.median()) |
| colors = 64 + np.random.randint(255 - 64, size=(len(cam2w), 3)) |
|
|
| if isinstance(shapes_or_imgs, np.ndarray) and shapes_or_imgs.ndim == 2: |
| cam_kws = dict(imsizes=shapes_or_imgs[:, ::-1], cam_size=cs) |
| else: |
| imgs = shapes_or_imgs |
| cam_kws = dict(images=imgs, cam_size=cs) |
| if K is not None: |
| viz.add_cameras(to_numpy(cam2w), to_numpy(K), colors=colors, **cam_kws) |
|
|
| if gt_cam2w is not None: |
| if gt_K is None: |
| gt_K = K |
| viz.add_cameras(to_numpy(gt_cam2w), to_numpy(gt_K), colors=colors, marker='o', **cam_kws) |
|
|
| if pts3d is not None: |
| for i, p in enumerate(pts3d): |
| if not len(p): |
| continue |
| if masks is None: |
| viz.add_pointcloud(to_numpy(p), color=tuple(colors[i].tolist())) |
| else: |
| viz.add_pointcloud(to_numpy(p), mask=masks[i], color=imgs[i]) |
| viz.show(**kw) |
|
|