File size: 982 Bytes
a62d7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch

def matrix_to_quaternion(R: torch.Tensor) -> torch.Tensor:
    """Minimal CPU-only implementation supporting [..., 3, 3] rotation matrices.
    Returns quaternion in (x, y, z, w) order.
    """
    R = R.float()
    single = False
    if R.dim() == 2:
        R = R.unsqueeze(0)
        single = True
    m00 = R[..., 0, 0]; m01 = R[..., 0, 1]; m02 = R[..., 0, 2]
    m10 = R[..., 1, 0]; m11 = R[..., 1, 1]; m12 = R[..., 1, 2]
    m20 = R[..., 2, 0]; m21 = R[..., 2, 1]; m22 = R[..., 2, 2]
    trace = m00 + m11 + m22

    qw = torch.sqrt(torch.clamp(trace + 1.0, min=1e-8)) / 2.0
    qx = torch.sign(m21 - m12) * torch.sqrt(torch.clamp(1.0 + m00 - m11 - m22, min=1e-8)) / 2.0
    qy = torch.sign(m02 - m20) * torch.sqrt(torch.clamp(1.0 - m00 + m11 - m22, min=1e-8)) / 2.0
    qz = torch.sign(m10 - m01) * torch.sqrt(torch.clamp(1.0 - m00 - m11 + m22, min=1e-8)) / 2.0
    q = torch.stack([qx, qy, qz, qw], dim=-1)
    if single:
        q = q.squeeze(0)
    return q