WJAD / src /wjad /data /se3.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""SE(3) 与 6D 表示之间的转换。
约定:6D = ``[tx, ty, tz, rx, ry, rz]``,rotation 为轴角向量(``angle * axis``)。
平移单位为米;旋转角弧度。
"""
from __future__ import annotations
import numpy as np
import torch
def rotation_matrix_to_axis_angle(R: torch.Tensor | np.ndarray) -> torch.Tensor:
"""3x3 旋转矩阵 -> 轴角向量 ``[3]`` (=angle * axis),支持 batch。
使用 Rodrigues 公式数值反求。
"""
if isinstance(R, np.ndarray):
R = torch.from_numpy(R).float()
if R.dim() == 2:
R = R.unsqueeze(0)
single = True
else:
single = False
trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
cos_theta = ((trace - 1.0) * 0.5).clamp(-1.0 + 1e-7, 1.0 - 1e-7)
theta = torch.acos(cos_theta) # [B]
# 提取轴向量
rx = R[..., 2, 1] - R[..., 1, 2]
ry = R[..., 0, 2] - R[..., 2, 0]
rz = R[..., 1, 0] - R[..., 0, 1]
axis = torch.stack([rx, ry, rz], dim=-1)
sin_theta = torch.sin(theta).clamp_min(1e-7)
axis = axis / (2.0 * sin_theta).unsqueeze(-1)
aa = axis * theta.unsqueeze(-1)
if single:
aa = aa.squeeze(0)
return aa
def axis_angle_to_rotation_matrix(aa: torch.Tensor) -> torch.Tensor:
"""轴角向量 ``[..., 3]`` -> 旋转矩阵 ``[..., 3, 3]``(Rodrigues)。"""
theta = aa.norm(dim=-1, keepdim=True).clamp_min(1e-9) # [..., 1]
axis = aa / theta
x, y, z = axis[..., 0], axis[..., 1], axis[..., 2]
sin_t = torch.sin(theta.squeeze(-1))
cos_t = torch.cos(theta.squeeze(-1))
one_c = 1.0 - cos_t
R = torch.stack(
[
cos_t + x * x * one_c, x * y * one_c - z * sin_t, x * z * one_c + y * sin_t,
y * x * one_c + z * sin_t, cos_t + y * y * one_c, y * z * one_c - x * sin_t,
z * x * one_c - y * sin_t, z * y * one_c + x * sin_t, cos_t + z * z * one_c,
],
dim=-1,
).reshape(*aa.shape[:-1], 3, 3)
return R
def matrix_to_6d(T: torch.Tensor | np.ndarray) -> torch.Tensor:
"""4x4 SE(3) -> 6D ``[tx, ty, tz, rx, ry, rz]``。"""
if isinstance(T, np.ndarray):
T = torch.from_numpy(T).float()
if T.dim() == 2:
T = T.unsqueeze(0)
single = True
else:
single = False
R = T[..., :3, :3]
t = T[..., :3, 3]
aa = rotation_matrix_to_axis_angle(R)
six = torch.cat([t, aa], dim=-1)
if single:
six = six.squeeze(0)
return six
def six_d_to_matrix(six: torch.Tensor) -> torch.Tensor:
"""6D -> 4x4 SE(3)。"""
if six.dim() == 1:
six = six.unsqueeze(0)
single = True
else:
single = False
t = six[..., :3]
aa = six[..., 3:]
R = axis_angle_to_rotation_matrix(aa)
T = torch.zeros(*six.shape[:-1], 4, 4, dtype=six.dtype, device=six.device)
T[..., :3, :3] = R
T[..., :3, 3] = t
T[..., 3, 3] = 1.0
if single:
T = T.squeeze(0)
return T
def invert_se3(T: torch.Tensor) -> torch.Tensor:
"""4x4 SE(3) 逆,``[..., 4, 4]``。"""
R = T[..., :3, :3]
t = T[..., :3, 3:4]
Rt = R.transpose(-2, -1)
inv = torch.zeros_like(T)
inv[..., :3, :3] = Rt
inv[..., :3, 3:4] = -Rt @ t
inv[..., 3, 3] = 1.0
return inv