lsnu's picture
Add files using upload-large-folder tool
6fa1956 verified
from __future__ import annotations
import torch
from torch import Tensor
def _copysign(a: Tensor, b: Tensor) -> Tensor:
return torch.where(b < 0, -a.abs(), a.abs())
def quaternion_to_matrix(quaternions: Tensor) -> Tensor:
if quaternions.shape[-1] != 4:
raise ValueError("Quaternion tensor must have shape (..., 4).")
quaternions = quaternions / quaternions.norm(dim=-1, keepdim=True).clamp_min(1e-8)
r, i, j, k = torch.unbind(quaternions, dim=-1)
two_s = 2.0 / (quaternions * quaternions).sum(dim=-1)
return torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
dim=-1,
).reshape(quaternions.shape[:-1] + (3, 3))
def _axis_angle_rotation(axis: str, angle: Tensor) -> Tensor:
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
values = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
elif axis == "Y":
values = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
elif axis == "Z":
values = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
else:
raise ValueError(f"Invalid axis {axis}.")
return torch.stack(values, dim=-1).reshape(angle.shape + (3, 3))
def euler_angles_to_matrix(euler_angles: Tensor, convention: str) -> Tensor:
if euler_angles.shape[-1] != 3:
raise ValueError("Euler angle tensor must have shape (..., 3).")
if len(convention) != 3:
raise ValueError("Convention must have three characters.")
matrices = [
_axis_angle_rotation(axis, angle)
for axis, angle in zip(convention, torch.unbind(euler_angles, dim=-1))
]
result = matrices[0]
for matrix in matrices[1:]:
result = result @ matrix
return result
def matrix_to_quaternion(matrix: Tensor) -> Tensor:
if matrix.shape[-2:] != (3, 3):
raise ValueError("Rotation matrix tensor must have shape (..., 3, 3).")
m = matrix
q_abs = torch.stack(
[
1.0 + m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2],
1.0 + m[..., 0, 0] - m[..., 1, 1] - m[..., 2, 2],
1.0 - m[..., 0, 0] + m[..., 1, 1] - m[..., 2, 2],
1.0 - m[..., 0, 0] - m[..., 1, 1] + m[..., 2, 2],
],
dim=-1,
)
q_abs = torch.sqrt(torch.clamp(q_abs, min=0.0))
quat_by_rijk = torch.stack(
[
torch.stack(
[q_abs[..., 0] ** 2, m[..., 2, 1] - m[..., 1, 2], m[..., 0, 2] - m[..., 2, 0], m[..., 1, 0] - m[..., 0, 1]],
dim=-1,
),
torch.stack(
[m[..., 2, 1] - m[..., 1, 2], q_abs[..., 1] ** 2, m[..., 1, 0] + m[..., 0, 1], m[..., 0, 2] + m[..., 2, 0]],
dim=-1,
),
torch.stack(
[m[..., 0, 2] - m[..., 2, 0], m[..., 1, 0] + m[..., 0, 1], q_abs[..., 2] ** 2, m[..., 2, 1] + m[..., 1, 2]],
dim=-1,
),
torch.stack(
[m[..., 1, 0] - m[..., 0, 1], m[..., 0, 2] + m[..., 2, 0], m[..., 2, 1] + m[..., 1, 2], q_abs[..., 3] ** 2],
dim=-1,
),
],
dim=-2,
)
denom = (2.0 * q_abs).clamp_min(0.1)[..., None]
quat_candidates = quat_by_rijk / denom
indices = q_abs.argmax(dim=-1)
one_hot = torch.nn.functional.one_hot(indices, num_classes=4).to(dtype=torch.bool)
quaternion = quat_candidates[one_hot].reshape(matrix.shape[:-2] + (4,))
quaternion[..., 0] = _copysign(quaternion[..., 0], quaternion[..., 0])
return quaternion / quaternion.norm(dim=-1, keepdim=True).clamp_min(1e-8)