|
|
""" |
|
|
Utilities for geometry operations. |
|
|
|
|
|
References: DUSt3R, MoGe |
|
|
""" |
|
|
|
|
|
from numbers import Number |
|
|
from typing import Tuple, Union |
|
|
|
|
|
import einops as ein |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from mapanything.utils.misc import invalid_to_zeros |
|
|
from mapanything.utils.warnings import no_warnings |
|
|
|
|
|
|
|
|
def depthmap_to_camera_frame(depthmap, intrinsics): |
|
|
""" |
|
|
Convert depth image to a pointcloud in camera frame. |
|
|
|
|
|
Args: |
|
|
- depthmap: HxW or BxHxW torch tensor |
|
|
- intrinsics: 3x3 or Bx3x3 torch tensor |
|
|
|
|
|
Returns: |
|
|
pointmap in camera frame (HxWx3 or BxHxWx3 tensor), and a mask specifying valid pixels. |
|
|
""" |
|
|
|
|
|
if depthmap.dim() == 2: |
|
|
depthmap = depthmap.unsqueeze(0) |
|
|
intrinsics = intrinsics.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
batch_size, height, width = depthmap.shape |
|
|
device = depthmap.device |
|
|
|
|
|
|
|
|
x_grid, y_grid = torch.meshgrid( |
|
|
torch.arange(width, device=device).float(), |
|
|
torch.arange(height, device=device).float(), |
|
|
indexing="xy", |
|
|
) |
|
|
x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
|
|
|
fx = intrinsics[:, 0, 0].view(-1, 1, 1) |
|
|
fy = intrinsics[:, 1, 1].view(-1, 1, 1) |
|
|
cx = intrinsics[:, 0, 2].view(-1, 1, 1) |
|
|
cy = intrinsics[:, 1, 2].view(-1, 1, 1) |
|
|
|
|
|
depth_z = depthmap |
|
|
xx = (x_grid - cx) * depth_z / fx |
|
|
yy = (y_grid - cy) * depth_z / fy |
|
|
pts3d_cam = torch.stack((xx, yy, depth_z), dim=-1) |
|
|
|
|
|
|
|
|
valid_mask = depthmap > 0.0 |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
pts3d_cam = pts3d_cam.squeeze(0) |
|
|
valid_mask = valid_mask.squeeze(0) |
|
|
|
|
|
return pts3d_cam, valid_mask |
|
|
|
|
|
|
|
|
def depthmap_to_world_frame(depthmap, intrinsics, camera_pose=None): |
|
|
""" |
|
|
Convert depth image to a pointcloud in world frame. |
|
|
|
|
|
Args: |
|
|
- depthmap: HxW or BxHxW torch tensor |
|
|
- intrinsics: 3x3 or Bx3x3 torch tensor |
|
|
- camera_pose: 4x4 or Bx4x4 torch tensor |
|
|
|
|
|
Returns: |
|
|
pointmap in world frame (HxWx3 or BxHxWx3 tensor), and a mask specifying valid pixels. |
|
|
""" |
|
|
pts3d_cam, valid_mask = depthmap_to_camera_frame(depthmap, intrinsics) |
|
|
|
|
|
if camera_pose is not None: |
|
|
|
|
|
if camera_pose.dim() == 2: |
|
|
camera_pose = camera_pose.unsqueeze(0) |
|
|
pts3d_cam = pts3d_cam.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
|
|
|
pts3d_cam_homo = torch.cat( |
|
|
[pts3d_cam, torch.ones_like(pts3d_cam[..., :1])], dim=-1 |
|
|
) |
|
|
pts3d_world = ein.einsum( |
|
|
camera_pose, pts3d_cam_homo, "b i k, b h w k -> b h w i" |
|
|
) |
|
|
pts3d_world = pts3d_world[..., :3] |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
pts3d_world = pts3d_world.squeeze(0) |
|
|
else: |
|
|
pts3d_world = pts3d_cam |
|
|
|
|
|
return pts3d_world, valid_mask |
|
|
|
|
|
|
|
|
def transform_pts3d(pts3d, transformation): |
|
|
""" |
|
|
Transform 3D points using a 4x4 transformation matrix. |
|
|
|
|
|
Args: |
|
|
- pts3d: HxWx3 or BxHxWx3 torch tensor |
|
|
- transformation: 4x4 or Bx4x4 torch tensor |
|
|
|
|
|
Returns: |
|
|
transformed points (HxWx3 or BxHxWx3 tensor) |
|
|
""" |
|
|
|
|
|
if pts3d.dim() == 3: |
|
|
pts3d = pts3d.unsqueeze(0) |
|
|
transformation = transformation.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
|
|
|
pts3d_homo = torch.cat([pts3d, torch.ones_like(pts3d[..., :1])], dim=-1) |
|
|
|
|
|
|
|
|
transformed_pts3d = ein.einsum( |
|
|
transformation, pts3d_homo, "b i k, b h w k -> b h w i" |
|
|
) |
|
|
transformed_pts3d = transformed_pts3d[..., :3] |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
transformed_pts3d = transformed_pts3d.squeeze(0) |
|
|
|
|
|
return transformed_pts3d |
|
|
|
|
|
|
|
|
def project_pts3d_to_image(pts3d, intrinsics, return_z_dim): |
|
|
""" |
|
|
Project 3D points to image plane (assumes pinhole camera model with no distortion). |
|
|
|
|
|
Args: |
|
|
- pts3d: HxWx3 or BxHxWx3 torch tensor |
|
|
- intrinsics: 3x3 or Bx3x3 torch tensor |
|
|
- return_z_dim: bool, whether to return the third dimension of the projected points |
|
|
|
|
|
Returns: |
|
|
projected points (HxWx2) |
|
|
""" |
|
|
if pts3d.dim() == 3: |
|
|
pts3d = pts3d.unsqueeze(0) |
|
|
intrinsics = intrinsics.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
|
|
|
projected_pts2d = ein.einsum(intrinsics, pts3d, "b i k, b h w k -> b h w i") |
|
|
projected_pts2d[..., :2] /= projected_pts2d[..., 2].unsqueeze(-1).clamp(min=1e-6) |
|
|
|
|
|
|
|
|
if not return_z_dim: |
|
|
projected_pts2d = projected_pts2d[..., :2] |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
projected_pts2d = projected_pts2d.squeeze(0) |
|
|
|
|
|
return projected_pts2d |
|
|
|
|
|
|
|
|
def get_rays_in_camera_frame(intrinsics, height, width, normalize_to_unit_sphere): |
|
|
""" |
|
|
Convert camera intrinsics to a raymap (ray origins + directions) in camera frame. |
|
|
Note: Currently only supports pinhole camera model. |
|
|
|
|
|
Args: |
|
|
- intrinsics: 3x3 or Bx3x3 torch tensor |
|
|
- height: int |
|
|
- width: int |
|
|
- normalize_to_unit_sphere: bool |
|
|
|
|
|
Returns: |
|
|
- ray_origins: (HxWx3 or BxHxWx3) tensor |
|
|
- ray_directions: (HxWx3 or BxHxWx3) tensor |
|
|
""" |
|
|
|
|
|
if intrinsics.dim() == 2: |
|
|
intrinsics = intrinsics.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
batch_size = intrinsics.shape[0] |
|
|
device = intrinsics.device |
|
|
|
|
|
|
|
|
x_grid, y_grid = torch.meshgrid( |
|
|
torch.arange(width, device=device).float(), |
|
|
torch.arange(height, device=device).float(), |
|
|
indexing="xy", |
|
|
) |
|
|
x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
|
|
|
fx = intrinsics[:, 0, 0].view(-1, 1, 1) |
|
|
fy = intrinsics[:, 1, 1].view(-1, 1, 1) |
|
|
cx = intrinsics[:, 0, 2].view(-1, 1, 1) |
|
|
cy = intrinsics[:, 1, 2].view(-1, 1, 1) |
|
|
|
|
|
ray_origins = torch.zeros((batch_size, height, width, 3), device=device) |
|
|
xx = (x_grid - cx) / fx |
|
|
yy = (y_grid - cy) / fy |
|
|
ray_directions = torch.stack((xx, yy, torch.ones_like(xx)), dim=-1) |
|
|
|
|
|
|
|
|
if normalize_to_unit_sphere: |
|
|
ray_directions = ray_directions / torch.norm( |
|
|
ray_directions, dim=-1, keepdim=True |
|
|
) |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
ray_origins = ray_origins.squeeze(0) |
|
|
ray_directions = ray_directions.squeeze(0) |
|
|
|
|
|
return ray_origins, ray_directions |
|
|
|
|
|
|
|
|
def get_rays_in_world_frame( |
|
|
intrinsics, height, width, normalize_to_unit_sphere, camera_pose=None |
|
|
): |
|
|
""" |
|
|
Convert camera intrinsics & camera_pose (if provided) to a raymap (ray origins + directions) in camera or world frame (if camera_pose is provided). |
|
|
Note: Currently only supports pinhole camera model. |
|
|
|
|
|
Args: |
|
|
- intrinsics: 3x3 or Bx3x3 torch tensor |
|
|
- height: int |
|
|
- width: int |
|
|
- normalize_to_unit_sphere: bool |
|
|
- camera_pose: 4x4 or Bx4x4 torch tensor |
|
|
|
|
|
Returns: |
|
|
- ray_origins: (HxWx3 or BxHxWx3) tensor |
|
|
- ray_directions: (HxWx3 or BxHxWx3) tensor |
|
|
""" |
|
|
|
|
|
ray_origins, ray_directions = get_rays_in_camera_frame( |
|
|
intrinsics, height, width, normalize_to_unit_sphere |
|
|
) |
|
|
|
|
|
if camera_pose is not None: |
|
|
|
|
|
if camera_pose.dim() == 2: |
|
|
camera_pose = camera_pose.unsqueeze(0) |
|
|
ray_origins = ray_origins.unsqueeze(0) |
|
|
ray_directions = ray_directions.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
|
|
|
ray_origins_homo = torch.cat( |
|
|
[ray_origins, torch.ones_like(ray_origins[..., :1])], dim=-1 |
|
|
) |
|
|
ray_directions_homo = torch.cat( |
|
|
[ray_directions, torch.zeros_like(ray_directions[..., :1])], dim=-1 |
|
|
) |
|
|
ray_origins_world = ein.einsum( |
|
|
camera_pose, ray_origins_homo, "b i k, b h w k -> b h w i" |
|
|
) |
|
|
ray_directions_world = ein.einsum( |
|
|
camera_pose, ray_directions_homo, "b i k, b h w k -> b h w i" |
|
|
) |
|
|
ray_origins_world = ray_origins_world[..., :3] |
|
|
ray_directions_world = ray_directions_world[..., :3] |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
ray_origins_world = ray_origins_world.squeeze(0) |
|
|
ray_directions_world = ray_directions_world.squeeze(0) |
|
|
else: |
|
|
ray_origins_world = ray_origins |
|
|
ray_directions_world = ray_directions |
|
|
|
|
|
return ray_origins_world, ray_directions_world |
|
|
|
|
|
|
|
|
def recover_pinhole_intrinsics_from_ray_directions( |
|
|
ray_directions, use_geometric_calculation=False |
|
|
): |
|
|
""" |
|
|
Recover pinhole camera intrinsics from ray directions, supporting both batched and non-batched inputs. |
|
|
|
|
|
Args: |
|
|
ray_directions: Tensor of shape [H, W, 3] or [B, H, W, 3] containing unit normalized ray directions |
|
|
|
|
|
Returns: |
|
|
Dictionary containing camera intrinsics (fx, fy, cx, cy) as tensors |
|
|
""" |
|
|
|
|
|
if ray_directions.dim() == 3: |
|
|
ray_directions = ray_directions.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
batch_size, height, width, _ = ray_directions.shape |
|
|
device = ray_directions.device |
|
|
|
|
|
|
|
|
x_grid, y_grid = torch.meshgrid( |
|
|
torch.arange(width, device=device).float(), |
|
|
torch.arange(height, device=device).float(), |
|
|
indexing="xy", |
|
|
) |
|
|
|
|
|
|
|
|
x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
|
|
|
|
|
|
is_high_res = height * width > 1000000 |
|
|
|
|
|
if is_high_res or use_geometric_calculation: |
|
|
|
|
|
|
|
|
center_h, center_w = height // 2, width // 2 |
|
|
quarter_w, three_quarter_w = width // 4, 3 * width // 4 |
|
|
quarter_h, three_quarter_h = height // 4, 3 * height // 4 |
|
|
|
|
|
|
|
|
center_rays = ray_directions[:, center_h, center_w, :].clone() |
|
|
left_rays = ray_directions[:, center_h, quarter_w, :].clone() |
|
|
right_rays = ray_directions[:, center_h, three_quarter_w, :].clone() |
|
|
top_rays = ray_directions[:, quarter_h, center_w, :].clone() |
|
|
bottom_rays = ray_directions[:, three_quarter_h, center_w, :].clone() |
|
|
|
|
|
|
|
|
center_rays = center_rays / center_rays[:, 2].unsqueeze(1) |
|
|
left_rays = left_rays / left_rays[:, 2].unsqueeze(1) |
|
|
right_rays = right_rays / right_rays[:, 2].unsqueeze(1) |
|
|
top_rays = top_rays / top_rays[:, 2].unsqueeze(1) |
|
|
bottom_rays = bottom_rays / bottom_rays[:, 2].unsqueeze(1) |
|
|
|
|
|
|
|
|
fx_left = (quarter_w - center_w) / (left_rays[:, 0] - center_rays[:, 0]) |
|
|
fx_right = (three_quarter_w - center_w) / (right_rays[:, 0] - center_rays[:, 0]) |
|
|
fx = (fx_left + fx_right) / 2 |
|
|
|
|
|
|
|
|
cx = center_w - fx * center_rays[:, 0] |
|
|
|
|
|
|
|
|
fy_top = (quarter_h - center_h) / (top_rays[:, 1] - center_rays[:, 1]) |
|
|
fy_bottom = (three_quarter_h - center_h) / ( |
|
|
bottom_rays[:, 1] - center_rays[:, 1] |
|
|
) |
|
|
fy = (fy_top + fy_bottom) / 2 |
|
|
|
|
|
cy = center_h - fy * center_rays[:, 1] |
|
|
else: |
|
|
|
|
|
|
|
|
step_h = max(1, height // 50) |
|
|
step_w = max(1, width // 50) |
|
|
|
|
|
h_indices = torch.arange(0, height, step_h, device=device) |
|
|
w_indices = torch.arange(0, width, step_w, device=device) |
|
|
|
|
|
|
|
|
x_sampled = x_grid[:, h_indices[:, None], w_indices[None, :]] |
|
|
y_sampled = y_grid[:, h_indices[:, None], w_indices[None, :]] |
|
|
rays_sampled = ray_directions[ |
|
|
:, h_indices[:, None], w_indices[None, :], : |
|
|
] |
|
|
|
|
|
|
|
|
x_flat = x_sampled.reshape(batch_size, -1) |
|
|
y_flat = y_sampled.reshape(batch_size, -1) |
|
|
|
|
|
|
|
|
dx = rays_sampled[..., 0].reshape(batch_size, -1) |
|
|
dy = rays_sampled[..., 1].reshape(batch_size, -1) |
|
|
dz = rays_sampled[..., 2].reshape(batch_size, -1) |
|
|
|
|
|
|
|
|
ratio_x = dx / dz |
|
|
ratio_y = dy / dz |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ones = torch.ones_like(x_flat) |
|
|
A_x = torch.stack([ones, ratio_x], dim=2) |
|
|
b_x = x_flat.unsqueeze(2) |
|
|
|
|
|
|
|
|
ATA_x = torch.bmm(A_x.transpose(1, 2), A_x) |
|
|
ATb_x = torch.bmm(A_x.transpose(1, 2), b_x) |
|
|
|
|
|
|
|
|
solution_x = torch.linalg.solve(ATA_x, ATb_x).squeeze(2) |
|
|
cx, fx = solution_x[:, 0], solution_x[:, 1] |
|
|
|
|
|
|
|
|
A_y = torch.stack([ones, ratio_y], dim=2) |
|
|
b_y = y_flat.unsqueeze(2) |
|
|
|
|
|
ATA_y = torch.bmm(A_y.transpose(1, 2), A_y) |
|
|
ATb_y = torch.bmm(A_y.transpose(1, 2), b_y) |
|
|
|
|
|
solution_y = torch.linalg.solve(ATA_y, ATb_y).squeeze(2) |
|
|
cy, fy = solution_y[:, 0], solution_y[:, 1] |
|
|
|
|
|
|
|
|
batch_size = fx.shape[0] |
|
|
intrinsics = torch.zeros(batch_size, 3, 3, device=ray_directions.device) |
|
|
|
|
|
|
|
|
intrinsics[:, 0, 0] = fx |
|
|
intrinsics[:, 1, 1] = fy |
|
|
intrinsics[:, 0, 2] = cx |
|
|
intrinsics[:, 1, 2] = cy |
|
|
intrinsics[:, 2, 2] = 1.0 |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
intrinsics = intrinsics.squeeze(0) |
|
|
|
|
|
return intrinsics |
|
|
|
|
|
|
|
|
def transform_rays(ray_origins, ray_directions, transformation): |
|
|
""" |
|
|
Transform 6D rays (ray origins and ray directions) using a 4x4 transformation matrix. |
|
|
|
|
|
Args: |
|
|
- ray_origins: HxWx3 or BxHxWx3 torch tensor |
|
|
- ray_directions: HxWx3 or BxHxWx3 torch tensor |
|
|
- transformation: 4x4 or Bx4x4 torch tensor |
|
|
- normalize_to_unit_sphere: bool, whether to normalize the transformed ray directions to unit length |
|
|
|
|
|
Returns: |
|
|
transformed ray_origins (HxWx3 or BxHxWx3 tensor) and ray_directions (HxWx3 or BxHxWx3 tensor) |
|
|
""" |
|
|
|
|
|
if ray_origins.dim() == 3: |
|
|
ray_origins = ray_origins.unsqueeze(0) |
|
|
ray_directions = ray_directions.unsqueeze(0) |
|
|
transformation = transformation.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
|
|
|
ray_origins_homo = torch.cat( |
|
|
[ray_origins, torch.ones_like(ray_origins[..., :1])], dim=-1 |
|
|
) |
|
|
ray_directions_homo = torch.cat( |
|
|
[ray_directions, torch.zeros_like(ray_directions[..., :1])], dim=-1 |
|
|
) |
|
|
transformed_ray_origins = ein.einsum( |
|
|
transformation, ray_origins_homo, "b i k, b h w k -> b h w i" |
|
|
) |
|
|
transformed_ray_directions = ein.einsum( |
|
|
transformation, ray_directions_homo, "b i k, b h w k -> b h w i" |
|
|
) |
|
|
transformed_ray_origins = transformed_ray_origins[..., :3] |
|
|
transformed_ray_directions = transformed_ray_directions[..., :3] |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
transformed_ray_origins = transformed_ray_origins.squeeze(0) |
|
|
transformed_ray_directions = transformed_ray_directions.squeeze(0) |
|
|
|
|
|
return transformed_ray_origins, transformed_ray_directions |
|
|
|
|
|
|
|
|
def convert_z_depth_to_depth_along_ray(z_depth, intrinsics): |
|
|
""" |
|
|
Convert z-depth image to depth along camera rays. |
|
|
|
|
|
Args: |
|
|
- z_depth: HxW or BxHxW torch tensor |
|
|
- intrinsics: 3x3 or Bx3x3 torch tensor |
|
|
|
|
|
Returns: |
|
|
- depth_along_ray: HxW or BxHxW torch tensor |
|
|
""" |
|
|
|
|
|
if z_depth.dim() == 2: |
|
|
z_depth = z_depth.unsqueeze(0) |
|
|
intrinsics = intrinsics.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
|
|
|
batch_size, height, width = z_depth.shape |
|
|
_, ray_directions = get_rays_in_camera_frame( |
|
|
intrinsics, height, width, normalize_to_unit_sphere=False |
|
|
) |
|
|
|
|
|
|
|
|
pts3d_cam = z_depth[..., None] * ray_directions |
|
|
depth_along_ray = torch.norm(pts3d_cam, dim=-1) |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
depth_along_ray = depth_along_ray.squeeze(0) |
|
|
|
|
|
return depth_along_ray |
|
|
|
|
|
|
|
|
def convert_raymap_z_depth_quats_to_pointmap(ray_origins, ray_directions, depth, quats): |
|
|
""" |
|
|
Convert raymap (ray origins + directions on unit plane), z-depth and |
|
|
unit quaternions (representing rotation) to a pointmap in world frame. |
|
|
|
|
|
Args: |
|
|
- ray_origins: (HxWx3 or BxHxWx3) torch tensor |
|
|
- ray_directions: (HxWx3 or BxHxWx3) torch tensor |
|
|
- depth: (HxWx1 or BxHxWx1) torch tensor |
|
|
- quats: (HxWx4 or BxHxWx4) torch tensor (unit quaternions and notation is (x, y, z, w)) |
|
|
|
|
|
Returns: |
|
|
- pointmap: (HxWx3 or BxHxWx3) torch tensor |
|
|
""" |
|
|
|
|
|
if ray_origins.dim() == 3: |
|
|
ray_origins = ray_origins.unsqueeze(0) |
|
|
ray_directions = ray_directions.unsqueeze(0) |
|
|
depth = depth.unsqueeze(0) |
|
|
quats = quats.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
batch_size, height, width, _ = depth.shape |
|
|
device = depth.device |
|
|
|
|
|
|
|
|
quats = quats / torch.norm(quats, dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
qx, qy, qz, qw = quats[..., 0], quats[..., 1], quats[..., 2], quats[..., 3] |
|
|
rot_mat = ( |
|
|
torch.stack( |
|
|
[ |
|
|
qw**2 + qx**2 - qy**2 - qz**2, |
|
|
2 * (qx * qy - qw * qz), |
|
|
2 * (qw * qy + qx * qz), |
|
|
2 * (qw * qz + qx * qy), |
|
|
qw**2 - qx**2 + qy**2 - qz**2, |
|
|
2 * (qy * qz - qw * qx), |
|
|
2 * (qx * qz - qw * qy), |
|
|
2 * (qw * qx + qy * qz), |
|
|
qw**2 - qx**2 - qy**2 + qz**2, |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
.reshape(batch_size, height, width, 3, 3) |
|
|
.to(device) |
|
|
) |
|
|
|
|
|
|
|
|
pts3d_local = depth * ray_directions |
|
|
|
|
|
|
|
|
rotated_pts3d_local = ein.einsum( |
|
|
rot_mat, pts3d_local, "b h w i k, b h w k -> b h w i" |
|
|
) |
|
|
|
|
|
|
|
|
pts3d = ray_origins + rotated_pts3d_local |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
pts3d = pts3d.squeeze(0) |
|
|
|
|
|
return pts3d |
|
|
|
|
|
|
|
|
def quaternion_to_rotation_matrix(quat): |
|
|
""" |
|
|
Convert a quaternion into a 3x3 rotation matrix. |
|
|
|
|
|
Args: |
|
|
- quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) |
|
|
|
|
|
Returns: |
|
|
- rot_matrix: 3x3 or Bx3x3 torch tensor |
|
|
""" |
|
|
if quat.dim() == 1: |
|
|
quat = quat.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
|
|
|
quat = quat / quat.norm(dim=1, keepdim=True) |
|
|
x, y, z, w = quat.unbind(dim=1) |
|
|
|
|
|
|
|
|
xx = x * x |
|
|
yy = y * y |
|
|
zz = z * z |
|
|
xy = x * y |
|
|
xz = x * z |
|
|
yz = y * z |
|
|
wx = w * x |
|
|
wy = w * y |
|
|
wz = w * z |
|
|
|
|
|
|
|
|
rot_matrix = torch.stack( |
|
|
[ |
|
|
1 - 2 * (yy + zz), |
|
|
2 * (xy - wz), |
|
|
2 * (xz + wy), |
|
|
2 * (xy + wz), |
|
|
1 - 2 * (xx + zz), |
|
|
2 * (yz - wx), |
|
|
2 * (xz - wy), |
|
|
2 * (yz + wx), |
|
|
1 - 2 * (xx + yy), |
|
|
], |
|
|
dim=1, |
|
|
).view(-1, 3, 3) |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
rot_matrix = rot_matrix.squeeze(0) |
|
|
|
|
|
return rot_matrix |
|
|
|
|
|
|
|
|
def rotation_matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Convert rotations given as rotation matrices to quaternions. |
|
|
|
|
|
Args: |
|
|
matrix: Rotation matrices as tensor of shape (..., 3, 3). |
|
|
|
|
|
Returns: |
|
|
quaternions with real part last, as tensor of shape (..., 4). |
|
|
Quaternion Order: XYZW or say ijkr, scalar-last |
|
|
""" |
|
|
if matrix.size(-1) != 3 or matrix.size(-2) != 3: |
|
|
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") |
|
|
|
|
|
batch_dim = matrix.shape[:-2] |
|
|
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( |
|
|
matrix.reshape(batch_dim + (9,)), dim=-1 |
|
|
) |
|
|
|
|
|
q_abs = _sqrt_positive_part( |
|
|
torch.stack( |
|
|
[ |
|
|
1.0 + m00 + m11 + m22, |
|
|
1.0 + m00 - m11 - m22, |
|
|
1.0 - m00 + m11 - m22, |
|
|
1.0 - m00 - m11 + m22, |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
quat_by_rijk = torch.stack( |
|
|
[ |
|
|
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), |
|
|
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), |
|
|
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), |
|
|
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), |
|
|
], |
|
|
dim=-2, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) |
|
|
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) |
|
|
|
|
|
|
|
|
|
|
|
out = quat_candidates[ |
|
|
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : |
|
|
].reshape(batch_dim + (4,)) |
|
|
|
|
|
|
|
|
out = out[..., [1, 2, 3, 0]] |
|
|
|
|
|
out = standardize_quaternion(out) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Returns torch.sqrt(torch.max(0, x)) |
|
|
but with a zero subgradient where x is 0. |
|
|
""" |
|
|
ret = torch.zeros_like(x) |
|
|
positive_mask = x > 0 |
|
|
if torch.is_grad_enabled(): |
|
|
ret[positive_mask] = torch.sqrt(x[positive_mask]) |
|
|
else: |
|
|
ret = torch.where(positive_mask, torch.sqrt(x), ret) |
|
|
return ret |
|
|
|
|
|
|
|
|
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Convert a unit quaternion to a standard form: one in which the real |
|
|
part is non negative. |
|
|
|
|
|
Args: |
|
|
quaternions: Quaternions with real part last, |
|
|
as tensor of shape (..., 4). |
|
|
|
|
|
Returns: |
|
|
Standardized quaternions as tensor of shape (..., 4). |
|
|
""" |
|
|
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) |
|
|
|
|
|
|
|
|
def quaternion_inverse(quat): |
|
|
""" |
|
|
Compute the inverse of a quaternion. |
|
|
|
|
|
Args: |
|
|
- quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) |
|
|
|
|
|
Returns: |
|
|
- inv_quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) |
|
|
""" |
|
|
|
|
|
if quat.dim() == 1: |
|
|
quat = quat.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
|
|
|
quat_conj = quat.clone() |
|
|
quat_conj[:, :3] = -quat_conj[:, :3] |
|
|
quat_norm = torch.sum(quat * quat, dim=1, keepdim=True) |
|
|
inv_quat = quat_conj / quat_norm |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
inv_quat = inv_quat.squeeze(0) |
|
|
|
|
|
return inv_quat |
|
|
|
|
|
|
|
|
def quaternion_multiply(q1, q2): |
|
|
""" |
|
|
Multiply two quaternions. |
|
|
|
|
|
Args: |
|
|
- q1: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) |
|
|
- q2: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) |
|
|
|
|
|
Returns: |
|
|
- qm: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) |
|
|
""" |
|
|
|
|
|
if q1.dim() == 1: |
|
|
q1 = q1.unsqueeze(0) |
|
|
q2 = q2.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
|
|
|
x1, y1, z1, w1 = q1.unbind(dim=1) |
|
|
x2, y2, z2, w2 = q2.unbind(dim=1) |
|
|
|
|
|
|
|
|
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 |
|
|
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 |
|
|
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 |
|
|
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 |
|
|
|
|
|
|
|
|
qm = torch.stack([x, y, z, w], dim=1) |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
qm = qm.squeeze(0) |
|
|
|
|
|
return qm |
|
|
|
|
|
|
|
|
def transform_pose_using_quats_and_trans_2_to_1(quats1, trans1, quats2, trans2): |
|
|
""" |
|
|
Transform quats and translation of pose2 from absolute frame (pose2 to world) to relative frame (pose2 to pose1). |
|
|
|
|
|
Args: |
|
|
- quats1: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) |
|
|
- trans1: 3 or Bx3 torch tensor |
|
|
- quats2: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) |
|
|
- trans2: 3 or Bx3 torch tensor |
|
|
|
|
|
Returns: |
|
|
- quats: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) |
|
|
- trans: 3 or Bx3 torch tensor |
|
|
""" |
|
|
|
|
|
if quats1.dim() == 1: |
|
|
quats1 = quats1.unsqueeze(0) |
|
|
trans1 = trans1.unsqueeze(0) |
|
|
quats2 = quats2.unsqueeze(0) |
|
|
trans2 = trans2.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
|
|
|
inv_quats1 = quaternion_inverse(quats1) |
|
|
R1_inv = quaternion_to_rotation_matrix(inv_quats1) |
|
|
t1_inv = -1 * ein.einsum(R1_inv, trans1, "b i j, b j -> b i") |
|
|
|
|
|
|
|
|
quats = quaternion_multiply(inv_quats1, quats2) |
|
|
trans = ein.einsum(R1_inv, trans2, "b i j, b j -> b i") + t1_inv |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
quats = quats.squeeze(0) |
|
|
trans = trans.squeeze(0) |
|
|
|
|
|
return quats, trans |
|
|
|
|
|
|
|
|
def convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( |
|
|
ray_directions, depth_along_ray, pose_trans, pose_quats |
|
|
): |
|
|
""" |
|
|
Convert ray directions, depth along ray, pose translation, and |
|
|
unit quaternions (representing pose rotation) to a pointmap in world frame. |
|
|
|
|
|
Args: |
|
|
- ray_directions: (HxWx3 or BxHxWx3) torch tensor |
|
|
- depth_along_ray: (HxWx1 or BxHxWx1) torch tensor |
|
|
- pose_trans: (3 or Bx3) torch tensor |
|
|
- pose_quats: (4 or Bx4) torch tensor (unit quaternions and notation is (x, y, z, w)) |
|
|
|
|
|
Returns: |
|
|
- pointmap: (HxWx3 or BxHxWx3) torch tensor |
|
|
""" |
|
|
|
|
|
if ray_directions.dim() == 3: |
|
|
ray_directions = ray_directions.unsqueeze(0) |
|
|
depth_along_ray = depth_along_ray.unsqueeze(0) |
|
|
pose_trans = pose_trans.unsqueeze(0) |
|
|
pose_quats = pose_quats.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
else: |
|
|
squeeze_batch_dim = False |
|
|
|
|
|
batch_size, height, width, _ = depth_along_ray.shape |
|
|
device = depth_along_ray.device |
|
|
|
|
|
|
|
|
pose_quats = pose_quats / torch.norm(pose_quats, dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
rot_mat = quaternion_to_rotation_matrix(pose_quats) |
|
|
|
|
|
|
|
|
pose_mat = torch.eye(4, device=device).unsqueeze(0).repeat(batch_size, 1, 1) |
|
|
pose_mat[:, :3, :3] = rot_mat |
|
|
pose_mat[:, :3, 3] = pose_trans |
|
|
|
|
|
|
|
|
pts3d_local = depth_along_ray * ray_directions |
|
|
|
|
|
|
|
|
pts3d_homo = torch.cat([pts3d_local, torch.ones_like(pts3d_local[..., :1])], dim=-1) |
|
|
pts3d_world = ein.einsum(pose_mat, pts3d_homo, "b i k, b h w k -> b h w i") |
|
|
pts3d_world = pts3d_world[..., :3] |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
pts3d_world = pts3d_world.squeeze(0) |
|
|
|
|
|
return pts3d_world |
|
|
|
|
|
|
|
|
def xy_grid( |
|
|
W, |
|
|
H, |
|
|
device=None, |
|
|
origin=(0, 0), |
|
|
unsqueeze=None, |
|
|
cat_dim=-1, |
|
|
homogeneous=False, |
|
|
**arange_kw, |
|
|
): |
|
|
""" |
|
|
Generate a coordinate grid of shape (H,W,2) or (H,W,3) if homogeneous=True. |
|
|
|
|
|
Args: |
|
|
W (int): Width of the grid |
|
|
H (int): Height of the grid |
|
|
device (torch.device, optional): Device to place the grid on. If None, uses numpy arrays |
|
|
origin (tuple, optional): Origin coordinates (x,y) for the grid. Default is (0,0) |
|
|
unsqueeze (int, optional): Dimension to unsqueeze in the output tensors |
|
|
cat_dim (int, optional): Dimension to concatenate the x,y coordinates. If None, returns tuple |
|
|
homogeneous (bool, optional): If True, adds a third dimension of ones to make homogeneous coordinates |
|
|
**arange_kw: Additional keyword arguments passed to np.arange or torch.arange |
|
|
|
|
|
Returns: |
|
|
numpy.ndarray or torch.Tensor: Coordinate grid where: |
|
|
- output[j,i,0] = i + origin[0] (x-coordinate) |
|
|
- output[j,i,1] = j + origin[1] (y-coordinate) |
|
|
- output[j,i,2] = 1 (if homogeneous=True) |
|
|
""" |
|
|
if device is None: |
|
|
|
|
|
arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones |
|
|
else: |
|
|
|
|
|
def arange(*a, **kw): |
|
|
return torch.arange(*a, device=device, **kw) |
|
|
|
|
|
meshgrid, stack = torch.meshgrid, torch.stack |
|
|
|
|
|
def ones(*a): |
|
|
return torch.ones(*a, device=device) |
|
|
|
|
|
tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] |
|
|
grid = meshgrid(tw, th, indexing="xy") |
|
|
if homogeneous: |
|
|
grid = grid + (ones((H, W)),) |
|
|
if unsqueeze is not None: |
|
|
grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) |
|
|
if cat_dim is not None: |
|
|
grid = stack(grid, cat_dim) |
|
|
|
|
|
return grid |
|
|
|
|
|
|
|
|
def geotrf(Trf, pts, ncol=None, norm=False): |
|
|
""" |
|
|
Apply a geometric transformation to a set of 3-D points. |
|
|
|
|
|
Args: |
|
|
Trf: 3x3 or 4x4 projection matrix (typically a Homography) or batch of matrices |
|
|
with shape (B, 3, 3) or (B, 4, 4) |
|
|
pts: numpy/torch/tuple of coordinates with shape (..., 2) or (..., 3) |
|
|
ncol: int, number of columns of the result (2 or 3) |
|
|
norm: float, if not 0, the result is projected on the z=norm plane |
|
|
(homogeneous normalization) |
|
|
|
|
|
Returns: |
|
|
Array or tensor of projected points with the same type as input and shape (..., ncol) |
|
|
""" |
|
|
assert Trf.ndim >= 2 |
|
|
if isinstance(Trf, np.ndarray): |
|
|
pts = np.asarray(pts) |
|
|
elif isinstance(Trf, torch.Tensor): |
|
|
pts = torch.as_tensor(pts, dtype=Trf.dtype) |
|
|
|
|
|
|
|
|
output_reshape = pts.shape[:-1] |
|
|
ncol = ncol or pts.shape[-1] |
|
|
|
|
|
|
|
|
if ( |
|
|
isinstance(Trf, torch.Tensor) |
|
|
and isinstance(pts, torch.Tensor) |
|
|
and Trf.ndim == 3 |
|
|
and pts.ndim == 4 |
|
|
): |
|
|
d = pts.shape[3] |
|
|
if Trf.shape[-1] == d: |
|
|
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) |
|
|
elif Trf.shape[-1] == d + 1: |
|
|
pts = ( |
|
|
torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) |
|
|
+ Trf[:, None, None, :d, d] |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") |
|
|
else: |
|
|
if Trf.ndim >= 3: |
|
|
n = Trf.ndim - 2 |
|
|
assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" |
|
|
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) |
|
|
|
|
|
if pts.ndim > Trf.ndim: |
|
|
|
|
|
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) |
|
|
elif pts.ndim == 2: |
|
|
|
|
|
pts = pts[:, None, :] |
|
|
|
|
|
if pts.shape[-1] + 1 == Trf.shape[-1]: |
|
|
Trf = Trf.swapaxes(-1, -2) |
|
|
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] |
|
|
elif pts.shape[-1] == Trf.shape[-1]: |
|
|
Trf = Trf.swapaxes(-1, -2) |
|
|
pts = pts @ Trf |
|
|
else: |
|
|
pts = Trf @ pts.T |
|
|
if pts.ndim >= 2: |
|
|
pts = pts.swapaxes(-1, -2) |
|
|
|
|
|
if norm: |
|
|
pts = pts / pts[..., -1:] |
|
|
if norm != 1: |
|
|
pts *= norm |
|
|
|
|
|
res = pts[..., :ncol].reshape(*output_reshape, ncol) |
|
|
|
|
|
return res |
|
|
|
|
|
|
|
|
def inv(mat): |
|
|
""" |
|
|
Invert a torch or numpy matrix |
|
|
""" |
|
|
if isinstance(mat, torch.Tensor): |
|
|
return torch.linalg.inv(mat) |
|
|
if isinstance(mat, np.ndarray): |
|
|
return np.linalg.inv(mat) |
|
|
raise ValueError(f"bad matrix type = {type(mat)}") |
|
|
|
|
|
|
|
|
def closed_form_pose_inverse( |
|
|
pose_matrices, rotation_matrices=None, translation_vectors=None |
|
|
): |
|
|
""" |
|
|
Compute the inverse of each 4x4 (or 3x4) SE3 pose matrices in a batch. |
|
|
|
|
|
If `rotation_matrices` and `translation_vectors` are provided, they must correspond to the rotation and translation |
|
|
components of `pose_matrices`. Otherwise, they will be extracted from `pose_matrices`. |
|
|
|
|
|
Args: |
|
|
pose_matrices: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. |
|
|
rotation_matrices (optional): Nx3x3 array or tensor of rotation matrices. |
|
|
translation_vectors (optional): Nx3x1 array or tensor of translation vectors. |
|
|
|
|
|
Returns: |
|
|
Inverted SE3 matrices with the same type and device as input `pose_matrices`. |
|
|
|
|
|
Shapes: |
|
|
pose_matrices: (N, 4, 4) |
|
|
rotation_matrices: (N, 3, 3) |
|
|
translation_vectors: (N, 3, 1) |
|
|
""" |
|
|
|
|
|
is_numpy = isinstance(pose_matrices, np.ndarray) |
|
|
|
|
|
|
|
|
if pose_matrices.shape[-2:] != (4, 4) and pose_matrices.shape[-2:] != (3, 4): |
|
|
raise ValueError( |
|
|
f"pose_matrices must be of shape (N,4,4), got {pose_matrices.shape}." |
|
|
) |
|
|
|
|
|
|
|
|
if rotation_matrices is None: |
|
|
rotation_matrices = pose_matrices[:, :3, :3] |
|
|
if translation_vectors is None: |
|
|
translation_vectors = pose_matrices[:, :3, 3:] |
|
|
|
|
|
|
|
|
if is_numpy: |
|
|
rotation_transposed = np.transpose(rotation_matrices, (0, 2, 1)) |
|
|
new_translation = -np.matmul(rotation_transposed, translation_vectors) |
|
|
inverted_matrix = np.tile(np.eye(4), (len(rotation_matrices), 1, 1)) |
|
|
else: |
|
|
rotation_transposed = rotation_matrices.transpose(1, 2) |
|
|
new_translation = -torch.bmm(rotation_transposed, translation_vectors) |
|
|
inverted_matrix = torch.eye(4, 4)[None].repeat(len(rotation_matrices), 1, 1) |
|
|
inverted_matrix = inverted_matrix.to(rotation_matrices.dtype).to( |
|
|
rotation_matrices.device |
|
|
) |
|
|
inverted_matrix[:, :3, :3] = rotation_transposed |
|
|
inverted_matrix[:, :3, 3:] = new_translation |
|
|
|
|
|
return inverted_matrix |
|
|
|
|
|
|
|
|
def relative_pose_transformation(trans_01, trans_02): |
|
|
r""" |
|
|
Function that computes the relative homogenous transformation from a |
|
|
reference transformation :math:`T_1^{0} = \begin{bmatrix} R_1 & t_1 \\ |
|
|
\mathbf{0} & 1 \end{bmatrix}` to destination :math:`T_2^{0} = |
|
|
\begin{bmatrix} R_2 & t_2 \\ \mathbf{0} & 1 \end{bmatrix}`. |
|
|
|
|
|
The relative transformation is computed as follows: |
|
|
|
|
|
.. math:: |
|
|
|
|
|
T_1^{2} = (T_0^{1})^{-1} \cdot T_0^{2} |
|
|
|
|
|
Arguments: |
|
|
trans_01 (torch.Tensor): reference transformation tensor of shape |
|
|
:math:`(N, 4, 4)` or :math:`(4, 4)`. |
|
|
trans_02 (torch.Tensor): destination transformation tensor of shape |
|
|
:math:`(N, 4, 4)` or :math:`(4, 4)`. |
|
|
|
|
|
Shape: |
|
|
- Output: :math:`(N, 4, 4)` or :math:`(4, 4)`. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: the relative transformation between the transformations. |
|
|
|
|
|
Example:: |
|
|
>>> trans_01 = torch.eye(4) # 4x4 |
|
|
>>> trans_02 = torch.eye(4) # 4x4 |
|
|
>>> trans_12 = relative_transformation(trans_01, trans_02) # 4x4 |
|
|
""" |
|
|
if not torch.is_tensor(trans_01): |
|
|
raise TypeError( |
|
|
"Input trans_01 type is not a torch.Tensor. Got {}".format(type(trans_01)) |
|
|
) |
|
|
if not torch.is_tensor(trans_02): |
|
|
raise TypeError( |
|
|
"Input trans_02 type is not a torch.Tensor. Got {}".format(type(trans_02)) |
|
|
) |
|
|
if trans_01.dim() not in (2, 3) and trans_01.shape[-2:] == (4, 4): |
|
|
raise ValueError( |
|
|
"Input must be a of the shape Nx4x4 or 4x4. Got {}".format(trans_01.shape) |
|
|
) |
|
|
if trans_02.dim() not in (2, 3) and trans_02.shape[-2:] == (4, 4): |
|
|
raise ValueError( |
|
|
"Input must be a of the shape Nx4x4 or 4x4. Got {}".format(trans_02.shape) |
|
|
) |
|
|
if not trans_01.dim() == trans_02.dim(): |
|
|
raise ValueError( |
|
|
"Input number of dims must match. Got {} and {}".format( |
|
|
trans_01.dim(), trans_02.dim() |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
squeeze_batch_dim = False |
|
|
if trans_01.dim() == 2: |
|
|
trans_01 = trans_01.unsqueeze(0) |
|
|
trans_02 = trans_02.unsqueeze(0) |
|
|
squeeze_batch_dim = True |
|
|
|
|
|
|
|
|
trans_10 = closed_form_pose_inverse(trans_01) |
|
|
|
|
|
|
|
|
trans_12 = torch.matmul(trans_10, trans_02) |
|
|
|
|
|
|
|
|
if squeeze_batch_dim: |
|
|
trans_12 = trans_12.squeeze(0) |
|
|
|
|
|
return trans_12 |
|
|
|
|
|
|
|
|
def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): |
|
|
""" |
|
|
Args: |
|
|
- depthmap (BxHxW array): |
|
|
- pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] |
|
|
Returns: |
|
|
pointmap of absolute coordinates (BxHxWx3 array) |
|
|
""" |
|
|
|
|
|
if len(depth.shape) == 4: |
|
|
B, H, W, n = depth.shape |
|
|
else: |
|
|
B, H, W = depth.shape |
|
|
n = None |
|
|
|
|
|
if len(pseudo_focal.shape) == 3: |
|
|
pseudo_focalx = pseudo_focaly = pseudo_focal |
|
|
elif len(pseudo_focal.shape) == 4: |
|
|
pseudo_focalx = pseudo_focal[:, 0] |
|
|
if pseudo_focal.shape[1] == 2: |
|
|
pseudo_focaly = pseudo_focal[:, 1] |
|
|
else: |
|
|
pseudo_focaly = pseudo_focalx |
|
|
else: |
|
|
raise NotImplementedError("Error, unknown input focal shape format.") |
|
|
|
|
|
assert pseudo_focalx.shape == depth.shape[:3] |
|
|
assert pseudo_focaly.shape == depth.shape[:3] |
|
|
grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] |
|
|
|
|
|
|
|
|
if pp is None: |
|
|
grid_x = grid_x - (W - 1) / 2 |
|
|
grid_y = grid_y - (H - 1) / 2 |
|
|
else: |
|
|
grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] |
|
|
grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] |
|
|
|
|
|
if n is None: |
|
|
pts3d = torch.empty((B, H, W, 3), device=depth.device) |
|
|
pts3d[..., 0] = depth * grid_x / pseudo_focalx |
|
|
pts3d[..., 1] = depth * grid_y / pseudo_focaly |
|
|
pts3d[..., 2] = depth |
|
|
else: |
|
|
pts3d = torch.empty((B, H, W, 3, n), device=depth.device) |
|
|
pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] |
|
|
pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] |
|
|
pts3d[..., 2, :] = depth |
|
|
return pts3d |
|
|
|
|
|
|
|
|
def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None, flow2d=None): |
|
|
""" |
|
|
Args: |
|
|
- depthmap (HxW array): |
|
|
- camera_intrinsics: a 3x3 matrix |
|
|
Returns: |
|
|
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. |
|
|
""" |
|
|
camera_intrinsics = np.float32(camera_intrinsics) |
|
|
H, W = depthmap.shape |
|
|
|
|
|
|
|
|
|
|
|
assert camera_intrinsics[0, 1] == 0.0 |
|
|
assert camera_intrinsics[1, 0] == 0.0 |
|
|
if pseudo_focal is None: |
|
|
fu = camera_intrinsics[0, 0] |
|
|
fv = camera_intrinsics[1, 1] |
|
|
else: |
|
|
assert pseudo_focal.shape == (H, W) |
|
|
fu = fv = pseudo_focal |
|
|
cu = camera_intrinsics[0, 2] |
|
|
cv = camera_intrinsics[1, 2] |
|
|
|
|
|
u, v = np.meshgrid(np.arange(W), np.arange(H)) |
|
|
|
|
|
if flow2d is not None: |
|
|
u = u + flow2d[..., 0] |
|
|
v = v + flow2d[..., 1] |
|
|
|
|
|
z_cam = depthmap |
|
|
x_cam = (u - cu) * z_cam / fu |
|
|
y_cam = (v - cv) * z_cam / fv |
|
|
X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) |
|
|
|
|
|
|
|
|
valid_mask = depthmap > 0.0 |
|
|
|
|
|
return X_cam, valid_mask |
|
|
|
|
|
|
|
|
def depthmap_to_absolute_camera_coordinates( |
|
|
depthmap, camera_intrinsics, camera_pose, flow2d=None, **kw |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
- depthmap (HxW array): |
|
|
- camera_intrinsics: a 3x3 matrix |
|
|
- camera_pose: a 4x3 or 4x4 cam2world matrix |
|
|
Returns: |
|
|
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. |
|
|
""" |
|
|
X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics, flow2d=flow2d) |
|
|
|
|
|
X_world = X_cam |
|
|
if camera_pose is not None: |
|
|
|
|
|
|
|
|
R_cam2world = camera_pose[:3, :3] |
|
|
t_cam2world = camera_pose[:3, 3] |
|
|
|
|
|
|
|
|
X_world = ( |
|
|
np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] |
|
|
) |
|
|
|
|
|
return X_world, valid_mask |
|
|
|
|
|
|
|
|
def get_absolute_pointmaps_and_rays_info( |
|
|
depthmap, camera_intrinsics, camera_pose, **kw |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
- depthmap (HxW array): |
|
|
- camera_intrinsics: a 3x3 matrix |
|
|
- camera_pose: a 4x3 or 4x4 cam2world matrix |
|
|
Returns: |
|
|
pointmap of absolute coordinates (HxWx3 array), |
|
|
a mask specifying valid pixels, |
|
|
ray origins of absolute coordinates (HxWx3 array), |
|
|
ray directions of absolute coordinates (HxWx3 array), |
|
|
depth along ray (HxWx1 array), |
|
|
ray directions of camera/local coordinates (HxWx3 array), |
|
|
pointmap of camera/local coordinates (HxWx3 array). |
|
|
""" |
|
|
camera_intrinsics = np.float32(camera_intrinsics) |
|
|
H, W = depthmap.shape |
|
|
|
|
|
|
|
|
|
|
|
assert camera_intrinsics[0, 1] == 0.0 |
|
|
assert camera_intrinsics[1, 0] == 0.0 |
|
|
fu = camera_intrinsics[0, 0] |
|
|
fv = camera_intrinsics[1, 1] |
|
|
cu = camera_intrinsics[0, 2] |
|
|
cv = camera_intrinsics[1, 2] |
|
|
|
|
|
|
|
|
u, v = np.meshgrid(np.arange(W), np.arange(H)) |
|
|
x_cam = (u - cu) / fu |
|
|
y_cam = (v - cv) / fv |
|
|
z_cam = np.ones_like(x_cam) |
|
|
ray_dirs_cam_on_unit_plane = np.stack((x_cam, y_cam, z_cam), axis=-1).astype( |
|
|
np.float32 |
|
|
) |
|
|
|
|
|
|
|
|
pts_cam = depthmap[..., None] * ray_dirs_cam_on_unit_plane |
|
|
|
|
|
|
|
|
depth_along_ray = np.linalg.norm(pts_cam, axis=-1, keepdims=True) |
|
|
ray_directions_cam = ray_dirs_cam_on_unit_plane / np.linalg.norm( |
|
|
ray_dirs_cam_on_unit_plane, axis=-1, keepdims=True |
|
|
) |
|
|
|
|
|
|
|
|
valid_mask = depthmap > 0.0 |
|
|
|
|
|
|
|
|
ray_origins_world = np.zeros_like(ray_directions_cam) |
|
|
ray_directions_world = ray_directions_cam |
|
|
pts_world = pts_cam |
|
|
if camera_pose is not None: |
|
|
R_cam2world = camera_pose[:3, :3] |
|
|
t_cam2world = camera_pose[:3, 3] |
|
|
|
|
|
|
|
|
ray_origins_world = ray_origins_world + t_cam2world[None, None, :] |
|
|
ray_directions_world = np.einsum( |
|
|
"ik, vuk -> vui", R_cam2world, ray_directions_cam |
|
|
) |
|
|
pts_world = ray_origins_world + ray_directions_world * depth_along_ray |
|
|
|
|
|
return ( |
|
|
pts_world, |
|
|
valid_mask, |
|
|
ray_origins_world, |
|
|
ray_directions_world, |
|
|
depth_along_ray, |
|
|
ray_directions_cam, |
|
|
pts_cam, |
|
|
) |
|
|
|
|
|
|
|
|
def adjust_camera_params_for_rotation(camera_params, original_size, k): |
|
|
""" |
|
|
Adjust camera parameters for rotation. |
|
|
|
|
|
Args: |
|
|
camera_params: Camera parameters [fx, fy, cx, cy, ...] |
|
|
original_size: Original image size as (width, height) |
|
|
k: Number of 90-degree rotations counter-clockwise (k=3 means 90 degrees clockwise) |
|
|
|
|
|
Returns: |
|
|
Adjusted camera parameters |
|
|
""" |
|
|
fx, fy, cx, cy = camera_params[:4] |
|
|
width, height = original_size |
|
|
|
|
|
if k % 4 == 1: |
|
|
new_fx, new_fy = fy, fx |
|
|
new_cx, new_cy = height - cy, cx |
|
|
elif k % 4 == 2: |
|
|
new_fx, new_fy = fx, fy |
|
|
new_cx, new_cy = width - cx, height - cy |
|
|
elif k % 4 == 3: |
|
|
new_fx, new_fy = fy, fx |
|
|
new_cx, new_cy = cy, width - cx |
|
|
else: |
|
|
return camera_params |
|
|
|
|
|
adjusted_params = [new_fx, new_fy, new_cx, new_cy] |
|
|
if len(camera_params) > 4: |
|
|
adjusted_params.extend(camera_params[4:]) |
|
|
|
|
|
return adjusted_params |
|
|
|
|
|
|
|
|
def adjust_pose_for_rotation(pose, k): |
|
|
""" |
|
|
Adjust camera pose for rotation. |
|
|
|
|
|
Args: |
|
|
pose: 4x4 camera pose matrix (camera-to-world, OpenCV convention - X right, Y down, Z forward) |
|
|
k: Number of 90-degree rotations counter-clockwise (k=3 means 90 degrees clockwise) |
|
|
|
|
|
Returns: |
|
|
Adjusted 4x4 camera pose matrix |
|
|
""" |
|
|
|
|
|
if k % 4 == 1: |
|
|
rot_transform = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]]) |
|
|
elif k % 4 == 2: |
|
|
rot_transform = np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]]) |
|
|
elif k % 4 == 3: |
|
|
rot_transform = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) |
|
|
else: |
|
|
return pose |
|
|
|
|
|
|
|
|
adjusted_pose = pose |
|
|
adjusted_pose[:3, :3] = adjusted_pose[:3, :3] @ rot_transform.T |
|
|
|
|
|
return adjusted_pose |
|
|
|
|
|
|
|
|
def crop_to_aspect_ratio(image, depth, camera_params, target_ratio=1.5): |
|
|
""" |
|
|
Crop image and depth to the largest possible target aspect ratio while |
|
|
keeping the left side if aspect ratio is wider and the bottom of image if the aspect ratio is taller. |
|
|
|
|
|
Args: |
|
|
image: PIL image |
|
|
depth: Depth map as numpy array |
|
|
camera_params: Camera parameters [fx, fy, cx, cy, ...] |
|
|
target_ratio: Target width/height ratio |
|
|
|
|
|
Returns: |
|
|
Cropped image, cropped depth, adjusted camera parameters |
|
|
""" |
|
|
width, height = image.size |
|
|
fx, fy, cx, cy = camera_params[:4] |
|
|
current_ratio = width / height |
|
|
|
|
|
if abs(current_ratio - target_ratio) < 1e-6: |
|
|
|
|
|
return image, depth, camera_params |
|
|
|
|
|
if current_ratio > target_ratio: |
|
|
|
|
|
new_width = int(height * target_ratio) |
|
|
left = 0 |
|
|
right = new_width |
|
|
|
|
|
|
|
|
cropped_image = image.crop((left, 0, right, height)) |
|
|
|
|
|
|
|
|
if len(depth.shape) == 3: |
|
|
cropped_depth = depth[:, left:right, :] |
|
|
else: |
|
|
cropped_depth = depth[:, left:right] |
|
|
|
|
|
|
|
|
new_cx = cx - left |
|
|
adjusted_params = [fx, fy, new_cx, cy] + list(camera_params[4:]) |
|
|
|
|
|
else: |
|
|
|
|
|
new_height = int(width / target_ratio) |
|
|
top = max(0, height - new_height) |
|
|
bottom = height |
|
|
|
|
|
|
|
|
cropped_image = image.crop((0, top, width, bottom)) |
|
|
|
|
|
|
|
|
if len(depth.shape) == 3: |
|
|
cropped_depth = depth[top:bottom, :, :] |
|
|
else: |
|
|
cropped_depth = depth[top:bottom, :] |
|
|
|
|
|
|
|
|
new_cy = cy - top |
|
|
adjusted_params = [fx, fy, cx, new_cy] + list(camera_params[4:]) |
|
|
|
|
|
return cropped_image, cropped_depth, adjusted_params |
|
|
|
|
|
|
|
|
def colmap_to_opencv_intrinsics(K): |
|
|
""" |
|
|
Modify camera intrinsics to follow a different convention. |
|
|
Coordinates of the center of the top-left pixels are by default: |
|
|
- (0.5, 0.5) in Colmap |
|
|
- (0,0) in OpenCV |
|
|
""" |
|
|
K = K.copy() |
|
|
K[0, 2] -= 0.5 |
|
|
K[1, 2] -= 0.5 |
|
|
|
|
|
return K |
|
|
|
|
|
|
|
|
def opencv_to_colmap_intrinsics(K): |
|
|
""" |
|
|
Modify camera intrinsics to follow a different convention. |
|
|
Coordinates of the center of the top-left pixels are by default: |
|
|
- (0.5, 0.5) in Colmap |
|
|
- (0,0) in OpenCV |
|
|
""" |
|
|
K = K.copy() |
|
|
K[0, 2] += 0.5 |
|
|
K[1, 2] += 0.5 |
|
|
|
|
|
return K |
|
|
|
|
|
|
|
|
def normalize_depth_using_non_zero_pixels(depth, return_norm_factor=False): |
|
|
""" |
|
|
Normalize the depth by the average depth of non-zero depth pixels. |
|
|
|
|
|
Args: |
|
|
depth (torch.Tensor): Depth tensor of size [B, H, W, 1]. |
|
|
Returns: |
|
|
normalized_depth (torch.Tensor): Normalized depth tensor. |
|
|
norm_factor (torch.Tensor): Norm factor tensor of size B. |
|
|
""" |
|
|
assert depth.ndim == 4 and depth.shape[3] == 1 |
|
|
|
|
|
valid_depth_mask = depth > 0 |
|
|
valid_sum = torch.sum(depth * valid_depth_mask, dim=(1, 2, 3)) |
|
|
valid_count = torch.sum(valid_depth_mask, dim=(1, 2, 3)) |
|
|
|
|
|
|
|
|
norm_factor = valid_sum / (valid_count + 1e-8) |
|
|
while norm_factor.ndim < depth.ndim: |
|
|
norm_factor.unsqueeze_(-1) |
|
|
|
|
|
|
|
|
norm_factor = norm_factor.clip(min=1e-8) |
|
|
normalized_depth = depth / norm_factor |
|
|
|
|
|
|
|
|
output = ( |
|
|
(normalized_depth, norm_factor.squeeze(-1).squeeze(-1).squeeze(-1)) |
|
|
if return_norm_factor |
|
|
else normalized_depth |
|
|
) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def normalize_pose_translations(pose_translations, return_norm_factor=False): |
|
|
""" |
|
|
Normalize the pose translations by the average norm of the non-zero pose translations. |
|
|
|
|
|
Args: |
|
|
pose_translations (torch.Tensor): Pose translations tensor of size [B, V, 3]. B is the batch size, V is the number of views. |
|
|
Returns: |
|
|
normalized_pose_translations (torch.Tensor): Normalized pose translations tensor of size [B, V, 3]. |
|
|
norm_factor (torch.Tensor): Norm factor tensor of size B. |
|
|
""" |
|
|
assert pose_translations.ndim == 3 and pose_translations.shape[2] == 3 |
|
|
|
|
|
pose_translations_dis = pose_translations.norm(dim=-1) |
|
|
non_zero_pose_translations_dis = pose_translations_dis > 0 |
|
|
|
|
|
|
|
|
sum_of_all_views_pose_translations = pose_translations_dis.sum(dim=1) |
|
|
count_of_all_views_with_non_zero_pose_translations = ( |
|
|
non_zero_pose_translations_dis.sum(dim=1) |
|
|
) |
|
|
norm_factor = sum_of_all_views_pose_translations / ( |
|
|
count_of_all_views_with_non_zero_pose_translations + 1e-8 |
|
|
) |
|
|
|
|
|
|
|
|
norm_factor = norm_factor.clip(min=1e-8) |
|
|
normalized_pose_translations = pose_translations / norm_factor.unsqueeze( |
|
|
-1 |
|
|
).unsqueeze(-1) |
|
|
|
|
|
|
|
|
output = ( |
|
|
(normalized_pose_translations, norm_factor) |
|
|
if return_norm_factor |
|
|
else normalized_pose_translations |
|
|
) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def normalize_multiple_pointclouds( |
|
|
pts_list, valid_masks=None, norm_mode="avg_dis", ret_factor=False |
|
|
): |
|
|
""" |
|
|
Normalize multiple point clouds using a joint normalization strategy. |
|
|
|
|
|
Args: |
|
|
pts_list: List of point clouds, each with shape (..., H, W, 3) or (B, H, W, 3) |
|
|
valid_masks: Optional list of masks indicating valid points in each point cloud |
|
|
norm_mode: String in format "{norm}_{dis}" where: |
|
|
- norm: Normalization strategy (currently only "avg" is supported) |
|
|
- dis: Distance transformation ("dis" for raw distance, "log1p" for log(1+distance), |
|
|
"warp-log1p" to warp points using log distance) |
|
|
ret_factor: If True, return the normalization factor as the last element in the result list |
|
|
|
|
|
Returns: |
|
|
List of normalized point clouds with the same shapes as inputs. |
|
|
If ret_factor is True, the last element is the normalization factor. |
|
|
""" |
|
|
assert all(pts.ndim >= 3 and pts.shape[-1] == 3 for pts in pts_list) |
|
|
if valid_masks is not None: |
|
|
assert len(pts_list) == len(valid_masks) |
|
|
|
|
|
norm_mode, dis_mode = norm_mode.split("_") |
|
|
|
|
|
|
|
|
nan_pts_list = [ |
|
|
invalid_to_zeros(pts, valid_masks[i], ndim=3) |
|
|
if valid_masks |
|
|
else invalid_to_zeros(pts, None, ndim=3) |
|
|
for i, pts in enumerate(pts_list) |
|
|
] |
|
|
all_pts = torch.cat([nan_pts for nan_pts, _ in nan_pts_list], dim=1) |
|
|
nnz_list = [nnz for _, nnz in nan_pts_list] |
|
|
|
|
|
|
|
|
all_dis = all_pts.norm(dim=-1) |
|
|
if dis_mode == "dis": |
|
|
pass |
|
|
elif dis_mode == "log1p": |
|
|
all_dis = torch.log1p(all_dis) |
|
|
elif dis_mode == "warp-log1p": |
|
|
|
|
|
log_dis = torch.log1p(all_dis) |
|
|
warp_factor = log_dis / all_dis.clip(min=1e-8) |
|
|
for i, pts in enumerate(pts_list): |
|
|
H, W = pts.shape[1:-1] |
|
|
pts_list[i] = pts * warp_factor[:, i * (H * W) : (i + 1) * (H * W)].view( |
|
|
-1, H, W, 1 |
|
|
) |
|
|
all_dis = log_dis |
|
|
else: |
|
|
raise ValueError(f"bad {dis_mode=}") |
|
|
|
|
|
|
|
|
norm_factor = all_dis.sum(dim=1) / (sum(nnz_list) + 1e-8) |
|
|
norm_factor = norm_factor.clip(min=1e-8) |
|
|
while norm_factor.ndim < pts_list[0].ndim: |
|
|
norm_factor.unsqueeze_(-1) |
|
|
|
|
|
|
|
|
res = [pts / norm_factor for pts in pts_list] |
|
|
if ret_factor: |
|
|
res.append(norm_factor) |
|
|
|
|
|
return res |
|
|
|
|
|
|
|
|
def apply_log_to_norm(input_data): |
|
|
""" |
|
|
Normalize the input data and apply a logarithmic transformation based on the normalization factor. |
|
|
|
|
|
Args: |
|
|
input_data (torch.Tensor): The input tensor to be normalized and transformed. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The transformed tensor after normalization and logarithmic scaling. |
|
|
""" |
|
|
org_d = input_data.norm(dim=-1, keepdim=True) |
|
|
input_data = input_data / org_d.clip(min=1e-8) |
|
|
input_data = input_data * torch.log1p(org_d) |
|
|
return input_data |
|
|
|
|
|
|
|
|
def angle_diff_vec3(v1, v2, eps=1e-12): |
|
|
""" |
|
|
Compute angle difference between 3D vectors. |
|
|
|
|
|
Args: |
|
|
v1: torch.Tensor of shape (..., 3) |
|
|
v2: torch.Tensor of shape (..., 3) |
|
|
eps: Small epsilon value for numerical stability |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Angle differences in radians |
|
|
""" |
|
|
cross_norm = torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps |
|
|
dot_prod = (v1 * v2).sum(dim=-1) |
|
|
return torch.atan2(cross_norm, dot_prod) |
|
|
|
|
|
|
|
|
def angle_diff_vec3_numpy(v1: np.ndarray, v2: np.ndarray, eps: float = 1e-12): |
|
|
""" |
|
|
Compute angle difference between 3D vectors using NumPy. |
|
|
|
|
|
Args: |
|
|
v1 (np.ndarray): First vector of shape (..., 3) |
|
|
v2 (np.ndarray): Second vector of shape (..., 3) |
|
|
eps (float, optional): Small epsilon value for numerical stability. Defaults to 1e-12. |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Angle differences in radians |
|
|
""" |
|
|
return np.arctan2( |
|
|
np.linalg.norm(np.cross(v1, v2, axis=-1), axis=-1) + eps, (v1 * v2).sum(axis=-1) |
|
|
) |
|
|
|
|
|
|
|
|
@no_warnings(category=RuntimeWarning) |
|
|
def points_to_normals( |
|
|
point: np.ndarray, mask: np.ndarray = None, edge_threshold: float = None |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Calculate normal map from point map. Value range is [-1, 1]. |
|
|
|
|
|
Args: |
|
|
point (np.ndarray): shape (height, width, 3), point map |
|
|
mask (optional, np.ndarray): shape (height, width), dtype=bool. Mask of valid depth pixels. Defaults to None. |
|
|
edge_threshold (optional, float): threshold for the angle (in degrees) between the normal and the view direction. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
normal (np.ndarray): shape (height, width, 3), normal map. |
|
|
""" |
|
|
height, width = point.shape[-3:-1] |
|
|
has_mask = mask is not None |
|
|
|
|
|
if mask is None: |
|
|
mask = np.ones_like(point[..., 0], dtype=bool) |
|
|
mask_pad = np.zeros((height + 2, width + 2), dtype=bool) |
|
|
mask_pad[1:-1, 1:-1] = mask |
|
|
mask = mask_pad |
|
|
|
|
|
pts = np.zeros((height + 2, width + 2, 3), dtype=point.dtype) |
|
|
pts[1:-1, 1:-1, :] = point |
|
|
up = pts[:-2, 1:-1, :] - pts[1:-1, 1:-1, :] |
|
|
left = pts[1:-1, :-2, :] - pts[1:-1, 1:-1, :] |
|
|
down = pts[2:, 1:-1, :] - pts[1:-1, 1:-1, :] |
|
|
right = pts[1:-1, 2:, :] - pts[1:-1, 1:-1, :] |
|
|
normal = np.stack( |
|
|
[ |
|
|
np.cross(up, left, axis=-1), |
|
|
np.cross(left, down, axis=-1), |
|
|
np.cross(down, right, axis=-1), |
|
|
np.cross(right, up, axis=-1), |
|
|
] |
|
|
) |
|
|
normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12) |
|
|
|
|
|
valid = ( |
|
|
np.stack( |
|
|
[ |
|
|
mask[:-2, 1:-1] & mask[1:-1, :-2], |
|
|
mask[1:-1, :-2] & mask[2:, 1:-1], |
|
|
mask[2:, 1:-1] & mask[1:-1, 2:], |
|
|
mask[1:-1, 2:] & mask[:-2, 1:-1], |
|
|
] |
|
|
) |
|
|
& mask[None, 1:-1, 1:-1] |
|
|
) |
|
|
if edge_threshold is not None: |
|
|
view_angle = angle_diff_vec3_numpy(pts[None, 1:-1, 1:-1, :], normal) |
|
|
view_angle = np.minimum(view_angle, np.pi - view_angle) |
|
|
valid = valid & (view_angle < np.deg2rad(edge_threshold)) |
|
|
|
|
|
normal = (normal * valid[..., None]).sum(axis=0) |
|
|
normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12) |
|
|
|
|
|
if has_mask: |
|
|
normal_mask = valid.any(axis=0) |
|
|
normal = np.where(normal_mask[..., None], normal, 0) |
|
|
return normal, normal_mask |
|
|
else: |
|
|
return normal |
|
|
|
|
|
|
|
|
def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1): |
|
|
""" |
|
|
Create a sliding window view of the input array along a specified axis. |
|
|
|
|
|
This function creates a memory-efficient view of the input array with sliding windows |
|
|
of the specified size and stride. The window dimension is appended to the end of the |
|
|
output array's shape. This is useful for operations like convolution, pooling, or |
|
|
any analysis that requires examining local neighborhoods in the data. |
|
|
|
|
|
Args: |
|
|
x (np.ndarray): Input array with shape (..., axis_size, ...) |
|
|
window_size (int): Size of the sliding window |
|
|
stride (int): Stride of the sliding window (step size between consecutive windows) |
|
|
axis (int, optional): Axis to perform sliding window over. Defaults to -1 (last axis) |
|
|
|
|
|
Returns: |
|
|
np.ndarray: View of the input array with shape (..., n_windows, ..., window_size), |
|
|
where n_windows = (axis_size - window_size + 1) // stride |
|
|
|
|
|
Raises: |
|
|
AssertionError: If window_size is larger than the size of the specified axis |
|
|
|
|
|
Example: |
|
|
>>> x = np.array([1, 2, 3, 4, 5, 6]) |
|
|
>>> sliding_window_1d(x, window_size=3, stride=2) |
|
|
array([[1, 2, 3], |
|
|
[3, 4, 5]]) |
|
|
""" |
|
|
assert x.shape[axis] >= window_size, ( |
|
|
f"kernel_size ({window_size}) is larger than axis_size ({x.shape[axis]})" |
|
|
) |
|
|
axis = axis % x.ndim |
|
|
shape = ( |
|
|
*x.shape[:axis], |
|
|
(x.shape[axis] - window_size + 1) // stride, |
|
|
*x.shape[axis + 1 :], |
|
|
window_size, |
|
|
) |
|
|
strides = ( |
|
|
*x.strides[:axis], |
|
|
stride * x.strides[axis], |
|
|
*x.strides[axis + 1 :], |
|
|
x.strides[axis], |
|
|
) |
|
|
x_sliding = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) |
|
|
return x_sliding |
|
|
|
|
|
|
|
|
def sliding_window_nd( |
|
|
x: np.ndarray, |
|
|
window_size: Tuple[int, ...], |
|
|
stride: Tuple[int, ...], |
|
|
axis: Tuple[int, ...], |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Create sliding windows along multiple dimensions of the input array. |
|
|
|
|
|
This function applies sliding_window_1d sequentially along multiple axes to create |
|
|
N-dimensional sliding windows. This is useful for operations that need to examine |
|
|
local neighborhoods in multiple dimensions simultaneously. |
|
|
|
|
|
Args: |
|
|
x (np.ndarray): Input array |
|
|
window_size (Tuple[int, ...]): Size of the sliding window for each axis |
|
|
stride (Tuple[int, ...]): Stride of the sliding window for each axis |
|
|
axis (Tuple[int, ...]): Axes to perform sliding window over |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Array with sliding windows along the specified dimensions. |
|
|
The window dimensions are appended to the end of the shape. |
|
|
|
|
|
Note: |
|
|
The length of window_size, stride, and axis tuples must be equal. |
|
|
|
|
|
Example: |
|
|
>>> x = np.random.rand(10, 10) |
|
|
>>> windows = sliding_window_nd(x, window_size=(3, 3), stride=(2, 2), axis=(-2, -1)) |
|
|
>>> # Creates 3x3 sliding windows with stride 2 in both dimensions |
|
|
""" |
|
|
axis = [axis[i] % x.ndim for i in range(len(axis))] |
|
|
for i in range(len(axis)): |
|
|
x = sliding_window_1d(x, window_size[i], stride[i], axis[i]) |
|
|
return x |
|
|
|
|
|
|
|
|
def sliding_window_2d( |
|
|
x: np.ndarray, |
|
|
window_size: Union[int, Tuple[int, int]], |
|
|
stride: Union[int, Tuple[int, int]], |
|
|
axis: Tuple[int, int] = (-2, -1), |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Create 2D sliding windows over the input array. |
|
|
|
|
|
Convenience function for creating 2D sliding windows, commonly used for image |
|
|
processing operations like convolution, pooling, or patch extraction. |
|
|
|
|
|
Args: |
|
|
x (np.ndarray): Input array |
|
|
window_size (Union[int, Tuple[int, int]]): Size of the 2D sliding window. |
|
|
If int, same size is used for both dimensions. |
|
|
stride (Union[int, Tuple[int, int]]): Stride of the 2D sliding window. |
|
|
If int, same stride is used for both dimensions. |
|
|
axis (Tuple[int, int], optional): Two axes to perform sliding window over. |
|
|
Defaults to (-2, -1) (last two dimensions). |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Array with 2D sliding windows. The window dimensions (height, width) |
|
|
are appended to the end of the shape. |
|
|
|
|
|
Example: |
|
|
>>> image = np.random.rand(100, 100) |
|
|
>>> patches = sliding_window_2d(image, window_size=8, stride=4) |
|
|
>>> # Creates 8x8 patches with stride 4 from the image |
|
|
""" |
|
|
if isinstance(window_size, int): |
|
|
window_size = (window_size, window_size) |
|
|
if isinstance(stride, int): |
|
|
stride = (stride, stride) |
|
|
return sliding_window_nd(x, window_size, stride, axis) |
|
|
|
|
|
|
|
|
def max_pool_1d( |
|
|
x: np.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1 |
|
|
): |
|
|
""" |
|
|
Perform 1D max pooling on the input array. |
|
|
|
|
|
Max pooling reduces the dimensionality of the input by taking the maximum value |
|
|
within each sliding window. This is commonly used in neural networks and signal |
|
|
processing for downsampling and feature extraction. |
|
|
|
|
|
Args: |
|
|
x (np.ndarray): Input array |
|
|
kernel_size (int): Size of the pooling kernel |
|
|
stride (int): Stride of the pooling operation |
|
|
padding (int, optional): Amount of padding to add on both sides. Defaults to 0. |
|
|
axis (int, optional): Axis to perform max pooling over. Defaults to -1. |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Max pooled array with reduced size along the specified axis |
|
|
|
|
|
Note: |
|
|
- For floating point arrays, padding is done with np.nan values |
|
|
- For integer arrays, padding is done with the minimum value of the dtype |
|
|
- np.nanmax is used to handle NaN values in the computation |
|
|
|
|
|
Example: |
|
|
>>> x = np.array([1, 3, 2, 4, 5, 1, 2]) |
|
|
>>> max_pool_1d(x, kernel_size=3, stride=2) |
|
|
array([3, 5, 2]) |
|
|
""" |
|
|
axis = axis % x.ndim |
|
|
if padding > 0: |
|
|
fill_value = np.nan if x.dtype.kind == "f" else np.iinfo(x.dtype).min |
|
|
padding_arr = np.full( |
|
|
(*x.shape[:axis], padding, *x.shape[axis + 1 :]), |
|
|
fill_value=fill_value, |
|
|
dtype=x.dtype, |
|
|
) |
|
|
x = np.concatenate([padding_arr, x, padding_arr], axis=axis) |
|
|
a_sliding = sliding_window_1d(x, kernel_size, stride, axis) |
|
|
max_pool = np.nanmax(a_sliding, axis=-1) |
|
|
return max_pool |
|
|
|
|
|
|
|
|
def max_pool_nd( |
|
|
x: np.ndarray, |
|
|
kernel_size: Tuple[int, ...], |
|
|
stride: Tuple[int, ...], |
|
|
padding: Tuple[int, ...], |
|
|
axis: Tuple[int, ...], |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Perform N-dimensional max pooling on the input array. |
|
|
|
|
|
This function applies max_pool_1d sequentially along multiple axes to perform |
|
|
multi-dimensional max pooling. This is useful for downsampling multi-dimensional |
|
|
data while preserving the most important features. |
|
|
|
|
|
Args: |
|
|
x (np.ndarray): Input array |
|
|
kernel_size (Tuple[int, ...]): Size of the pooling kernel for each axis |
|
|
stride (Tuple[int, ...]): Stride of the pooling operation for each axis |
|
|
padding (Tuple[int, ...]): Amount of padding for each axis |
|
|
axis (Tuple[int, ...]): Axes to perform max pooling over |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Max pooled array with reduced size along the specified axes |
|
|
|
|
|
Note: |
|
|
The length of kernel_size, stride, padding, and axis tuples must be equal. |
|
|
Max pooling is applied sequentially along each axis in the order specified. |
|
|
|
|
|
Example: |
|
|
>>> x = np.random.rand(10, 10, 10) |
|
|
>>> pooled = max_pool_nd(x, kernel_size=(2, 2, 2), stride=(2, 2, 2), |
|
|
... padding=(0, 0, 0), axis=(-3, -2, -1)) |
|
|
>>> # Reduces each dimension by half with 2x2x2 max pooling |
|
|
""" |
|
|
for i in range(len(axis)): |
|
|
x = max_pool_1d(x, kernel_size[i], stride[i], padding[i], axis[i]) |
|
|
return x |
|
|
|
|
|
|
|
|
def max_pool_2d( |
|
|
x: np.ndarray, |
|
|
kernel_size: Union[int, Tuple[int, int]], |
|
|
stride: Union[int, Tuple[int, int]], |
|
|
padding: Union[int, Tuple[int, int]], |
|
|
axis: Tuple[int, int] = (-2, -1), |
|
|
): |
|
|
""" |
|
|
Perform 2D max pooling on the input array. |
|
|
|
|
|
Convenience function for 2D max pooling, commonly used in computer vision |
|
|
and image processing for downsampling images while preserving important features. |
|
|
|
|
|
Args: |
|
|
x (np.ndarray): Input array |
|
|
kernel_size (Union[int, Tuple[int, int]]): Size of the 2D pooling kernel. |
|
|
If int, same size is used for both dimensions. |
|
|
stride (Union[int, Tuple[int, int]]): Stride of the 2D pooling operation. |
|
|
If int, same stride is used for both dimensions. |
|
|
padding (Union[int, Tuple[int, int]]): Amount of padding for both dimensions. |
|
|
If int, same padding is used for both dimensions. |
|
|
axis (Tuple[int, int], optional): Two axes to perform max pooling over. |
|
|
Defaults to (-2, -1) (last two dimensions). |
|
|
|
|
|
Returns: |
|
|
np.ndarray: 2D max pooled array with reduced size along the specified axes |
|
|
|
|
|
Example: |
|
|
>>> image = np.random.rand(64, 64) |
|
|
>>> pooled = max_pool_2d(image, kernel_size=2, stride=2, padding=0) |
|
|
>>> # Reduces image size from 64x64 to 32x32 with 2x2 max pooling |
|
|
""" |
|
|
if isinstance(kernel_size, Number): |
|
|
kernel_size = (kernel_size, kernel_size) |
|
|
if isinstance(stride, Number): |
|
|
stride = (stride, stride) |
|
|
if isinstance(padding, Number): |
|
|
padding = (padding, padding) |
|
|
axis = tuple(axis) |
|
|
return max_pool_nd(x, kernel_size, stride, padding, axis) |
|
|
|
|
|
|
|
|
@no_warnings(category=RuntimeWarning) |
|
|
def depth_edge( |
|
|
depth: np.ndarray, |
|
|
atol: float = None, |
|
|
rtol: float = None, |
|
|
kernel_size: int = 3, |
|
|
mask: np.ndarray = None, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth. |
|
|
|
|
|
Args: |
|
|
depth (np.ndarray): shape (..., height, width), linear depth map |
|
|
atol (float): absolute tolerance |
|
|
rtol (float): relative tolerance |
|
|
|
|
|
Returns: |
|
|
edge (np.ndarray): shape (..., height, width) of dtype torch.bool |
|
|
""" |
|
|
if mask is None: |
|
|
diff = max_pool_2d( |
|
|
depth, kernel_size, stride=1, padding=kernel_size // 2 |
|
|
) + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) |
|
|
else: |
|
|
diff = max_pool_2d( |
|
|
np.where(mask, depth, -np.inf), |
|
|
kernel_size, |
|
|
stride=1, |
|
|
padding=kernel_size // 2, |
|
|
) + max_pool_2d( |
|
|
np.where(mask, -depth, -np.inf), |
|
|
kernel_size, |
|
|
stride=1, |
|
|
padding=kernel_size // 2, |
|
|
) |
|
|
|
|
|
edge = np.zeros_like(depth, dtype=bool) |
|
|
if atol is not None: |
|
|
edge |= diff > atol |
|
|
|
|
|
if rtol is not None: |
|
|
edge |= diff / depth > rtol |
|
|
return edge |
|
|
|
|
|
|
|
|
def depth_aliasing( |
|
|
depth: np.ndarray, |
|
|
atol: float = None, |
|
|
rtol: float = None, |
|
|
kernel_size: int = 3, |
|
|
mask: np.ndarray = None, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. |
|
|
Args: |
|
|
depth (np.ndarray): shape (..., height, width), linear depth map |
|
|
atol (float): absolute tolerance |
|
|
rtol (float): relative tolerance |
|
|
|
|
|
Returns: |
|
|
edge (np.ndarray): shape (..., height, width) of dtype torch.bool |
|
|
""" |
|
|
if mask is None: |
|
|
diff_max = ( |
|
|
max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth |
|
|
) |
|
|
diff_min = ( |
|
|
max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth |
|
|
) |
|
|
else: |
|
|
diff_max = ( |
|
|
max_pool_2d( |
|
|
np.where(mask, depth, -np.inf), |
|
|
kernel_size, |
|
|
stride=1, |
|
|
padding=kernel_size // 2, |
|
|
) |
|
|
- depth |
|
|
) |
|
|
diff_min = ( |
|
|
max_pool_2d( |
|
|
np.where(mask, -depth, -np.inf), |
|
|
kernel_size, |
|
|
stride=1, |
|
|
padding=kernel_size // 2, |
|
|
) |
|
|
+ depth |
|
|
) |
|
|
diff = np.minimum(diff_max, diff_min) |
|
|
|
|
|
edge = np.zeros_like(depth, dtype=bool) |
|
|
if atol is not None: |
|
|
edge |= diff > atol |
|
|
if rtol is not None: |
|
|
edge |= diff / depth > rtol |
|
|
return edge |
|
|
|
|
|
|
|
|
@no_warnings(category=RuntimeWarning) |
|
|
def normals_edge( |
|
|
normals: np.ndarray, tol: float, kernel_size: int = 3, mask: np.ndarray = None |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Compute the edge mask from normal map. |
|
|
|
|
|
Args: |
|
|
normal (np.ndarray): shape (..., height, width, 3), normal map |
|
|
tol (float): tolerance in degrees |
|
|
|
|
|
Returns: |
|
|
edge (np.ndarray): shape (..., height, width) of dtype torch.bool |
|
|
""" |
|
|
assert normals.ndim >= 3 and normals.shape[-1] == 3, ( |
|
|
"normal should be of shape (..., height, width, 3)" |
|
|
) |
|
|
normals = normals / (np.linalg.norm(normals, axis=-1, keepdims=True) + 1e-12) |
|
|
|
|
|
padding = kernel_size // 2 |
|
|
normals_window = sliding_window_2d( |
|
|
np.pad( |
|
|
normals, |
|
|
( |
|
|
*([(0, 0)] * (normals.ndim - 3)), |
|
|
(padding, padding), |
|
|
(padding, padding), |
|
|
(0, 0), |
|
|
), |
|
|
mode="edge", |
|
|
), |
|
|
window_size=kernel_size, |
|
|
stride=1, |
|
|
axis=(-3, -2), |
|
|
) |
|
|
if mask is None: |
|
|
angle_diff = np.arccos( |
|
|
(normals[..., None, None] * normals_window).sum(axis=-3) |
|
|
).max(axis=(-2, -1)) |
|
|
else: |
|
|
mask_window = sliding_window_2d( |
|
|
np.pad( |
|
|
mask, |
|
|
(*([(0, 0)] * (mask.ndim - 3)), (padding, padding), (padding, padding)), |
|
|
mode="edge", |
|
|
), |
|
|
window_size=kernel_size, |
|
|
stride=1, |
|
|
axis=(-3, -2), |
|
|
) |
|
|
angle_diff = np.where( |
|
|
mask_window, |
|
|
np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)), |
|
|
0, |
|
|
).max(axis=(-2, -1)) |
|
|
|
|
|
angle_diff = max_pool_2d( |
|
|
angle_diff, kernel_size, stride=1, padding=kernel_size // 2 |
|
|
) |
|
|
edge = angle_diff > np.deg2rad(tol) |
|
|
return edge |
|
|
|