ml-sharp / src /sharp /utils /linalg.py
amael-apple's picture
Initial commit
c20d7cc
raw
history blame
3.77 kB
"""Contains linear algebra related utility functions.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
import torch
import torch.nn.functional as F
from scipy.spatial.transform import Rotation
def rotation_matrices_from_quaternions(quaternions: torch.Tensor) -> torch.Tensor:
"""Convert batch of quaternions into rotations matrices.
Args:
quaternions: The quaternions convert to matrices.
Returns:
The rotations matrices corresponding to the (normalized) quaternions.
"""
device = quaternions.device
shape = quaternions.shape[:-1]
quaternions = quaternions / torch.linalg.norm(quaternions, dim=-1, keepdim=True)
real_part = quaternions[..., 0]
vector_part = quaternions[..., 1:]
vector_cross = get_cross_product_matrix(vector_part)
real_part = real_part[..., None, None]
matrix_outer = vector_part[..., :, None] * vector_part[..., None, :]
matrix_diag = real_part.square() * eyes(3, shape=shape, device=device)
matrix_cross_1 = 2 * real_part * vector_cross
matrix_cross_2 = vector_cross @ vector_cross
return matrix_outer + matrix_diag + matrix_cross_1 + matrix_cross_2
def quaternions_from_rotation_matrices(matrices: torch.Tensor) -> torch.Tensor:
"""Convert batch of rotation matrices to quaternions.
Args:
matrices: The matrices to convert to quaternions.
Returns:
The quaternions corresponding to the rotation matrices.
Note: this operation is not differentiable and will be performed on the CPU.
"""
if not matrices.shape[-2:] == (3, 3):
raise ValueError(f"matrices have invalid shape {matrices.shape}")
matrices_np = matrices.detach().cpu().numpy()
quaternions_np = Rotation.from_matrix(matrices_np.reshape(-1, 3, 3)).as_quat()
# We use a convention where the w component is at the start of the quaternion.
quaternions_np = quaternions_np[:, [3, 0, 1, 2]]
quaternions_np = quaternions_np.reshape(matrices_np.shape[:-2] + (4,))
return torch.as_tensor(quaternions_np, device=matrices.device, dtype=matrices.dtype)
def get_cross_product_matrix(vectors: torch.Tensor) -> torch.Tensor:
"""Generate cross product matrix for vector exterior product."""
if not vectors.shape[-1] == 3:
raise ValueError("Only 3-dimensional vectors are supported")
device = vectors.device
shape = vectors.shape[:-1]
unit_basis = eyes(3, shape=shape, device=device)
# We compute the matrix by multiplying each column of unit_basis with the
# corresponding vector.
return torch.cross(vectors[..., :, None], unit_basis, dim=-2)
def eyes(
dim: int, shape: tuple[int, ...], device: torch.device | str | None = None
) -> torch.Tensor:
"""Create batch of identity matrices."""
return torch.eye(dim, device=device).broadcast_to(shape + (dim, dim)).clone()
def quaternion_product(q1, q2):
"""Compute dot product between two quaternions."""
real_1 = q1[..., :1]
real_2 = q2[..., :1]
vector_1 = q1[..., 1:]
vector_2 = q2[..., 1:]
real_out = real_1 * real_2 - (vector_1 * vector_2).sum(dim=-1, keepdim=True)
vector_out = real_1 * vector_2 + real_2 * vector_1 + torch.cross(vector_1, vector_2)
return torch.concatenate([real_out, vector_out], dim=-1)
def quaternion_conj(q):
"""Get conjugate of a quaternion."""
real = q[..., :1]
vector = q[..., 1:]
return torch.concatenate([real, -vector], dim=-1)
def project(u: torch.Tensor, basis: torch.Tensor) -> torch.Tensor:
"""Project tensor u to unit basis a."""
unit_u = F.normalize(u, dim=-1)
inner_prod = (unit_u * basis).sum(dim=-1, keepdim=True)
return inner_prod * u