Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| def se3_inverse(T): | |
| """ | |
| Computes the inverse of a batch of SE(3) matrices. | |
| T: Tensor of shape (B, 4, 4) | |
| """ | |
| if len(T.shape) == 2: | |
| T = T[None] | |
| unseq_flag = True | |
| else: | |
| unseq_flag = False | |
| if torch.is_tensor(T): | |
| R = T[:, :3, :3] | |
| t = T[:, :3, 3].unsqueeze(-1) | |
| R_inv = R.transpose(-2, -1) | |
| t_inv = -torch.matmul(R_inv, t) | |
| T_inv = torch.cat([ | |
| torch.cat([R_inv, t_inv], dim=-1), | |
| torch.tensor([0, 0, 0, 1], device=T.device, dtype=T.dtype).repeat(T.shape[0], 1, 1) | |
| ], dim=1) | |
| else: | |
| R = T[:, :3, :3] | |
| t = T[:, :3, 3, np.newaxis] | |
| R_inv = np.swapaxes(R, -2, -1) | |
| t_inv = -R_inv @ t | |
| bottom_row = np.zeros((T.shape[0], 1, 4), dtype=T.dtype) | |
| bottom_row[:, :, 3] = 1 | |
| top_part = np.concatenate([R_inv, t_inv], axis=-1) | |
| T_inv = np.concatenate([top_part, bottom_row], axis=1) | |
| if unseq_flag: | |
| T_inv = T_inv[0] | |
| return T_inv | |
| def get_pixel(H, W): | |
| # get 2D pixels (u, v) for image_a in cam_a pixel space | |
| u_a, v_a = np.meshgrid(np.arange(W), np.arange(H)) | |
| # u_a = np.flip(u_a, axis=1) | |
| # v_a = np.flip(v_a, axis=0) | |
| pixels_a = np.stack([ | |
| u_a.flatten() + 0.5, | |
| v_a.flatten() + 0.5, | |
| np.ones_like(u_a.flatten()) | |
| ], axis=0) | |
| return pixels_a | |
| def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, z_far=0, **kw): | |
| """ | |
| Args: | |
| - depthmap (HxW array): | |
| - camera_intrinsics: a 3x3 matrix | |
| - camera_pose: a 4x3 or 4x4 cam2world matrix | |
| Returns: | |
| pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.""" | |
| X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) | |
| if z_far > 0: | |
| valid_mask = valid_mask & (depthmap < z_far) | |
| X_world = X_cam # default | |
| if camera_pose is not None: | |
| # R_cam2world = np.float32(camera_params["R_cam2world"]) | |
| # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() | |
| R_cam2world = camera_pose[:3, :3] | |
| t_cam2world = camera_pose[:3, 3] | |
| # Express in absolute coordinates (invalid depth values) | |
| X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] | |
| return X_world, valid_mask | |
| def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): | |
| """ | |
| Args: | |
| - depthmap (HxW array): | |
| - camera_intrinsics: a 3x3 matrix | |
| Returns: | |
| pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. | |
| """ | |
| camera_intrinsics = np.float32(camera_intrinsics) | |
| H, W = depthmap.shape | |
| # Compute 3D ray associated with each pixel | |
| # Strong assumption: there are no skew terms | |
| # assert camera_intrinsics[0, 1] == 0.0 | |
| # assert camera_intrinsics[1, 0] == 0.0 | |
| if pseudo_focal is None: | |
| fu = camera_intrinsics[0, 0] | |
| fv = camera_intrinsics[1, 1] | |
| else: | |
| assert pseudo_focal.shape == (H, W) | |
| fu = fv = pseudo_focal | |
| cu = camera_intrinsics[0, 2] | |
| cv = camera_intrinsics[1, 2] | |
| u, v = np.meshgrid(np.arange(W), np.arange(H)) | |
| z_cam = depthmap | |
| x_cam = (u - cu) * z_cam / fu | |
| y_cam = (v - cv) * z_cam / fv | |
| X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) | |
| # Mask for valid coordinates | |
| valid_mask = (depthmap > 0.0) | |
| # Invalid any depth > 80m | |
| valid_mask = valid_mask | |
| return X_cam, valid_mask | |
| def homogenize_points( | |
| points, | |
| ): | |
| """Convert batched points (xyz) to (xyz1).""" | |
| return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) | |
| def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None): | |
| if H is None: | |
| B,H,W = depth1.shape | |
| else: | |
| B = depth1.shape[0] | |
| with torch.no_grad(): | |
| x1_n = torch.meshgrid( | |
| *[ | |
| torch.linspace( | |
| -1 + 1 / n, 1 - 1 / n, n, device=depth1.device | |
| ) | |
| for n in (B, H, W) | |
| ], | |
| indexing = 'ij' | |
| ) | |
| x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) | |
| mask, x2 = warp_kpts( | |
| x1_n.double(), | |
| depth1.double(), | |
| depth2.double(), | |
| T_1to2.double(), | |
| K1.double(), | |
| K2.double(), | |
| depth_interpolation_mode = depth_interpolation_mode, | |
| relative_depth_error_threshold = relative_depth_error_threshold, | |
| ) | |
| prob = mask.float().reshape(B, H, W) | |
| x2 = x2.reshape(B, H, W, 2) | |
| return x2, prob | |
| def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05): | |
| """Warp kpts0 from I0 to I1 with depth, K and Rt | |
| Also check covisibility and depth consistency. | |
| Depth is consistent if relative error < 0.2 (hard-coded). | |
| # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here | |
| Args: | |
| kpts0 (torch.Tensor): [N, L, 2] - <x, y>, should be normalized in (-1,1) | |
| depth0 (torch.Tensor): [N, H, W], | |
| depth1 (torch.Tensor): [N, H, W], | |
| T_0to1 (torch.Tensor): [N, 3, 4], | |
| K0 (torch.Tensor): [N, 3, 3], | |
| K1 (torch.Tensor): [N, 3, 3], | |
| Returns: | |
| calculable_mask (torch.Tensor): [N, L] | |
| warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat> | |
| """ | |
| ( | |
| n, | |
| h, | |
| w, | |
| ) = depth0.shape | |
| if depth_interpolation_mode == "combined": | |
| # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation | |
| if smooth_mask: | |
| raise NotImplementedError("Combined bilinear and NN warp not implemented") | |
| valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, | |
| smooth_mask = smooth_mask, | |
| return_relative_depth_error = return_relative_depth_error, | |
| depth_interpolation_mode = "bilinear", | |
| relative_depth_error_threshold = relative_depth_error_threshold) | |
| valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, | |
| smooth_mask = smooth_mask, | |
| return_relative_depth_error = return_relative_depth_error, | |
| depth_interpolation_mode = "nearest-exact", | |
| relative_depth_error_threshold = relative_depth_error_threshold) | |
| nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) | |
| warp = warp_bilinear.clone() | |
| warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid] | |
| valid = valid_bilinear | valid_nearest | |
| return valid, warp | |
| kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[ | |
| :, 0, :, 0 | |
| ] | |
| kpts0 = torch.stack( | |
| (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 | |
| ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] | |
| # Sample depth, get calculable_mask on depth != 0 | |
| # nonzero_mask = kpts0_depth != 0 | |
| # Sample depth, get calculable_mask on depth > 0 | |
| nonzero_mask = kpts0_depth > 0 | |
| # Unproject | |
| kpts0_h = ( | |
| torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) | |
| * kpts0_depth[..., None] | |
| ) # (N, L, 3) | |
| kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) | |
| kpts0_cam = kpts0_n | |
| # Rigid Transform | |
| w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) | |
| w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] | |
| # Project | |
| w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) | |
| w_kpts0 = w_kpts0_h[:, :, :2] / ( | |
| w_kpts0_h[:, :, [2]] + 1e-4 | |
| ) # (N, L, 2), +1e-4 to avoid zero depth | |
| # Covisible Check | |
| h, w = depth1.shape[1:3] | |
| covisible_mask = ( | |
| (w_kpts0[:, :, 0] > 0) | |
| * (w_kpts0[:, :, 0] < w - 1) | |
| * (w_kpts0[:, :, 1] > 0) | |
| * (w_kpts0[:, :, 1] < h - 1) | |
| ) | |
| w_kpts0 = torch.stack( | |
| (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1 | |
| ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h] | |
| # w_kpts0[~covisible_mask, :] = -5 # xd | |
| w_kpts0_depth = F.grid_sample( | |
| depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False | |
| )[:, 0, :, 0] | |
| relative_depth_error = ( | |
| (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth | |
| ).abs() | |
| if not smooth_mask: | |
| consistent_mask = relative_depth_error < relative_depth_error_threshold | |
| else: | |
| consistent_mask = (-relative_depth_error/smooth_mask).exp() | |
| valid_mask = nonzero_mask * covisible_mask * consistent_mask | |
| if return_relative_depth_error: | |
| return relative_depth_error, w_kpts0 | |
| else: | |
| return valid_mask, w_kpts0 | |
| def geotrf(Trf, pts, ncol=None, norm=False): | |
| """ Apply a geometric transformation to a list of 3-D points. | |
| H: 3x3 or 4x4 projection matrix (typically a Homography) | |
| p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) | |
| ncol: int. number of columns of the result (2 or 3) | |
| norm: float. if != 0, the resut is projected on the z=norm plane. | |
| Returns an array of projected 2d points. | |
| """ | |
| assert Trf.ndim >= 2 | |
| if isinstance(Trf, np.ndarray): | |
| pts = np.asarray(pts) | |
| elif isinstance(Trf, torch.Tensor): | |
| pts = torch.as_tensor(pts, dtype=Trf.dtype) | |
| # adapt shape if necessary | |
| output_reshape = pts.shape[:-1] | |
| ncol = ncol or pts.shape[-1] | |
| # optimized code | |
| if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and | |
| Trf.ndim == 3 and pts.ndim == 4): | |
| d = pts.shape[3] | |
| if Trf.shape[-1] == d: | |
| pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) | |
| elif Trf.shape[-1] == d + 1: | |
| pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] | |
| else: | |
| raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') | |
| else: | |
| if Trf.ndim >= 3: | |
| n = Trf.ndim - 2 | |
| assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' | |
| Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) | |
| if pts.ndim > Trf.ndim: | |
| # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) | |
| pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) | |
| elif pts.ndim == 2: | |
| # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) | |
| pts = pts[:, None, :] | |
| if pts.shape[-1] + 1 == Trf.shape[-1]: | |
| Trf = Trf.swapaxes(-1, -2) # transpose Trf | |
| pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] | |
| elif pts.shape[-1] == Trf.shape[-1]: | |
| Trf = Trf.swapaxes(-1, -2) # transpose Trf | |
| pts = pts @ Trf | |
| else: | |
| pts = Trf @ pts.T | |
| if pts.ndim >= 2: | |
| pts = pts.swapaxes(-1, -2) | |
| if norm: | |
| pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG | |
| if norm != 1: | |
| pts *= norm | |
| res = pts[..., :ncol].reshape(*output_reshape, ncol) | |
| return res | |
| def inv(mat): | |
| """ Invert a torch or numpy matrix | |
| """ | |
| if isinstance(mat, torch.Tensor): | |
| return torch.linalg.inv(mat) | |
| if isinstance(mat, np.ndarray): | |
| return np.linalg.inv(mat) | |
| raise ValueError(f'bad matrix type = {type(mat)}') | |
| def opencv_camera_to_plucker(poses, K, H, W): | |
| device = poses.device | |
| B = poses.shape[0] | |
| pixel = torch.from_numpy(get_pixel(H, W).astype(np.float32)).to(device).T.reshape(H, W, 3)[None].repeat(B, 1, 1, 1) # (3, H, W) | |
| pixel = torch.einsum('bij, bhwj -> bhwi', torch.inverse(K), pixel) | |
| ray_directions = torch.einsum('bij, bhwj -> bhwi', poses[..., :3, :3], pixel) | |
| ray_origins = poses[..., :3, 3][:, None, None].repeat(1, H, W, 1) | |
| ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True) | |
| plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1) | |
| plucker_ray = torch.cat([ray_directions, plucker_normal], dim=-1) | |
| return plucker_ray | |
| def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor: | |
| """ | |
| Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth. | |
| Args: | |
| depth (torch.Tensor): shape (..., height, width), linear depth map | |
| atol (float): absolute tolerance | |
| rtol (float): relative tolerance | |
| Returns: | |
| edge (torch.Tensor): shape (..., height, width) of dtype torch.bool | |
| """ | |
| shape = depth.shape | |
| depth = depth.reshape(-1, 1, *shape[-2:]) | |
| if mask is not None: | |
| mask = mask.reshape(-1, 1, *shape[-2:]) | |
| if mask is None: | |
| diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)) | |
| else: | |
| diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)) | |
| edge = torch.zeros_like(depth, dtype=torch.bool) | |
| if atol is not None: | |
| edge |= diff > atol | |
| if rtol is not None: | |
| edge |= (diff / depth).nan_to_num_() > rtol | |
| edge = edge.reshape(*shape) | |
| return edge |