AniGen / third_parties /dsine /utils /rotation.py
Yihua7's picture
Initial commit: AniGen - Animatable 3D Generation
6b92ff7
import numpy as np
import torch
import torch.nn.functional as F
def get_r_yaw(yaw):
""" rotation around the y-axis
"""
return np.array([
[np.cos(yaw), 0, np.sin(yaw) ],
[0, 1, 0 ],
[-np.sin(yaw), 0, np.cos(yaw) ],
], dtype=np.float32
)
def get_r_pitch(pitch):
""" rotation around the x-axis
"""
return np.array([
[1, 0, 0 ],
[0, np.cos(pitch), -np.sin(pitch) ],
[0, np.sin(pitch), np.cos(pitch) ]
], dtype=np.float32
)
def get_r_roll(roll):
""" rotation around the z-axis
"""
return np.array([
[np.cos(roll), -np.sin(roll), 0 ],
[np.sin(roll), np.cos(roll), 0 ],
[0, 0, 1 ]
], dtype=np.float32
)
def get_R(yaw, pitch, roll):
""" rotation matrix from yaw, pitch, roll
"""
R_yaw = get_r_yaw(yaw)
R_pitch = get_r_pitch(pitch)
R_roll = get_r_roll(roll)
R_yaw_inv = get_r_yaw(-yaw)
R_pitch_inv = get_r_pitch(-pitch)
R_roll_inv = get_r_roll(-roll)
R = R_pitch @ R_roll @ R_yaw
R_inv = R_yaw_inv @ R_roll_inv @ R_pitch_inv
return R, R_inv
# NOTE: the code below is copied from PyTorch3D
# (https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py)
# See the license at https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as axis/angle to quaternions.
Args:
axis_angle: Rotations given as a vector in axis angle form,
as a tensor of shape (..., 3), where the magnitude is
the angle turned anticlockwise in radians around the
vector's direction.
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
half_angles = angles * 0.5
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
quaternions = torch.cat(
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
)
return quaternions
# NOTE: the code below is copied from PyTorch3D
# (https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py)
# See the license at https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
r, i, j, k = torch.unbind(quaternions, -1)
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
# NOTE: the code below is copied from PyTorch3D
# (https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py)
# See the license at https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as axis/angle to rotation matrices.
Args:
axis_angle: Rotations given as a vector in axis angle form,
as a tensor of shape (..., 3), where the magnitude is
the angle turned anticlockwise in radians around the
vector's direction.
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
# NOTE: the code below is copied from PyTorch3D
# (https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py)
# See the license at https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
"""
Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given.
Args:
axis: Axis label "X" or "Y or "Z".
angle: any shape tensor of Euler angles in radians
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
elif axis == "Y":
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
elif axis == "Z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
else:
raise ValueError("letter must be either X, Y or Z.")
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
# NOTE: the code below is copied from PyTorch3D
# (https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py)
# See the license at https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
"""
Convert rotations given as Euler angles in radians to rotation matrices.
Args:
euler_angles: Euler angles in radians as tensor of shape (..., 3).
convention: Convention string of three uppercase letters from
{"X", "Y", and "Z"}.
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
raise ValueError("Invalid input euler angles.")
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = [
_axis_angle_rotation(c, e)
for c, e in zip(convention, torch.unbind(euler_angles, -1))
]
# return functools.reduce(torch.matmul, matrices)
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])