Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import torch.multiprocessing as mp | |
| import torch.nn.functional as F | |
| from . import altcorr, fastba, lietorch | |
| from . import projective_ops as pops | |
| from .lietorch import SE3 | |
| from .net import VONet | |
| from .patchgraph import PatchGraph | |
| from .utils import * | |
| mp.set_start_method('spawn', True) | |
| autocast = torch.cuda.amp.autocast | |
| Id = SE3.Identity(1, device="cuda") | |
| class DPVO: | |
| def __init__(self, cfg, network, ht=480, wd=640, viz=False): | |
| self.cfg = cfg | |
| self.load_weights(network) | |
| self.is_initialized = False | |
| self.enable_timing = False | |
| torch.set_num_threads(2) | |
| self.M = self.cfg.PATCHES_PER_FRAME | |
| self.N = self.cfg.BUFFER_SIZE | |
| self.ht = ht # image height | |
| self.wd = wd # image width | |
| DIM = self.DIM | |
| RES = self.RES | |
| ### state attributes ### | |
| self.tlist = [] | |
| self.counter = 0 | |
| # keep track of global-BA calls | |
| self.ran_global_ba = np.zeros(100000, dtype=bool) | |
| ht = ht // RES | |
| wd = wd // RES | |
| # dummy image for visualization | |
| self.image_ = torch.zeros(self.ht, self.wd, 3, dtype=torch.uint8, device="cpu") | |
| ### network attributes ### | |
| if self.cfg.MIXED_PRECISION: | |
| self.kwargs = kwargs = {"device": "cuda", "dtype": torch.half} | |
| else: | |
| self.kwargs = kwargs = {"device": "cuda", "dtype": torch.float} | |
| ### frame memory size ### | |
| self.pmem = self.mem = 36 # 32 was too small given default settings | |
| if self.cfg.LOOP_CLOSURE: | |
| self.last_global_ba = -1000 # keep track of time since last global opt | |
| self.pmem = self.cfg.MAX_EDGE_AGE # patch memory | |
| self.imap_ = torch.zeros(self.pmem, self.M, DIM, **kwargs) | |
| self.gmap_ = torch.zeros(self.pmem, self.M, 128, self.P, self.P, **kwargs) | |
| self.pg = PatchGraph(self.cfg, self.P, self.DIM, self.pmem, **kwargs) | |
| # classic backend | |
| if self.cfg.CLASSIC_LOOP_CLOSURE: | |
| self.load_long_term_loop_closure() | |
| self.fmap1_ = torch.zeros(1, self.mem, 128, ht // 1, wd // 1, **kwargs) | |
| self.fmap2_ = torch.zeros(1, self.mem, 128, ht // 4, wd // 4, **kwargs) | |
| # feature pyramid | |
| self.pyramid = (self.fmap1_, self.fmap2_) | |
| self.viewer = None | |
| if viz: | |
| self.start_viewer() | |
| def load_long_term_loop_closure(self): | |
| try: | |
| from .loop_closure.long_term import LongTermLoopClosure | |
| self.long_term_lc = LongTermLoopClosure(self.cfg, self.pg) | |
| except ModuleNotFoundError as e: | |
| self.cfg.CLASSIC_LOOP_CLOSURE = False | |
| print(f"WARNING: {e}") | |
| def load_weights(self, network): | |
| # load network from checkpoint file | |
| if isinstance(network, str): | |
| from collections import OrderedDict | |
| state_dict = torch.load(network) | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| if "update.lmbda" not in k: | |
| new_state_dict[k.replace('module.', '')] = v | |
| self.network = VONet() | |
| self.network.load_state_dict(new_state_dict) | |
| else: | |
| self.network = network | |
| # steal network attributes | |
| self.DIM = self.network.DIM | |
| self.RES = self.network.RES | |
| self.P = self.network.P | |
| self.network.cuda() | |
| self.network.eval() | |
| def start_viewer(self): | |
| from dpviewer import Viewer | |
| intrinsics_ = torch.zeros(1, 4, dtype=torch.float32, device="cuda") | |
| self.viewer = Viewer( | |
| self.image_, | |
| self.pg.poses_, | |
| self.pg.points_, | |
| self.pg.colors_, | |
| intrinsics_) | |
| def poses(self): | |
| return self.pg.poses_.view(1, self.N, 7) | |
| def patches(self): | |
| return self.pg.patches_.view(1, self.N*self.M, 3, 3, 3) | |
| def intrinsics(self): | |
| return self.pg.intrinsics_.view(1, self.N, 4) | |
| def ix(self): | |
| return self.pg.index_.view(-1) | |
| def imap(self): | |
| return self.imap_.view(1, self.pmem * self.M, self.DIM) | |
| def gmap(self): | |
| return self.gmap_.view(1, self.pmem * self.M, 128, 3, 3) | |
| def n(self): | |
| return self.pg.n | |
| def n(self, val): | |
| self.pg.n = val | |
| def m(self): | |
| return self.pg.m | |
| def m(self, val): | |
| self.pg.m = val | |
| def get_pose(self, t): | |
| if t in self.traj: | |
| return SE3(self.traj[t]) | |
| t0, dP = self.pg.delta[t] | |
| return dP * self.get_pose(t0) | |
| def terminate(self): | |
| if self.cfg.CLASSIC_LOOP_CLOSURE: | |
| self.long_term_lc.terminate(self.n) | |
| if self.cfg.LOOP_CLOSURE: | |
| self.append_factors(*self.pg.edges_loop()) | |
| for _ in range(12): | |
| self.ran_global_ba[self.n] = False | |
| self.update() | |
| """ interpolate missing poses """ | |
| self.traj = {} | |
| for i in range(self.n): | |
| self.traj[self.pg.tstamps_[i]] = self.pg.poses_[i] | |
| poses = [self.get_pose(t) for t in range(self.counter)] | |
| poses = lietorch.stack(poses, dim=0) | |
| poses = poses.inv().data.cpu().numpy() | |
| tstamps = np.array(self.tlist, dtype=np.float64) | |
| if self.viewer is not None: | |
| self.viewer.join() | |
| # Poses: x y z qx qy qz qw | |
| return poses, tstamps | |
| def corr(self, coords, indicies=None): | |
| """ local correlation volume """ | |
| ii, jj = indicies if indicies is not None else (self.pg.kk, self.pg.jj) | |
| ii1 = ii % (self.M * self.pmem) | |
| jj1 = jj % (self.mem) | |
| corr1 = altcorr.corr(self.gmap, self.pyramid[0], coords / 1, ii1, jj1, 3) | |
| corr2 = altcorr.corr(self.gmap, self.pyramid[1], coords / 4, ii1, jj1, 3) | |
| return torch.stack([corr1, corr2], -1).view(1, len(ii), -1) | |
| def reproject(self, indicies=None): | |
| """ reproject patch k from i -> j """ | |
| (ii, jj, kk) = indicies if indicies is not None else (self.pg.ii, self.pg.jj, self.pg.kk) | |
| coords = pops.transform(SE3(self.poses), self.patches, self.intrinsics, ii, jj, kk) | |
| return coords.permute(0, 1, 4, 2, 3).contiguous() | |
| def append_factors(self, ii, jj): | |
| self.pg.jj = torch.cat([self.pg.jj, jj]) | |
| self.pg.kk = torch.cat([self.pg.kk, ii]) | |
| self.pg.ii = torch.cat([self.pg.ii, self.ix[ii]]) | |
| net = torch.zeros(1, len(ii), self.DIM, **self.kwargs) | |
| self.pg.net = torch.cat([self.pg.net, net], dim=1) | |
| def remove_factors(self, m, store: bool): | |
| assert self.pg.ii.numel() == self.pg.weight.shape[1] | |
| if store: | |
| self.pg.ii_inac = torch.cat((self.pg.ii_inac, self.pg.ii[m])) | |
| self.pg.jj_inac = torch.cat((self.pg.jj_inac, self.pg.jj[m])) | |
| self.pg.kk_inac = torch.cat((self.pg.kk_inac, self.pg.kk[m])) | |
| self.pg.weight_inac = torch.cat((self.pg.weight_inac, self.pg.weight[:,m]), dim=1) | |
| self.pg.target_inac = torch.cat((self.pg.target_inac, self.pg.target[:,m]), dim=1) | |
| self.pg.weight = self.pg.weight[:,~m] | |
| self.pg.target = self.pg.target[:,~m] | |
| self.pg.ii = self.pg.ii[~m] | |
| self.pg.jj = self.pg.jj[~m] | |
| self.pg.kk = self.pg.kk[~m] | |
| self.pg.net = self.pg.net[:,~m] | |
| assert self.pg.ii.numel() == self.pg.weight.shape[1] | |
| def motion_probe(self): | |
| """ kinda hacky way to ensure enough motion for initialization """ | |
| kk = torch.arange(self.m-self.M, self.m, device="cuda") | |
| jj = self.n * torch.ones_like(kk) | |
| ii = self.ix[kk] | |
| net = torch.zeros(1, len(ii), self.DIM, **self.kwargs) | |
| coords = self.reproject(indicies=(ii, jj, kk)) | |
| with autocast(enabled=self.cfg.MIXED_PRECISION): | |
| corr = self.corr(coords, indicies=(kk, jj)) | |
| ctx = self.imap[:,kk % (self.M * self.pmem)] | |
| net, (delta, weight, _) = \ | |
| self.network.update(net, ctx, corr, None, ii, jj, kk) | |
| return torch.quantile(delta.norm(dim=-1).float(), 0.5) | |
| def motionmag(self, i, j): | |
| k = (self.pg.ii == i) & (self.pg.jj == j) | |
| ii = self.pg.ii[k] | |
| jj = self.pg.jj[k] | |
| kk = self.pg.kk[k] | |
| flow, _ = pops.flow_mag(SE3(self.poses), self.patches, self.intrinsics, ii, jj, kk, beta=0.5) | |
| return flow.mean().item() | |
| def keyframe(self): | |
| i = self.n - self.cfg.KEYFRAME_INDEX - 1 | |
| j = self.n - self.cfg.KEYFRAME_INDEX + 1 | |
| m = self.motionmag(i, j) + self.motionmag(j, i) | |
| if m / 2 < self.cfg.KEYFRAME_THRESH: | |
| k = self.n - self.cfg.KEYFRAME_INDEX | |
| t0 = self.pg.tstamps_[k-1] | |
| t1 = self.pg.tstamps_[k] | |
| dP = SE3(self.pg.poses_[k]) * SE3(self.pg.poses_[k-1]).inv() | |
| self.pg.delta[t1] = (t0, dP) | |
| to_remove = (self.pg.ii == k) | (self.pg.jj == k) | |
| self.remove_factors(to_remove, store=False) | |
| self.pg.kk[self.pg.ii > k] -= self.M | |
| self.pg.ii[self.pg.ii > k] -= 1 | |
| self.pg.jj[self.pg.jj > k] -= 1 | |
| for i in range(k, self.n-1): | |
| self.pg.tstamps_[i] = self.pg.tstamps_[i+1] | |
| self.pg.colors_[i] = self.pg.colors_[i+1] | |
| self.pg.poses_[i] = self.pg.poses_[i+1] | |
| self.pg.patches_[i] = self.pg.patches_[i+1] | |
| self.pg.intrinsics_[i] = self.pg.intrinsics_[i+1] | |
| self.imap_[i % self.pmem] = self.imap_[(i+1) % self.pmem] | |
| self.gmap_[i % self.pmem] = self.gmap_[(i+1) % self.pmem] | |
| self.fmap1_[0,i%self.mem] = self.fmap1_[0,(i+1)%self.mem] | |
| self.fmap2_[0,i%self.mem] = self.fmap2_[0,(i+1)%self.mem] | |
| self.n -= 1 | |
| self.m-= self.M | |
| if self.cfg.CLASSIC_LOOP_CLOSURE: | |
| self.long_term_lc.keyframe(k) | |
| to_remove = self.ix[self.pg.kk] < self.n - self.cfg.REMOVAL_WINDOW # Remove edges falling outside the optimization window | |
| if self.cfg.LOOP_CLOSURE: | |
| # ...unless they are being used for loop closure | |
| lc_edges = ((self.pg.jj - self.pg.ii) > 30) & (self.pg.jj > (self.n - self.cfg.OPTIMIZATION_WINDOW)) | |
| to_remove = to_remove & ~lc_edges | |
| self.remove_factors(to_remove, store=True) | |
| def __run_global_BA(self): | |
| """ Global bundle adjustment | |
| Includes both active and inactive edges """ | |
| full_target = torch.cat((self.pg.target_inac, self.pg.target), dim=1) | |
| full_weight = torch.cat((self.pg.weight_inac, self.pg.weight), dim=1) | |
| full_ii = torch.cat((self.pg.ii_inac, self.pg.ii)) | |
| full_jj = torch.cat((self.pg.jj_inac, self.pg.jj)) | |
| full_kk = torch.cat((self.pg.kk_inac, self.pg.kk)) | |
| self.pg.normalize() | |
| lmbda = torch.as_tensor([1e-4], device="cuda") | |
| t0 = self.pg.ii.min().item() | |
| fastba.BA(self.poses, self.patches, self.intrinsics, | |
| full_target, full_weight, lmbda, full_ii, full_jj, full_kk, t0, self.n, M=self.M, iterations=2, eff_impl=True) | |
| self.ran_global_ba[self.n] = True | |
| def update(self): | |
| with Timer("other", enabled=self.enable_timing): | |
| coords = self.reproject() | |
| with autocast(enabled=True): | |
| corr = self.corr(coords) | |
| ctx = self.imap[:, self.pg.kk % (self.M * self.pmem)] | |
| self.pg.net, (delta, weight, _) = \ | |
| self.network.update(self.pg.net, ctx, corr, None, self.pg.ii, self.pg.jj, self.pg.kk) | |
| lmbda = torch.as_tensor([1e-4], device="cuda") | |
| weight = weight.float() | |
| target = coords[...,self.P//2,self.P//2] + delta.float() | |
| self.pg.target = target | |
| self.pg.weight = weight | |
| with Timer("BA", enabled=self.enable_timing): | |
| try: | |
| # run global bundle adjustment if there exist long-range edges | |
| if (self.pg.ii < self.n - self.cfg.REMOVAL_WINDOW - 1).any() and not self.ran_global_ba[self.n]: | |
| self.__run_global_BA() | |
| else: | |
| t0 = self.n - self.cfg.OPTIMIZATION_WINDOW if self.is_initialized else 1 | |
| t0 = max(t0, 1) | |
| fastba.BA(self.poses, self.patches, self.intrinsics, | |
| target, weight, lmbda, self.pg.ii, self.pg.jj, self.pg.kk, t0, self.n, M=self.M, iterations=2, eff_impl=False) | |
| except: | |
| print("Warning BA failed...") | |
| points = pops.point_cloud(SE3(self.poses), self.patches[:, :self.m], self.intrinsics, self.ix[:self.m]) | |
| points = (points[...,1,1,:3] / points[...,1,1,3:]).reshape(-1, 3) | |
| self.pg.points_[:len(points)] = points[:] | |
| def __edges_forw(self): | |
| r=self.cfg.PATCH_LIFETIME | |
| t0 = self.M * max((self.n - r), 0) | |
| t1 = self.M * max((self.n - 1), 0) | |
| return flatmeshgrid( | |
| torch.arange(t0, t1, device="cuda"), | |
| torch.arange(self.n-1, self.n, device="cuda"), indexing='ij') | |
| def __edges_back(self): | |
| r=self.cfg.PATCH_LIFETIME | |
| t0 = self.M * max((self.n - 1), 0) | |
| t1 = self.M * max((self.n - 0), 0) | |
| return flatmeshgrid(torch.arange(t0, t1, device="cuda"), | |
| torch.arange(max(self.n-r, 0), self.n, device="cuda"), indexing='ij') | |
| def __call__(self, tstamp, image, intrinsics): | |
| """ track new frame """ | |
| if self.cfg.CLASSIC_LOOP_CLOSURE: | |
| self.long_term_lc(image, self.n) | |
| if (self.n+1) >= self.N: | |
| raise Exception(f'The buffer size is too small. You can increase it using "--opts BUFFER_SIZE={self.N*2}"') | |
| if self.viewer is not None: | |
| self.viewer.update_image(image.contiguous()) | |
| image = 2 * (image[None,None] / 255.0) - 0.5 | |
| with autocast(enabled=self.cfg.MIXED_PRECISION): | |
| fmap, gmap, imap, patches, _, clr = \ | |
| self.network.patchify(image, | |
| patches_per_image=self.cfg.PATCHES_PER_FRAME, | |
| centroid_sel_strat=self.cfg.CENTROID_SEL_STRAT, | |
| return_color=True) | |
| ### update state attributes ### | |
| self.tlist.append(tstamp) | |
| self.pg.tstamps_[self.n] = self.counter | |
| self.pg.intrinsics_[self.n] = intrinsics / self.RES | |
| # color info for visualization | |
| clr = (clr[0,:,[2,1,0]] + 0.5) * (255.0 / 2) | |
| self.pg.colors_[self.n] = clr.to(torch.uint8) | |
| self.pg.index_[self.n + 1] = self.n + 1 | |
| self.pg.index_map_[self.n + 1] = self.m + self.M | |
| if self.n > 1: | |
| if self.cfg.MOTION_MODEL == 'DAMPED_LINEAR': | |
| P1 = SE3(self.pg.poses_[self.n-1]) | |
| P2 = SE3(self.pg.poses_[self.n-2]) | |
| # To deal with varying camera hz | |
| *_, a,b,c = [1]*3 + self.tlist | |
| fac = (c-b) / (b-a) | |
| xi = self.cfg.MOTION_DAMPING * fac * (P1 * P2.inv()).log() | |
| tvec_qvec = (SE3.exp(xi) * P1).data | |
| self.pg.poses_[self.n] = tvec_qvec | |
| else: | |
| tvec_qvec = self.poses[self.n-1] | |
| self.pg.poses_[self.n] = tvec_qvec | |
| # TODO better depth initialization | |
| patches[:,:,2] = torch.rand_like(patches[:,:,2,0,0,None,None]) | |
| if self.is_initialized: | |
| s = torch.median(self.pg.patches_[self.n-3:self.n,:,2]) | |
| patches[:,:,2] = s | |
| self.pg.patches_[self.n] = patches | |
| ### update network attributes ### | |
| self.imap_[self.n % self.pmem] = imap.squeeze() | |
| self.gmap_[self.n % self.pmem] = gmap.squeeze() | |
| self.fmap1_[:, self.n % self.mem] = F.avg_pool2d(fmap[0], 1, 1) | |
| self.fmap2_[:, self.n % self.mem] = F.avg_pool2d(fmap[0], 4, 4) | |
| self.counter += 1 | |
| if self.n > 0 and not self.is_initialized: | |
| if self.motion_probe() < 2.0: | |
| self.pg.delta[self.counter - 1] = (self.counter - 2, Id[0]) | |
| return | |
| self.n += 1 | |
| self.m += self.M | |
| if self.cfg.LOOP_CLOSURE: | |
| if self.n - self.last_global_ba >= self.cfg.GLOBAL_OPT_FREQ: | |
| """ Add loop closure factors """ | |
| lii, ljj = self.pg.edges_loop() | |
| if lii.numel() > 0: | |
| self.last_global_ba = self.n | |
| self.append_factors(lii, ljj) | |
| # Add forward and backward factors | |
| self.append_factors(*self.__edges_forw()) | |
| self.append_factors(*self.__edges_back()) | |
| if self.n == 8 and not self.is_initialized: | |
| self.is_initialized = True | |
| for itr in range(12): | |
| self.update() | |
| elif self.is_initialized: | |
| self.update() | |
| self.keyframe() | |
| if self.cfg.CLASSIC_LOOP_CLOSURE: | |
| self.long_term_lc.attempt_loop_closure(self.n) | |
| self.long_term_lc.lc_callback() | |