Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from einops import asnumpy, reduce, repeat | |
| from . import projective_ops as pops | |
| from .lietorch import SE3 | |
| from .loop_closure.optim_utils import reduce_edges | |
| from .utils import * | |
| class PatchGraph: | |
| """ Dataclass for storing variables """ | |
| def __init__(self, cfg, P, DIM, pmem, **kwargs): | |
| self.cfg = cfg | |
| self.P = P | |
| self.pmem = pmem | |
| self.DIM = DIM | |
| self.n = 0 # number of frames | |
| self.m = 0 # number of patches | |
| self.M = self.cfg.PATCHES_PER_FRAME | |
| self.N = self.cfg.BUFFER_SIZE | |
| self.tstamps_ = np.zeros(self.N, dtype=np.int64) | |
| self.poses_ = torch.zeros(self.N, 7, dtype=torch.float, device="cuda") | |
| self.patches_ = torch.zeros(self.N, self.M, 3, self.P, self.P, dtype=torch.float, device="cuda") | |
| self.intrinsics_ = torch.zeros(self.N, 4, dtype=torch.float, device="cuda") | |
| self.points_ = torch.zeros(self.N * self.M, 3, dtype=torch.float, device="cuda") | |
| self.colors_ = torch.zeros(self.N, self.M, 3, dtype=torch.uint8, device="cuda") | |
| self.index_ = torch.zeros(self.N, self.M, dtype=torch.long, device="cuda") | |
| self.index_map_ = torch.zeros(self.N, dtype=torch.long, device="cuda") | |
| # initialize poses to identity matrix | |
| self.poses_[:,6] = 1.0 | |
| # store relative poses for removed frames | |
| self.delta = {} | |
| ### edge information ### | |
| self.net = torch.zeros(1, 0, DIM, **kwargs) | |
| self.ii = torch.as_tensor([], dtype=torch.long, device="cuda") | |
| self.jj = torch.as_tensor([], dtype=torch.long, device="cuda") | |
| self.kk = torch.as_tensor([], dtype=torch.long, device="cuda") | |
| ### inactive edge information (i.e., no longer updated, but useful for BA) ### | |
| self.ii_inac = torch.as_tensor([], dtype=torch.long, device="cuda") | |
| self.jj_inac = torch.as_tensor([], dtype=torch.long, device="cuda") | |
| self.kk_inac = torch.as_tensor([], dtype=torch.long, device="cuda") | |
| self.weight_inac = torch.zeros(1, 0, 2, dtype=torch.long, device="cuda") | |
| self.target_inac = torch.zeros(1, 0, 2, dtype=torch.long, device="cuda") | |
| def edges_loop(self): | |
| """ Adding edges from old patches to new frames """ | |
| lc_range = self.cfg.MAX_EDGE_AGE | |
| l = self.n - self.cfg.REMOVAL_WINDOW # l is the upper bound for "old" patches | |
| if l <= 0: | |
| return torch.empty(2, 0, dtype=torch.long, device='cuda') | |
| # create candidate edges | |
| jj, kk = flatmeshgrid( | |
| torch.arange(self.n - self.cfg.GLOBAL_OPT_FREQ, self.n - self.cfg.KEYFRAME_INDEX, device="cuda"), | |
| torch.arange(max(l - lc_range, 0) * self.M, l * self.M, device="cuda"), indexing='ij') | |
| ii = self.ix[kk] | |
| # Remove edges which have too large flow magnitude | |
| flow_mg, val = pops.flow_mag(SE3(self.poses), self.patches[...,1,1].view(1,-1,3,1,1), self.intrinsics, ii, jj, kk, beta=0.5) | |
| flow_mg_sum = reduce(flow_mg * val, '1 (fl M) 1 1 -> fl', 'sum', M=self.M).float() | |
| num_val = reduce(val, '1 (fl M) 1 1 -> fl', 'sum', M=self.M).clamp(min=1) | |
| flow_mag = torch.where(num_val > (self.M * 0.75), flow_mg_sum / num_val, torch.inf) | |
| mask = (flow_mag < self.cfg.BACKEND_THRESH) | |
| es = reduce_edges(asnumpy(flow_mag[mask]), asnumpy(ii[::self.M][mask]), asnumpy(jj[::self.M][mask]), max_num_edges=1000, nms=1) | |
| edges = torch.as_tensor(es, device=ii.device) | |
| ii, jj = repeat(edges, 'E ij -> ij E M', M=self.M, ij=2) | |
| kk = ii.mul(self.M) + torch.arange(self.M, device=ii.device) | |
| return kk.flatten(), jj.flatten() | |
| def normalize(self): | |
| """ normalize depth and poses """ | |
| s = self.patches_[:self.n,:,2].mean() | |
| self.patches_[:self.n,:,2] /= s | |
| self.poses_[:self.n,:3] *= s | |
| for t, (t0, dP) in self.delta.items(): | |
| self.delta[t] = (t0, dP.scale(s)) | |
| self.poses_[:self.n] = (SE3(self.poses_[:self.n]) * SE3(self.poses_[[0]]).inv()).data | |
| points = pops.point_cloud(SE3(self.poses), self.patches[:, :self.m], self.intrinsics, self.ix[:self.m]) | |
| points = (points[...,1,1,:3] / points[...,1,1,3:]).reshape(-1, 3) | |
| self.points_[:len(points)] = points[:] | |
| def poses(self): | |
| return self.poses_.view(1, self.N, 7) | |
| def patches(self): | |
| return self.patches_.view(1, self.N*self.M, 3, 3, 3) | |
| def intrinsics(self): | |
| return self.intrinsics_.view(1, self.N, 4) | |
| def ix(self): | |
| return self.index_.view(-1) | |