worldgenTest / pytorch3d /transforms.py
chahui's picture
Upload 5 files
a62d7b6 verified
import torch
def matrix_to_quaternion(R: torch.Tensor) -> torch.Tensor:
"""Minimal CPU-only implementation supporting [..., 3, 3] rotation matrices.
Returns quaternion in (x, y, z, w) order.
"""
R = R.float()
single = False
if R.dim() == 2:
R = R.unsqueeze(0)
single = True
m00 = R[..., 0, 0]; m01 = R[..., 0, 1]; m02 = R[..., 0, 2]
m10 = R[..., 1, 0]; m11 = R[..., 1, 1]; m12 = R[..., 1, 2]
m20 = R[..., 2, 0]; m21 = R[..., 2, 1]; m22 = R[..., 2, 2]
trace = m00 + m11 + m22
qw = torch.sqrt(torch.clamp(trace + 1.0, min=1e-8)) / 2.0
qx = torch.sign(m21 - m12) * torch.sqrt(torch.clamp(1.0 + m00 - m11 - m22, min=1e-8)) / 2.0
qy = torch.sign(m02 - m20) * torch.sqrt(torch.clamp(1.0 - m00 + m11 - m22, min=1e-8)) / 2.0
qz = torch.sign(m10 - m01) * torch.sqrt(torch.clamp(1.0 - m00 - m11 + m22, min=1e-8)) / 2.0
q = torch.stack([qx, qy, qz, qw], dim=-1)
if single:
q = q.squeeze(0)
return q