| from __future__ import annotations |
|
|
| import torch |
| from torch import Tensor |
|
|
|
|
| def _copysign(a: Tensor, b: Tensor) -> Tensor: |
| return torch.where(b < 0, -a.abs(), a.abs()) |
|
|
|
|
| def quaternion_to_matrix(quaternions: Tensor) -> Tensor: |
| if quaternions.shape[-1] != 4: |
| raise ValueError("Quaternion tensor must have shape (..., 4).") |
| quaternions = quaternions / quaternions.norm(dim=-1, keepdim=True).clamp_min(1e-8) |
| r, i, j, k = torch.unbind(quaternions, dim=-1) |
| two_s = 2.0 / (quaternions * quaternions).sum(dim=-1) |
| return torch.stack( |
| ( |
| 1 - two_s * (j * j + k * k), |
| two_s * (i * j - k * r), |
| two_s * (i * k + j * r), |
| two_s * (i * j + k * r), |
| 1 - two_s * (i * i + k * k), |
| two_s * (j * k - i * r), |
| two_s * (i * k - j * r), |
| two_s * (j * k + i * r), |
| 1 - two_s * (i * i + j * j), |
| ), |
| dim=-1, |
| ).reshape(quaternions.shape[:-1] + (3, 3)) |
|
|
|
|
| def _axis_angle_rotation(axis: str, angle: Tensor) -> Tensor: |
| cos = torch.cos(angle) |
| sin = torch.sin(angle) |
| one = torch.ones_like(angle) |
| zero = torch.zeros_like(angle) |
|
|
| if axis == "X": |
| values = (one, zero, zero, zero, cos, -sin, zero, sin, cos) |
| elif axis == "Y": |
| values = (cos, zero, sin, zero, one, zero, -sin, zero, cos) |
| elif axis == "Z": |
| values = (cos, -sin, zero, sin, cos, zero, zero, zero, one) |
| else: |
| raise ValueError(f"Invalid axis {axis}.") |
| return torch.stack(values, dim=-1).reshape(angle.shape + (3, 3)) |
|
|
|
|
| def euler_angles_to_matrix(euler_angles: Tensor, convention: str) -> Tensor: |
| if euler_angles.shape[-1] != 3: |
| raise ValueError("Euler angle tensor must have shape (..., 3).") |
| if len(convention) != 3: |
| raise ValueError("Convention must have three characters.") |
| matrices = [ |
| _axis_angle_rotation(axis, angle) |
| for axis, angle in zip(convention, torch.unbind(euler_angles, dim=-1)) |
| ] |
| result = matrices[0] |
| for matrix in matrices[1:]: |
| result = result @ matrix |
| return result |
|
|
|
|
| def matrix_to_quaternion(matrix: Tensor) -> Tensor: |
| if matrix.shape[-2:] != (3, 3): |
| raise ValueError("Rotation matrix tensor must have shape (..., 3, 3).") |
|
|
| m = matrix |
| q_abs = torch.stack( |
| [ |
| 1.0 + m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2], |
| 1.0 + m[..., 0, 0] - m[..., 1, 1] - m[..., 2, 2], |
| 1.0 - m[..., 0, 0] + m[..., 1, 1] - m[..., 2, 2], |
| 1.0 - m[..., 0, 0] - m[..., 1, 1] + m[..., 2, 2], |
| ], |
| dim=-1, |
| ) |
| q_abs = torch.sqrt(torch.clamp(q_abs, min=0.0)) |
|
|
| quat_by_rijk = torch.stack( |
| [ |
| torch.stack( |
| [q_abs[..., 0] ** 2, m[..., 2, 1] - m[..., 1, 2], m[..., 0, 2] - m[..., 2, 0], m[..., 1, 0] - m[..., 0, 1]], |
| dim=-1, |
| ), |
| torch.stack( |
| [m[..., 2, 1] - m[..., 1, 2], q_abs[..., 1] ** 2, m[..., 1, 0] + m[..., 0, 1], m[..., 0, 2] + m[..., 2, 0]], |
| dim=-1, |
| ), |
| torch.stack( |
| [m[..., 0, 2] - m[..., 2, 0], m[..., 1, 0] + m[..., 0, 1], q_abs[..., 2] ** 2, m[..., 2, 1] + m[..., 1, 2]], |
| dim=-1, |
| ), |
| torch.stack( |
| [m[..., 1, 0] - m[..., 0, 1], m[..., 0, 2] + m[..., 2, 0], m[..., 2, 1] + m[..., 1, 2], q_abs[..., 3] ** 2], |
| dim=-1, |
| ), |
| ], |
| dim=-2, |
| ) |
|
|
| denom = (2.0 * q_abs).clamp_min(0.1)[..., None] |
| quat_candidates = quat_by_rijk / denom |
| indices = q_abs.argmax(dim=-1) |
| one_hot = torch.nn.functional.one_hot(indices, num_classes=4).to(dtype=torch.bool) |
| quaternion = quat_candidates[one_hot].reshape(matrix.shape[:-2] + (4,)) |
| quaternion[..., 0] = _copysign(quaternion[..., 0], quaternion[..., 0]) |
| return quaternion / quaternion.norm(dim=-1, keepdim=True).clamp_min(1e-8) |
|
|