Spaces:
Sleeping
Sleeping
| import os | |
| import kornia as K | |
| import kornia.feature as KF | |
| import numpy as np | |
| import pypose as pp | |
| import torch | |
| import torch.multiprocessing as mp | |
| import torch.nn.functional as F | |
| from einops import asnumpy, rearrange, repeat | |
| from torch_scatter import scatter_max | |
| from .. import fastba | |
| from .. import projective_ops as pops | |
| from ..lietorch import SE3 | |
| from .optim_utils import SE3_to_Sim3, make_pypose_Sim3, ransac_umeyama, run_DPVO_PGO | |
| from .retrieval import ImageCache, RetrievalDBOW | |
| class LongTermLoopClosure: | |
| def __init__(self, cfg, patchgraph): | |
| self.cfg = cfg | |
| # Data structures to manage retrieval | |
| self.retrieval = RetrievalDBOW() | |
| self.imcache = ImageCache() | |
| # Process to run PGO in parallel | |
| self.lc_pool = mp.Pool(processes=1) | |
| self.lc_process = self.lc_pool.apply_async(os.getpid) | |
| self.manager = mp.Manager() | |
| self.result_queue = self.manager.Queue() | |
| self.lc_in_progress = False | |
| # Patch graph + loop edges | |
| self.pg = patchgraph | |
| self.loop_ii = torch.zeros(0, dtype=torch.long) | |
| self.loop_jj = torch.zeros(0, dtype=torch.long) | |
| self.lc_count = 0 | |
| # warmup the jit compiler | |
| ransac_umeyama(np.random.randn(3,3), np.random.randn(3,3), iterations=200, threshold=0.01) | |
| self.detector = KF.DISK.from_pretrained("depth").to("cuda").eval() | |
| self.matcher = KF.LightGlue("disk").to("cuda").eval() | |
| def detect_keypoints(self, images, num_features=2048): | |
| """ Pretty self explanitory! Alas, we can only use disk w/ lightglue. ORB is brittle """ | |
| _, _, h, w = images.shape | |
| wh = torch.tensor([w, h]).view(1, 2).float().cuda() | |
| features = self.detector(images, num_features, pad_if_not_divisible=True, window_size=15, score_threshold=40.0) | |
| return [{ | |
| "keypoints": f.keypoints[None], | |
| "descriptors": f.descriptors[None], | |
| "image_size": wh | |
| } for f in features] | |
| def __call__(self, img, n): | |
| img_np = K.tensor_to_image(img) | |
| self.retrieval(img_np, n) | |
| self.imcache(img_np, n) | |
| def keyframe(self, k): | |
| self.retrieval.keyframe(k) | |
| self.imcache.keyframe(k) | |
| def estimate_3d_keypoints(self, i): | |
| """ Detect, match and triangulate 3D points """ | |
| """ Load the triplet of frames """ | |
| image_orig = self.imcache.load_frames([i-1,i,i+1], self.pg.intrinsics.device) | |
| image = image_orig.float() / 255 | |
| fl = self.detect_keypoints(image) | |
| """ Form keypoint trajectories """ | |
| trajectories = torch.full((2048, 3), -1, device='cuda', dtype=torch.long) | |
| trajectories[:,1] = torch.arange(2048) | |
| out = self.matcher({"image0": fl[0], "image1": fl[1]}) | |
| i0, i1 = out["matches"][0].mT | |
| trajectories[i1, 0] = i0 | |
| out = self.matcher({"image0": fl[2], "image1": fl[1]}) | |
| i2, i1 = out["matches"][0].mT | |
| trajectories[i1, 2] = i2 | |
| trajectories = trajectories[torch.randperm(2048)] | |
| trajectories = trajectories[trajectories.min(dim=1).values >= 0] | |
| a,b,c = trajectories.mT | |
| n, _ = trajectories.shape | |
| kps0 = fl[0]['keypoints'][:,a] | |
| kps1 = fl[1]['keypoints'][:,b] | |
| kps2 = fl[2]['keypoints'][:,c] | |
| desc1 = fl[1]['descriptors'][:,b] | |
| image_size = fl[1]["image_size"] | |
| kk = torch.arange(n).cuda().repeat(2) | |
| ii = torch.ones(2*n, device='cuda', dtype=torch.long) | |
| jj = torch.zeros(2*n, device='cuda', dtype=torch.long) | |
| jj[n:] = 2 | |
| """ Construct "mini" patch graph. """ | |
| true_disp = self.pg.patches_[i,:,2,1,1].median() | |
| patches = torch.cat((kps1, torch.ones(1, n, 1).cuda() * true_disp), dim=-1) | |
| patches = repeat(patches, '1 n uvd -> 1 n uvd 3 3', uvd=3) | |
| target = rearrange(torch.stack((kps0, kps2)), 'ot 1 n uv -> 1 (ot n) uv', uv=2, n=n, ot=2) | |
| weight = torch.ones_like(target) | |
| poses = self.pg.poses[:,i-1:i+2].clone() | |
| intrinsics = self.pg.intrinsics[:,i-1:i+2].clone() * 4 | |
| coords = pops.transform(SE3(poses), patches, intrinsics, ii, jj, kk) | |
| coords = coords[:,:,1,1] | |
| residual = (coords - target).norm(dim=-1).squeeze(0) | |
| """ structure-only bundle adjustment """ | |
| lmbda = torch.as_tensor([1e-3], device="cuda") | |
| fastba.BA(poses, patches, intrinsics, | |
| target, weight, lmbda, ii, jj, kk, 3, 3, M=-1, iterations=6, eff_impl=False) | |
| """ Only keep points with small residuals """ | |
| coords = pops.transform(SE3(poses), patches, intrinsics, ii, jj, kk) | |
| coords = coords[:,:,1,1] | |
| residual = (coords - target).norm(dim=-1).squeeze(0) | |
| assert residual.numel() == 2*n | |
| mask = scatter_max(residual, kk)[0] < 2 | |
| """ Un-project keypoints """ | |
| points = pops.iproj(patches, intrinsics[:,torch.ones(n, device='cuda', dtype=torch.long)]) | |
| points = (points[...,1,1,:3] / points[...,1,1,3:]) | |
| return points[:,mask].squeeze(0), {"keypoints": kps1[:,mask], "descriptors": desc1[:,mask], "image_size": image_size} | |
| def attempt_loop_closure(self, n): | |
| if self.lc_in_progress: | |
| return | |
| """ Check if a loop was detected """ | |
| cands = self.retrieval.detect_loop(thresh=self.cfg.LOOP_RETR_THRESH, num_repeat=self.cfg.LOOP_CLOSE_WINDOW_SIZE) | |
| if cands is not None: | |
| i, j = cands | |
| """ A loop was detected. Try to close it """ | |
| lc_result = self.close_loop(i, j, n) | |
| self.lc_count += int(lc_result) | |
| """ Avoid multiple back-to-back detections """ | |
| if lc_result: | |
| self.retrieval.confirm_loop(i, j) | |
| self.retrieval.found.clear() | |
| """ "Flush" the queue of frames into the loop-closure pipeline """ | |
| self.retrieval.save_up_to(n - self.cfg.REMOVAL_WINDOW - 2) | |
| self.imcache.save_up_to(n - self.cfg.REMOVAL_WINDOW - 1) | |
| def terminate(self, n): | |
| self.retrieval.save_up_to(n-1) | |
| self.imcache.save_up_to(n-1) | |
| self.attempt_loop_closure(n) | |
| if self.lc_in_progress: | |
| self.lc_callback(skip_if_empty=False) | |
| self.lc_process.get() | |
| self.imcache.close() | |
| self.lc_pool.close() | |
| self.retrieval.close() | |
| print(f"LC COUNT: {self.lc_count}") | |
| def _rescale_deltas(self, s): | |
| """ Rescale the poses of removed frames by their predicted scales """ | |
| tstamp_2_rescale = {} | |
| for i in range(self.pg.n): | |
| tstamp_2_rescale[self.pg.tstamps_[i]] = s[i] | |
| for t, (t0, dP) in self.pg.delta.items(): | |
| t_src = t | |
| while t_src in self.pg.delta: | |
| t_src, _ = self.pg.delta[t_src] | |
| s1 = tstamp_2_rescale[t_src] | |
| self.pg.delta[t] = (t0, dP.scale(s1)) | |
| def lc_callback(self, skip_if_empty=True): | |
| """ Check if the PGO finished running """ | |
| if skip_if_empty and self.result_queue.empty(): | |
| return | |
| self.lc_in_progress = False | |
| final_est = self.result_queue.get() | |
| safe_i, _ = final_est.shape | |
| res, s = final_est.tensor().cuda().split([7,1], dim=1) | |
| s1 = torch.ones(self.pg.n, device=s.device) | |
| s1[:safe_i] = s.squeeze() | |
| self.pg.poses_[:safe_i] = SE3(res).inv().data | |
| self.pg.patches_[:safe_i,:,2] /= s.view(safe_i, 1, 1, 1) | |
| self._rescale_deltas(s1) | |
| self.pg.normalize() | |
| def close_loop(self, i, j, n): | |
| """ This function tries to actually execute the loop closure """ | |
| MIN_NUM_INLIERS = 30 # Minimum number of inlier matches | |
| # print("Found a match!", i, j) | |
| """ Estimate 3d keypoints w/ features""" | |
| i_pts, i_feat = self.estimate_3d_keypoints(i) | |
| j_pts, j_feat = self.estimate_3d_keypoints(j) | |
| _, _, iz = i_pts.mT | |
| _, _, jz = j_pts.mT | |
| th = 20 # a depth threshold. Far-away points aren't helpful | |
| i_pts = i_pts[iz < th] | |
| j_pts = j_pts[jz < th] | |
| for key in ['keypoints', 'descriptors']: | |
| i_feat[key] = i_feat[key][:,iz < th] | |
| j_feat[key] = j_feat[key][:,jz < th] | |
| # Early exit | |
| if i_pts.numel() < MIN_NUM_INLIERS: | |
| # print(f"Too few inliers (A): {i_pts.numel()=}") | |
| return False | |
| """ Match between the two point clouds """ | |
| out = self.matcher({"image0": i_feat, "image1": j_feat}) | |
| i_ind, j_ind = out["matches"][0].mT | |
| i_pts = i_pts[i_ind] | |
| j_pts = j_pts[j_ind] | |
| assert i_pts.shape == j_pts.shape, (i_pts.shape, j_pts.shape) | |
| i_pts, j_pts = asnumpy(i_pts.double()), asnumpy(j_pts.double()) | |
| # Early exit | |
| if i_pts.size < MIN_NUM_INLIERS: | |
| # print(f"Too few inliers (B): {i_pts.size=}") | |
| return False | |
| """ Estimate Sim(3) transformation """ | |
| r, t, s, num_inliers = ransac_umeyama(i_pts, j_pts, iterations=400, threshold=0.1) # threshold shouldn't be too low | |
| # Exist if number of inlier matches is too small | |
| if num_inliers < MIN_NUM_INLIERS: | |
| # print(f"Too few inliers (C): {num_inliers=}") | |
| return False | |
| """ Run Pose-Graph Optimization (PGO) """ | |
| far_rel_pose = make_pypose_Sim3(r, t, s)[None] | |
| Gi = pp.SE3(self.pg.poses[:,self.loop_ii]) | |
| Gj = pp.SE3(self.pg.poses[:,self.loop_jj]) | |
| Gij = Gj * Gi.Inv() | |
| prev_sim3 = SE3_to_Sim3(Gij).data[0].cpu() | |
| loop_poses = pp.Sim3(torch.cat((prev_sim3, far_rel_pose))) | |
| loop_ii = torch.cat((self.loop_ii, torch.tensor([i]))) | |
| loop_jj = torch.cat((self.loop_jj, torch.tensor([j]))) | |
| pred_poses = pp.SE3(self.pg.poses_[:n]).Inv().cpu() | |
| self.loop_ii = loop_ii | |
| self.loop_jj = loop_jj | |
| torch.set_num_threads(1) | |
| self.lc_in_progress = True | |
| self.lc_process = self.lc_pool.apply_async(run_DPVO_PGO, (pred_poses.data, loop_poses.data, loop_ii, loop_jj, self.result_queue)) | |
| return True |