Argus / argus /utils /geometry.py
lixi042
Initial commit: Argus metric panoramic 3D reconstruction demo
510e990
Raw
History Blame Contribute Delete
7.55 kB
import torch
import numpy as np
def closed_form_inverse_se3(se3, R=None, T=None):
"""
Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
If `R` and `T` are provided, they must correspond to the rotation and translation
components of `se3`. Otherwise, they will be extracted from `se3`.
Args:
se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
R (optional): Nx3x3 array or tensor of rotation matrices.
T (optional): Nx3x1 array or tensor of translation vectors.
Returns:
Inverted SE3 matrices with the same type and device as `se3`.
Shapes:
se3: (N, 4, 4)
R: (N, 3, 3)
T: (N, 3, 1)
"""
# Check if se3 is a numpy array or a torch tensor
is_numpy = isinstance(se3, np.ndarray)
# Validate shapes
if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
# Extract R and T if not provided
if R is None:
R = se3[:, :3, :3] # (N,3,3)
if T is None:
T = se3[:, :3, 3:] # (N,3,1)
# Transpose R
if is_numpy:
# Compute the transpose of the rotation for NumPy
R_transposed = np.transpose(R, (0, 2, 1))
# -R^T t for NumPy
top_right = -np.matmul(R_transposed, T)
inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
else:
R_transposed = R.transpose(1, 2) # (N,3,3)
top_right = -torch.bmm(R_transposed, T) # (N,3,1)
inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
inverted_matrix[:, :3, :3] = R_transposed
inverted_matrix[:, :3, 3:] = top_right
return inverted_matrix
def pano_depth_to_points(depth_map, original_pano_shape=(560, 280), crop_ratio=0.15):
"""
Convert batched cropped panoramic depth maps to 3D point clouds (PyTorch implementation).
Assumption: Input depth maps are already cropped by crop_ratio on top and bottom.
Args:
depth_map (torch.Tensor): Input cropped depth map, shape [B, S, H_crop, W, 1]
original_pano_shape (tuple): Original uncropped panorama size (W_ori, H_ori), default (560, 280)
crop_ratio (float): Crop ratio of original panorama (top and bottom respectively), default 0.15
Returns:
torch.Tensor: 3D point cloud with shape [B, S, H_crop, W, 3]
"""
# Validate input shape
assert depth_map.dim() == 5 and depth_map.shape[-1] == 1, \
f"Input must be [B, S, H_crop, W, 1], got {depth_map.shape}"
B, S, H_crop, W, _ = depth_map.shape
W_ori, H_ori = original_pano_shape
device = depth_map.device # Align tensor device automatically
# Generate pixel grid coordinates (H_crop, W)
px_grid, py_grid = torch.meshgrid(
torch.arange(W, device=device),
torch.arange(H_crop, device=device),
indexing='xy' # Consistent with numpy's meshgrid
)
# Restore to original panorama y-coordinates (compensate for cropping)
crop_top = int(crop_ratio * H_ori)
py_ori = py_grid + crop_top
# Compute spherical coordinates (lat: latitude, long: longitude)
lat = (py_ori / H_ori - 0.5) * torch.pi
long = (px_grid / W_ori - 0.5) * 2 * torch.pi
# Remove channel dim and compute 3D Cartesian coordinates
dist = depth_map.squeeze(-1) # [B, S, H_crop, W]
y = dist * torch.sin(lat)
tmp = dist * torch.cos(lat)
x = tmp * torch.sin(long)
z = tmp * torch.cos(long)
# Concatenate to form 3D point cloud
point_cloud = torch.stack([x, y, z], dim=-1)
return point_cloud
def points_to_pano_depth(points):
"""
Convert 3D point cloud back to ray panoramic depth map.
Ignore the error in direction.
Args:
points (torch.Tensor): Input 3D point cloud, shape [B, S, H, W, 3]
Returns:
torch.Tensor: panoramic depth map, shape [B, S, H, W, 1]
"""
# Validate input shape and fill mode
assert points.dim() == 5 and points.shape[-1] == 3, \
f"Input point cloud must be [B, S, H, W, 3], got {points.shape}"
# Compute radial depth (dist = sqrt(x² + y² + z²))
dist = torch.norm(points, dim=-1, keepdim=True) # [B, S, H, W, 1]
return dist
def camera_points_to_rotated_points(cam_points, R):
"""
Rotate batched panoramic camera point clouds with corresponding rotation matrices.
Args:
cam_points (torch.Tensor): Input camera 3D point cloud, shape [B, S, H, W, 3]
R (torch.Tensor): Corresponding rotation matrices, shape [B, S, 3, 3]
Returns:
torch.Tensor: Rotated 3D point cloud, shape [B, S, H, W, 3] (same as input cam_points)
"""
# Validate input shapes and dimensions matching
assert cam_points.dim() == 5 and cam_points.shape[-1] == 3, \
f"Camera points must be [B, S, H, W, 3], got {cam_points.shape}"
assert R.dim() == 4 and R.shape[2:] == (3, 3), \
f"Rotation matrices R must be [B, S, 3, 3], got {R.shape}"
assert cam_points.shape[:2] == R.shape[:2], \
f"Batch/Sequence dim mismatch: cam_points {cam_points.shape[:2]} vs R {R.shape[:2]}"
# Expand dimensions for broadcasting (align spatial dimensions H, W)
cam_points_expanded = cam_points.unsqueeze(-1) # [B, S, H, W, 3, 1]
R_expanded = R.unsqueeze(2).unsqueeze(2) # [B, S, 1, 1, 3, 3]
# Batch matrix multiplication: R @ p (rotation operation)
rotated_points_expanded = torch.matmul(R_expanded, cam_points_expanded)
# Squeeze redundant dimension to recover original shape
rotated_points = rotated_points_expanded.squeeze(-1)
return rotated_points
def rotated_points_to_world_points(rotated_points, t):
"""
Transform rotated camera points to world coordinates by adding translation vector.
Args:
rotated_points (torch.Tensor): Rotated 3D point cloud, shape [B, S, H, W, 3]
t (torch.Tensor): Translation vector, shape [B, S, 3] (per batch-sequence translation)
Returns:
torch.Tensor: World-coordinate 3D point cloud, shape [B, S, H, W, 3] (same as input)
"""
# Validate input shapes and dimension matching
assert rotated_points.dim() == 5 and rotated_points.shape[-1] == 3, \
f"Rotated points must be [B, S, H, W, 3], got {rotated_points.shape}"
assert t.dim() == 3 and t.shape[-1] == 3, \
f"Translation t must be [B, S, 3], got {t.shape}"
assert rotated_points.shape[:2] == t.shape[:2], \
f"Batch/Sequence dim mismatch: rotated_points {rotated_points.shape[:2]} vs t {t.shape[:2]}"
# Expand translation dimensions for broadcasting with spatial dimensions (H, W)
# t: [B, S, 3] -> [B, S, 1, 1, 3] (broadcast to H and W)
t_expanded = t.unsqueeze(2).unsqueeze(2)
# Add translation (broadcasting automatically applies t to all H×W points per B-S pair)
world_points = rotated_points + t_expanded
return world_points
def unproject_depth_to_world_points(depth, extrinsic, size=560):
'''
Args:
depth: [S, H, W, 1]
extrinsic: [S, 4, 4]
Returns:
world_points: [S, H, W, 3]
'''
camera_points = pano_depth_to_points(depth, original_pano_shape=(size, size//2))
rotated_points = camera_points_to_rotated_points(camera_points, extrinsic[:, :, :3, :3])
world_points = rotated_points_to_world_points(rotated_points, extrinsic[:, :, :3, 3])
return world_points