#!/usr/bin/env python3 # -------------------------------------------------------- # Utils for geometric calculations # Adopted from AnyMap(Nilhil Keetha) # Includes functions from DUSt3R (Naver Corporation, CC BY-NC-SA 4.0 (non-commercial use only)) & GradSLAM (MIT License) # -------------------------------------------------------- from functools import lru_cache import einops as ein import numpy as np import torch def depthmap_to_camera_frame(depthmap, intrinsics): """ Convert depth image to a pointcloud in camera frame. Args: - depthmap: HxW torch tensor - camera_intrinsics: 3x3 torch tensor Returns: pointmap in camera frame (HxWx3 tensor), and a mask specifying valid pixels. """ height, width = depthmap.shape device = depthmap.device fx = intrinsics[0, 0] fy = intrinsics[1, 1] cx = intrinsics[0, 2] cy = intrinsics[1, 2] # Compute 3D point in camera frame associated with each pixel x_grid, y_grid = torch.meshgrid( torch.arange(width).to(device).float(), torch.arange(height).to(device).float(), indexing="xy" ) depth_z = depthmap xx = (x_grid - cx) * depth_z / fx yy = (y_grid - cy) * depth_z / fy pts3d_cam = torch.stack((xx, yy, depth_z), dim=-1) # Compute mask of valid non-zero depth pixels valid_mask = depthmap > 0.0 return pts3d_cam, valid_mask def depthmap_to_world_frame(depthmap, intrinsics, camera_pose=None): """ Convert depth image to a pointcloud in world frame. Args: - depthmap: HxW torch tensor - camera_intrinsics: 3x3 torch tensor - camera_pose: 4x4 torch tensor Returns: pointmap in world frame (HxWx3 tensor), and a mask specifying valid pixels. """ pts3d_cam, valid_mask = depthmap_to_camera_frame(depthmap, intrinsics) if camera_pose is not None: pts3d_cam_homo = torch.cat([pts3d_cam, torch.ones_like(pts3d_cam[..., :1])], dim=-1) pts3d_world = ein.einsum(camera_pose, pts3d_cam_homo, "i k, h w k -> h w i") pts3d_world = pts3d_world[..., :3] return pts3d_world, valid_mask def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw): """Output a (H,W,2) array of int32 with output[j,i,0] = i + origin[0] output[j,i,1] = j + origin[1] """ if device is None: # numpy arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones else: # torch arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) meshgrid, stack = torch.meshgrid, torch.stack ones = lambda *a: torch.ones(*a, device=device) tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] grid = meshgrid(tw, th, indexing="xy") if homogeneous: grid = grid + (ones((H, W)),) if unsqueeze is not None: grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) if cat_dim is not None: grid = stack(grid, cat_dim) return grid def geotrf(Trf, pts, ncol=None, norm=False): """Apply a geometric transformation to a list of 3-D points. H: 3x3 or 4x4 projection matrix (typically a Homography) p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) ncol: int. number of columns of the result (2 or 3) norm: float. if != 0, the resut is projected on the z=norm plane. Returns an array of projected 2d points. """ assert Trf.ndim >= 2 if isinstance(Trf, np.ndarray): pts = np.asarray(pts) elif isinstance(Trf, torch.Tensor): pts = torch.as_tensor(pts, dtype=Trf.dtype) # adapt shape if necessary output_reshape = pts.shape[:-1] ncol = ncol or pts.shape[-1] # optimized code if isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and Trf.ndim == 3 and pts.ndim == 4: d = pts.shape[3] if Trf.shape[-1] == d: pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) elif Trf.shape[-1] == d + 1: pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] else: raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") else: if Trf.ndim >= 3: n = Trf.ndim - 2 assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) if pts.ndim > Trf.ndim: # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) elif pts.ndim == 2: # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) pts = pts[:, None, :] if pts.shape[-1] + 1 == Trf.shape[-1]: Trf = Trf.swapaxes(-1, -2) # transpose Trf pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] elif pts.shape[-1] == Trf.shape[-1]: Trf = Trf.swapaxes(-1, -2) # transpose Trf pts = pts @ Trf else: pts = Trf @ pts.T if pts.ndim >= 2: pts = pts.swapaxes(-1, -2) if norm: pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG if norm != 1: pts *= norm res = pts[..., :ncol].reshape(*output_reshape, ncol) return res def inv(mat): """Invert a torch or numpy matrix""" if isinstance(mat, torch.Tensor): return torch.linalg.inv(mat) if isinstance(mat, np.ndarray): return np.linalg.inv(mat) raise ValueError(f"bad matrix type = {type(mat)}") def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): """ Args: - depthmap (BxHxW array): - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] Returns: pointmap of absolute coordinates (BxHxWx3 array) """ if len(depth.shape) == 4: B, H, W, n = depth.shape else: B, H, W = depth.shape n = None if len(pseudo_focal.shape) == 3: # [B,H,W] pseudo_focalx = pseudo_focaly = pseudo_focal elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W] pseudo_focalx = pseudo_focal[:, 0] if pseudo_focal.shape[1] == 2: pseudo_focaly = pseudo_focal[:, 1] else: pseudo_focaly = pseudo_focalx else: raise NotImplementedError("Error, unknown input focal shape format.") assert pseudo_focalx.shape == depth.shape[:3] assert pseudo_focaly.shape == depth.shape[:3] grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] # set principal point if pp is None: grid_x = grid_x - (W - 1) / 2 grid_y = grid_y - (H - 1) / 2 else: grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] if n is None: pts3d = torch.empty((B, H, W, 3), device=depth.device) pts3d[..., 0] = depth * grid_x / pseudo_focalx pts3d[..., 1] = depth * grid_y / pseudo_focaly pts3d[..., 2] = depth else: pts3d = torch.empty((B, H, W, 3, n), device=depth.device) pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] pts3d[..., 2, :] = depth return pts3d @lru_cache(maxsize=10) def get_meshgrid(W, H): u, v = np.meshgrid(np.arange(W), np.arange(H)) return u, v @lru_cache(maxsize=10) def get_meshgrid_torch(W, H, device): u, v = torch.meshgrid(torch.arange(W, device=device).float(), torch.arange(H, device=device).float(), indexing="xy") uv = torch.stack((u, v), dim=-1) return uv def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): """ Args: - depthmap (HxW array): - camera_intrinsics: a 3x3 matrix Returns: pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. """ camera_intrinsics = np.float32(camera_intrinsics) H, W = depthmap.shape # Compute 3D ray associated with each pixel # Strong assumption: there are no skew terms assert camera_intrinsics[0, 1] == 0.0 assert camera_intrinsics[1, 0] == 0.0 if pseudo_focal is None: fu = camera_intrinsics[0, 0] fv = camera_intrinsics[1, 1] else: assert pseudo_focal.shape == (H, W) fu = fv = pseudo_focal cu = camera_intrinsics[0, 2] cv = camera_intrinsics[1, 2] u, v = get_meshgrid(W, H) X_cam = np.zeros((H, W, 3), dtype=np.float32) X_cam[..., 0] = (u - cu) * depthmap / fu X_cam[..., 1] = (v - cv) * depthmap / fv X_cam[..., 2] = depthmap # Mask for valid coordinates valid_mask = depthmap > 0.0 return X_cam, valid_mask def z_depthmap_to_norm_depthmap(z_depthmap, camera_intrinsics, pseudo_focal=None): """ Args: - z_depthmap (HxW array) - camera_intrinsics: a 3x3 matrix Returns: pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. """ camera_intrinsics = np.float32(camera_intrinsics) H, W = z_depthmap.shape # Compute 3D ray associated with each pixel # Strong assumption: there are no skew terms assert camera_intrinsics[0, 1] == 0.0 assert camera_intrinsics[1, 0] == 0.0 if pseudo_focal is None: fu = camera_intrinsics[0, 0] fv = camera_intrinsics[1, 1] else: assert pseudo_focal.shape == (H, W) fu = fv = pseudo_focal cu = camera_intrinsics[0, 2] cv = camera_intrinsics[1, 2] rays = np.ones((H, W, 3), dtype=np.float32) u, v = get_meshgrid(W, H) rays[..., 0] = (u - cu) / fu rays[..., 1] = (v - cv) / fv ray_norm = np.linalg.norm(rays, axis=-1) return z_depthmap * ray_norm def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, **kw): """ Args: - depthmap (HxW array): - camera_intrinsics: a 3x3 matrix - camera_pose: a 4x3 or 4x4 cam2world matrix Returns: pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.""" X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) X_world = X_cam # default if camera_pose is not None: # R_cam2world = np.float32(camera_params["R_cam2world"]) # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() R_cam2world = camera_pose[:3, :3] t_cam2world = camera_pose[:3, 3] # Express in absolute coordinates (invalid depth values) # X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] X_world = X_cam @ (R_cam2world.T) + t_cam2world[None, None, :] return X_world, valid_mask def global_points_to_local(pts, camera_pose): """ Args: - pts: points in world coordinate - camera_pose: camera to world transformation """ world_to_camera = np.linalg.inv(camera_pose) R_world2cam = world_to_camera[:3, :3] t_world2cam = world_to_camera[:3, 3] pts_local = np.einsum("ik, vuk -> vui", R_world2cam, pts) + t_world2cam[None, None, :] return pts_local def project_points_to_pixels(pts_camera, camera_intrinsics, pseudo_focal=None): """ Args: - pts_camera (HxWx3 array): points in camera coordinates - camera_intrinsics: a 3x3 matrix Returns: pixel coordinates (HxWx2 array), and a mask specifying valid pixels. """ camera_intrinsics = np.float32(camera_intrinsics) H, W = pts_camera.shape[:2] # Compute 3D ray associated with each pixel # Strong assumption: there are no skew terms assert camera_intrinsics[0, 1] == 0.0 assert camera_intrinsics[1, 0] == 0.0 if pseudo_focal is None: fu = camera_intrinsics[0, 0] fv = camera_intrinsics[1, 1] else: assert pseudo_focal.shape == (H, W) fu = fv = pseudo_focal cu = camera_intrinsics[0, 2] cv = camera_intrinsics[1, 2] x, y, z = pts_camera[..., 0], pts_camera[..., 1], pts_camera[..., 2] uv = np.zeros((H, W, 2), dtype=np.float32) uv[..., 0] = fu * x / z + cu uv[..., 1] = fv * y / z + cv # Mask for valid coordinates valid_mask = ( (z > 0.0) & (uv[..., 0] >= -0.5) & (uv[..., 0] < W - 0.5) & (uv[..., 1] >= -0.5) & (uv[..., 1] < H - 0.5) ) # valid_mask = (z > 0.0) & (uv[..., 0] >= 0) & (uv[..., 0] < W) & (uv[..., 1] >= 0) & (uv[..., 1] < H) return uv, valid_mask def project_points_to_pixels_batched(pts_camera, camera_intrinsics, pseudo_focal=None): """ Args: - pts_camera (BxHxWx3 torch.Tensor): points in camera coordinates - camera_intrinsics: a Bx3x3 torch.Tensor Returns: pixel coordinates (BxHxWx2 torch.Tensor), and a mask (BxHxW) specifying valid pixels. """ camera_intrinsics = camera_intrinsics B, H, W, C = pts_camera.shape # Compute 3D ray associated with each pixel # Strong assumption: there are no skew terms assert (camera_intrinsics[..., 0, 1] == 0.0).all() assert (camera_intrinsics[..., 1, 0] == 0.0).all() if pseudo_focal is None: fu = camera_intrinsics[..., 0, 0] fv = camera_intrinsics[..., 1, 1] else: assert pseudo_focal.shape == (B, H, W) fu = fv = pseudo_focal cu = camera_intrinsics[..., 0, 2] cv = camera_intrinsics[..., 1, 2] x, y, z = pts_camera[..., 0], pts_camera[..., 1], pts_camera[..., 2] uv = torch.zeros((B, H, W, 2), dtype=pts_camera.dtype, device=pts_camera.device) uv[..., 0] = fu.view(B, 1, 1) * x / z + cu.view(B, 1, 1) uv[..., 1] = fv.view(B, 1, 1) * y / z + cv.view(B, 1, 1) # Mask for valid coordinates valid_mask = ( (z > 0.0) & (uv[..., 0] >= -0.5) & (uv[..., 0] < W - 0.5) & (uv[..., 1] >= -0.5) & (uv[..., 1] < H - 0.5) ) # valid_mask = (z > 0.0) & (uv[..., 0] >= 0) & (uv[..., 0] < W) & (uv[..., 1] >= 0) & (uv[..., 1] < H) return uv, valid_mask def z_depthmap_to_norm_depthmap_batched(z_depthmap, camera_intrinsics, pseudo_focal=None): """ Args: - z_depthmap (BxHxW array) - camera_intrinsics: a Bx3x3 matrix Returns: pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. """ B, H, W = z_depthmap.shape # Compute 3D ray associated with each pixel # Strong assumption: there are no skew terms assert (camera_intrinsics[..., 0, 1] == 0.0).all() assert (camera_intrinsics[..., 1, 0] == 0.0).all() if pseudo_focal is None: fu = camera_intrinsics[..., 0, 0] fv = camera_intrinsics[..., 1, 1] else: assert pseudo_focal.shape == (B, H, W) fu = fv = pseudo_focal cu = camera_intrinsics[..., 0, 2] cv = camera_intrinsics[..., 1, 2] rays = torch.ones((B, H, W, 3), dtype=z_depthmap.dtype, device=z_depthmap.device) uv = get_meshgrid_torch(W, H, device=z_depthmap.device) rays[..., 0] = (uv[..., 0].view(1, H, W) - cu.view(B, 1, 1)) / fu.view(B, 1, 1) rays[..., 1] = (uv[..., 1].view(1, H, W) - cv.view(B, 1, 1)) / fv.view(B, 1, 1) ray_norm = torch.linalg.norm(rays, axis=-1) return z_depthmap * ray_norm def colmap_to_opencv_intrinsics(K): """ Modify camera intrinsics to follow a different convention. Coordinates of the center of the top-left pixels are by default: - (0.5, 0.5) in Colmap - (0,0) in OpenCV """ K = K.copy() K[0, 2] -= 0.5 K[1, 2] -= 0.5 return K def opencv_to_colmap_intrinsics(K): """ Modify camera intrinsics to follow a different convention. Coordinates of the center of the top-left pixels are by default: - (0.5, 0.5) in Colmap - (0,0) in OpenCV """ K = K.copy() K[0, 2] += 0.5 K[1, 2] += 0.5 return K @torch.no_grad() def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5): # set invalid points to NaN _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1) _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1 # compute median depth overall (ignoring nans) if quantile == 0.5: shift_z = torch.nanmedian(_z, dim=-1).values else: shift_z = torch.nanquantile(_z, quantile, dim=-1) return shift_z # (B,) @torch.no_grad() def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True): # set invalid points to NaN _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3) _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1 # compute median center _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) if z_only: _center[..., :2] = 0 # do not center X and Y # compute median norm _norm = ((_pts - _center) if center else _pts).norm(dim=-1) scale = torch.nanmedian(_norm, dim=1).values return _center[:, None, :, :], scale[:, None, None, None] def find_reciprocal_matches(P1, P2): """ returns 3 values: 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1 3 - reciprocal_in_P2.sum(): the number of matches """ tree1 = KDTree(P1) tree2 = KDTree(P2) _, nn1_in_P2 = tree2.query(P1, workers=8) _, nn2_in_P1 = tree1.query(P2, workers=8) reciprocal_in_P1 = nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)) reciprocal_in_P2 = nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)) assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum() return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum() def rotate_vector_with_quaternion( v: torch.Tensor, quat: torch.Tensor, scalar_first: bool = False, skip_norm=False ) -> torch.Tensor: """ Rotate a 3D vector by a quaternion. Args: v (torch.Tensor): A tensor of shape (..., 3) representing the vectors to rotate. quat (torch.Tensor): A tensor of shape (..., 4) representing the quaternions. The last dimension is [w, x, y, z] if scalar_first is True, or [x, y, z, w] if scalar_first is False. scalar_first (bool): If True, assumes the quaternion is in the format [w, x, y, z]. Otherwise, assumes the format [x, y, z, w]. Returns: torch.Tensor: A tensor of shape (..., 3) representing the rotated vectors. """ if scalar_first: w, x, y, z = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3] else: x, y, z, w = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3] # Normalize the quaternion to ensure a valid rotation if not skip_norm: norm_quat = torch.sqrt(w**2 + x**2 + y**2 + z**2 + 1e-8) w, x, y, z = w / norm_quat, x / norm_quat, y / norm_quat, z / norm_quat # Vector part of the quaternion q_vec = torch.stack([x, y, z], dim=-1) # Shape (..., 3) # Cross product q_vec x v t = 2 * torch.cross(q_vec, v, dim=-1) # Intermediate vector, shape (..., 3) # Ensure proper broadcasting of w v_rotated = v + w.unsqueeze(-1) * t + torch.cross(q_vec, t, dim=-1) return v_rotated def quaternion_to_rot_matrix(quat: torch.Tensor, scalar_first: bool = False) -> torch.Tensor: if scalar_first: w, x, y, z = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3] else: x, y, z, w = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3] norm_quat = torch.sqrt(w**2 + x**2 + y**2 + z**2 + 1e-8) w, x, y, z = w / norm_quat, x / norm_quat, y / norm_quat, z / norm_quat xx, yy, zz = x * x, y * y, z * z xy, xz, yz = x * y, x * z, y * z wx, wy, wz = w * x, w * y, w * z rot_matrix_shape = quat.shape[:-1] + (3, 3) rot_matrix = torch.empty(rot_matrix_shape, device=quat.device) rot_matrix[..., 0, 0] = 1 - 2 * (yy + zz) rot_matrix[..., 0, 1] = 2 * (xy - wz) rot_matrix[..., 0, 2] = 2 * (xz + wy) rot_matrix[..., 1, 0] = 2 * (xy + wz) rot_matrix[..., 1, 1] = 1 - 2 * (xx + zz) rot_matrix[..., 1, 2] = 2 * (yz - wx) rot_matrix[..., 2, 0] = 2 * (xz - wy) rot_matrix[..., 2, 1] = 2 * (yz + wx) rot_matrix[..., 2, 2] = 1 - 2 * (xx + yy) return rot_matrix