| """SE(3) 与 6D 表示之间的转换。 |
| |
| 约定:6D = ``[tx, ty, tz, rx, ry, rz]``,rotation 为轴角向量(``angle * axis``)。 |
| 平移单位为米;旋转角弧度。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| def rotation_matrix_to_axis_angle(R: torch.Tensor | np.ndarray) -> torch.Tensor: |
| """3x3 旋转矩阵 -> 轴角向量 ``[3]`` (=angle * axis),支持 batch。 |
| |
| 使用 Rodrigues 公式数值反求。 |
| """ |
| if isinstance(R, np.ndarray): |
| R = torch.from_numpy(R).float() |
| if R.dim() == 2: |
| R = R.unsqueeze(0) |
| single = True |
| else: |
| single = False |
|
|
| trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] |
| cos_theta = ((trace - 1.0) * 0.5).clamp(-1.0 + 1e-7, 1.0 - 1e-7) |
| theta = torch.acos(cos_theta) |
|
|
| |
| rx = R[..., 2, 1] - R[..., 1, 2] |
| ry = R[..., 0, 2] - R[..., 2, 0] |
| rz = R[..., 1, 0] - R[..., 0, 1] |
| axis = torch.stack([rx, ry, rz], dim=-1) |
| sin_theta = torch.sin(theta).clamp_min(1e-7) |
| axis = axis / (2.0 * sin_theta).unsqueeze(-1) |
|
|
| aa = axis * theta.unsqueeze(-1) |
| if single: |
| aa = aa.squeeze(0) |
| return aa |
|
|
|
|
| def axis_angle_to_rotation_matrix(aa: torch.Tensor) -> torch.Tensor: |
| """轴角向量 ``[..., 3]`` -> 旋转矩阵 ``[..., 3, 3]``(Rodrigues)。""" |
| theta = aa.norm(dim=-1, keepdim=True).clamp_min(1e-9) |
| axis = aa / theta |
| x, y, z = axis[..., 0], axis[..., 1], axis[..., 2] |
| sin_t = torch.sin(theta.squeeze(-1)) |
| cos_t = torch.cos(theta.squeeze(-1)) |
| one_c = 1.0 - cos_t |
|
|
| R = torch.stack( |
| [ |
| cos_t + x * x * one_c, x * y * one_c - z * sin_t, x * z * one_c + y * sin_t, |
| y * x * one_c + z * sin_t, cos_t + y * y * one_c, y * z * one_c - x * sin_t, |
| z * x * one_c - y * sin_t, z * y * one_c + x * sin_t, cos_t + z * z * one_c, |
| ], |
| dim=-1, |
| ).reshape(*aa.shape[:-1], 3, 3) |
| return R |
|
|
|
|
| def matrix_to_6d(T: torch.Tensor | np.ndarray) -> torch.Tensor: |
| """4x4 SE(3) -> 6D ``[tx, ty, tz, rx, ry, rz]``。""" |
| if isinstance(T, np.ndarray): |
| T = torch.from_numpy(T).float() |
| if T.dim() == 2: |
| T = T.unsqueeze(0) |
| single = True |
| else: |
| single = False |
|
|
| R = T[..., :3, :3] |
| t = T[..., :3, 3] |
| aa = rotation_matrix_to_axis_angle(R) |
| six = torch.cat([t, aa], dim=-1) |
| if single: |
| six = six.squeeze(0) |
| return six |
|
|
|
|
| def six_d_to_matrix(six: torch.Tensor) -> torch.Tensor: |
| """6D -> 4x4 SE(3)。""" |
| if six.dim() == 1: |
| six = six.unsqueeze(0) |
| single = True |
| else: |
| single = False |
| t = six[..., :3] |
| aa = six[..., 3:] |
| R = axis_angle_to_rotation_matrix(aa) |
| T = torch.zeros(*six.shape[:-1], 4, 4, dtype=six.dtype, device=six.device) |
| T[..., :3, :3] = R |
| T[..., :3, 3] = t |
| T[..., 3, 3] = 1.0 |
| if single: |
| T = T.squeeze(0) |
| return T |
|
|
|
|
| def invert_se3(T: torch.Tensor) -> torch.Tensor: |
| """4x4 SE(3) 逆,``[..., 4, 4]``。""" |
| R = T[..., :3, :3] |
| t = T[..., :3, 3:4] |
| Rt = R.transpose(-2, -1) |
| inv = torch.zeros_like(T) |
| inv[..., :3, :3] = Rt |
| inv[..., :3, 3:4] = -Rt @ t |
| inv[..., 3, 3] = 1.0 |
| return inv |
|
|