from __future__ import annotations import torch from torch import Tensor def _copysign(a: Tensor, b: Tensor) -> Tensor: return torch.where(b < 0, -a.abs(), a.abs()) def quaternion_to_matrix(quaternions: Tensor) -> Tensor: if quaternions.shape[-1] != 4: raise ValueError("Quaternion tensor must have shape (..., 4).") quaternions = quaternions / quaternions.norm(dim=-1, keepdim=True).clamp_min(1e-8) r, i, j, k = torch.unbind(quaternions, dim=-1) two_s = 2.0 / (quaternions * quaternions).sum(dim=-1) return torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), dim=-1, ).reshape(quaternions.shape[:-1] + (3, 3)) def _axis_angle_rotation(axis: str, angle: Tensor) -> Tensor: cos = torch.cos(angle) sin = torch.sin(angle) one = torch.ones_like(angle) zero = torch.zeros_like(angle) if axis == "X": values = (one, zero, zero, zero, cos, -sin, zero, sin, cos) elif axis == "Y": values = (cos, zero, sin, zero, one, zero, -sin, zero, cos) elif axis == "Z": values = (cos, -sin, zero, sin, cos, zero, zero, zero, one) else: raise ValueError(f"Invalid axis {axis}.") return torch.stack(values, dim=-1).reshape(angle.shape + (3, 3)) def euler_angles_to_matrix(euler_angles: Tensor, convention: str) -> Tensor: if euler_angles.shape[-1] != 3: raise ValueError("Euler angle tensor must have shape (..., 3).") if len(convention) != 3: raise ValueError("Convention must have three characters.") matrices = [ _axis_angle_rotation(axis, angle) for axis, angle in zip(convention, torch.unbind(euler_angles, dim=-1)) ] result = matrices[0] for matrix in matrices[1:]: result = result @ matrix return result def matrix_to_quaternion(matrix: Tensor) -> Tensor: if matrix.shape[-2:] != (3, 3): raise ValueError("Rotation matrix tensor must have shape (..., 3, 3).") m = matrix q_abs = torch.stack( [ 1.0 + m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2], 1.0 + m[..., 0, 0] - m[..., 1, 1] - m[..., 2, 2], 1.0 - m[..., 0, 0] + m[..., 1, 1] - m[..., 2, 2], 1.0 - m[..., 0, 0] - m[..., 1, 1] + m[..., 2, 2], ], dim=-1, ) q_abs = torch.sqrt(torch.clamp(q_abs, min=0.0)) quat_by_rijk = torch.stack( [ torch.stack( [q_abs[..., 0] ** 2, m[..., 2, 1] - m[..., 1, 2], m[..., 0, 2] - m[..., 2, 0], m[..., 1, 0] - m[..., 0, 1]], dim=-1, ), torch.stack( [m[..., 2, 1] - m[..., 1, 2], q_abs[..., 1] ** 2, m[..., 1, 0] + m[..., 0, 1], m[..., 0, 2] + m[..., 2, 0]], dim=-1, ), torch.stack( [m[..., 0, 2] - m[..., 2, 0], m[..., 1, 0] + m[..., 0, 1], q_abs[..., 2] ** 2, m[..., 2, 1] + m[..., 1, 2]], dim=-1, ), torch.stack( [m[..., 1, 0] - m[..., 0, 1], m[..., 0, 2] + m[..., 2, 0], m[..., 2, 1] + m[..., 1, 2], q_abs[..., 3] ** 2], dim=-1, ), ], dim=-2, ) denom = (2.0 * q_abs).clamp_min(0.1)[..., None] quat_candidates = quat_by_rijk / denom indices = q_abs.argmax(dim=-1) one_hot = torch.nn.functional.one_hot(indices, num_classes=4).to(dtype=torch.bool) quaternion = quat_candidates[one_hot].reshape(matrix.shape[:-2] + (4,)) quaternion[..., 0] = _copysign(quaternion[..., 0], quaternion[..., 0]) return quaternion / quaternion.norm(dim=-1, keepdim=True).clamp_min(1e-8)