"""SE(3) 与 6D 表示之间的转换。 约定:6D = ``[tx, ty, tz, rx, ry, rz]``,rotation 为轴角向量(``angle * axis``)。 平移单位为米;旋转角弧度。 """ from __future__ import annotations import numpy as np import torch def rotation_matrix_to_axis_angle(R: torch.Tensor | np.ndarray) -> torch.Tensor: """3x3 旋转矩阵 -> 轴角向量 ``[3]`` (=angle * axis),支持 batch。 使用 Rodrigues 公式数值反求。 """ if isinstance(R, np.ndarray): R = torch.from_numpy(R).float() if R.dim() == 2: R = R.unsqueeze(0) single = True else: single = False trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] cos_theta = ((trace - 1.0) * 0.5).clamp(-1.0 + 1e-7, 1.0 - 1e-7) theta = torch.acos(cos_theta) # [B] # 提取轴向量 rx = R[..., 2, 1] - R[..., 1, 2] ry = R[..., 0, 2] - R[..., 2, 0] rz = R[..., 1, 0] - R[..., 0, 1] axis = torch.stack([rx, ry, rz], dim=-1) sin_theta = torch.sin(theta).clamp_min(1e-7) axis = axis / (2.0 * sin_theta).unsqueeze(-1) aa = axis * theta.unsqueeze(-1) if single: aa = aa.squeeze(0) return aa def axis_angle_to_rotation_matrix(aa: torch.Tensor) -> torch.Tensor: """轴角向量 ``[..., 3]`` -> 旋转矩阵 ``[..., 3, 3]``(Rodrigues)。""" theta = aa.norm(dim=-1, keepdim=True).clamp_min(1e-9) # [..., 1] axis = aa / theta x, y, z = axis[..., 0], axis[..., 1], axis[..., 2] sin_t = torch.sin(theta.squeeze(-1)) cos_t = torch.cos(theta.squeeze(-1)) one_c = 1.0 - cos_t R = torch.stack( [ cos_t + x * x * one_c, x * y * one_c - z * sin_t, x * z * one_c + y * sin_t, y * x * one_c + z * sin_t, cos_t + y * y * one_c, y * z * one_c - x * sin_t, z * x * one_c - y * sin_t, z * y * one_c + x * sin_t, cos_t + z * z * one_c, ], dim=-1, ).reshape(*aa.shape[:-1], 3, 3) return R def matrix_to_6d(T: torch.Tensor | np.ndarray) -> torch.Tensor: """4x4 SE(3) -> 6D ``[tx, ty, tz, rx, ry, rz]``。""" if isinstance(T, np.ndarray): T = torch.from_numpy(T).float() if T.dim() == 2: T = T.unsqueeze(0) single = True else: single = False R = T[..., :3, :3] t = T[..., :3, 3] aa = rotation_matrix_to_axis_angle(R) six = torch.cat([t, aa], dim=-1) if single: six = six.squeeze(0) return six def six_d_to_matrix(six: torch.Tensor) -> torch.Tensor: """6D -> 4x4 SE(3)。""" if six.dim() == 1: six = six.unsqueeze(0) single = True else: single = False t = six[..., :3] aa = six[..., 3:] R = axis_angle_to_rotation_matrix(aa) T = torch.zeros(*six.shape[:-1], 4, 4, dtype=six.dtype, device=six.device) T[..., :3, :3] = R T[..., :3, 3] = t T[..., 3, 3] = 1.0 if single: T = T.squeeze(0) return T def invert_se3(T: torch.Tensor) -> torch.Tensor: """4x4 SE(3) 逆,``[..., 4, 4]``。""" R = T[..., :3, :3] t = T[..., :3, 3:4] Rt = R.transpose(-2, -1) inv = torch.zeros_like(T) inv[..., :3, :3] = Rt inv[..., :3, 3:4] = -Rt @ t inv[..., 3, 3] = 1.0 return inv