import torch def matrix_to_quaternion(R: torch.Tensor) -> torch.Tensor: """Minimal CPU-only implementation supporting [..., 3, 3] rotation matrices. Returns quaternion in (x, y, z, w) order. """ R = R.float() single = False if R.dim() == 2: R = R.unsqueeze(0) single = True m00 = R[..., 0, 0]; m01 = R[..., 0, 1]; m02 = R[..., 0, 2] m10 = R[..., 1, 0]; m11 = R[..., 1, 1]; m12 = R[..., 1, 2] m20 = R[..., 2, 0]; m21 = R[..., 2, 1]; m22 = R[..., 2, 2] trace = m00 + m11 + m22 qw = torch.sqrt(torch.clamp(trace + 1.0, min=1e-8)) / 2.0 qx = torch.sign(m21 - m12) * torch.sqrt(torch.clamp(1.0 + m00 - m11 - m22, min=1e-8)) / 2.0 qy = torch.sign(m02 - m20) * torch.sqrt(torch.clamp(1.0 - m00 + m11 - m22, min=1e-8)) / 2.0 qz = torch.sign(m10 - m01) * torch.sqrt(torch.clamp(1.0 - m00 - m11 + m22, min=1e-8)) / 2.0 q = torch.stack([qx, qy, qz, qw], dim=-1) if single: q = q.squeeze(0) return q