File size: 3,232 Bytes
0cfefd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""SE(3) 与 6D 表示之间的转换。

约定:6D = ``[tx, ty, tz, rx, ry, rz]``,rotation 为轴角向量(``angle * axis``)。
平移单位为米;旋转角弧度。
"""

from __future__ import annotations

import numpy as np
import torch


def rotation_matrix_to_axis_angle(R: torch.Tensor | np.ndarray) -> torch.Tensor:
    """3x3 旋转矩阵 -> 轴角向量 ``[3]`` (=angle * axis),支持 batch。

    使用 Rodrigues 公式数值反求。
    """
    if isinstance(R, np.ndarray):
        R = torch.from_numpy(R).float()
    if R.dim() == 2:
        R = R.unsqueeze(0)
        single = True
    else:
        single = False

    trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
    cos_theta = ((trace - 1.0) * 0.5).clamp(-1.0 + 1e-7, 1.0 - 1e-7)
    theta = torch.acos(cos_theta)  # [B]

    # 提取轴向量
    rx = R[..., 2, 1] - R[..., 1, 2]
    ry = R[..., 0, 2] - R[..., 2, 0]
    rz = R[..., 1, 0] - R[..., 0, 1]
    axis = torch.stack([rx, ry, rz], dim=-1)
    sin_theta = torch.sin(theta).clamp_min(1e-7)
    axis = axis / (2.0 * sin_theta).unsqueeze(-1)

    aa = axis * theta.unsqueeze(-1)
    if single:
        aa = aa.squeeze(0)
    return aa


def axis_angle_to_rotation_matrix(aa: torch.Tensor) -> torch.Tensor:
    """轴角向量 ``[..., 3]`` -> 旋转矩阵 ``[..., 3, 3]``(Rodrigues)。"""
    theta = aa.norm(dim=-1, keepdim=True).clamp_min(1e-9)  # [..., 1]
    axis = aa / theta
    x, y, z = axis[..., 0], axis[..., 1], axis[..., 2]
    sin_t = torch.sin(theta.squeeze(-1))
    cos_t = torch.cos(theta.squeeze(-1))
    one_c = 1.0 - cos_t

    R = torch.stack(
        [
            cos_t + x * x * one_c, x * y * one_c - z * sin_t, x * z * one_c + y * sin_t,
            y * x * one_c + z * sin_t, cos_t + y * y * one_c, y * z * one_c - x * sin_t,
            z * x * one_c - y * sin_t, z * y * one_c + x * sin_t, cos_t + z * z * one_c,
        ],
        dim=-1,
    ).reshape(*aa.shape[:-1], 3, 3)
    return R


def matrix_to_6d(T: torch.Tensor | np.ndarray) -> torch.Tensor:
    """4x4 SE(3) -> 6D ``[tx, ty, tz, rx, ry, rz]``。"""
    if isinstance(T, np.ndarray):
        T = torch.from_numpy(T).float()
    if T.dim() == 2:
        T = T.unsqueeze(0)
        single = True
    else:
        single = False

    R = T[..., :3, :3]
    t = T[..., :3, 3]
    aa = rotation_matrix_to_axis_angle(R)
    six = torch.cat([t, aa], dim=-1)
    if single:
        six = six.squeeze(0)
    return six


def six_d_to_matrix(six: torch.Tensor) -> torch.Tensor:
    """6D -> 4x4 SE(3)。"""
    if six.dim() == 1:
        six = six.unsqueeze(0)
        single = True
    else:
        single = False
    t = six[..., :3]
    aa = six[..., 3:]
    R = axis_angle_to_rotation_matrix(aa)
    T = torch.zeros(*six.shape[:-1], 4, 4, dtype=six.dtype, device=six.device)
    T[..., :3, :3] = R
    T[..., :3, 3] = t
    T[..., 3, 3] = 1.0
    if single:
        T = T.squeeze(0)
    return T


def invert_se3(T: torch.Tensor) -> torch.Tensor:
    """4x4 SE(3) 逆,``[..., 4, 4]``。"""
    R = T[..., :3, :3]
    t = T[..., :3, 3:4]
    Rt = R.transpose(-2, -1)
    inv = torch.zeros_like(T)
    inv[..., :3, :3] = Rt
    inv[..., :3, 3:4] = -Rt @ t
    inv[..., 3, 3] = 1.0
    return inv