zirobtc's picture
Upload folder using huggingface_hub
fbb20ff verified
import copy
import math
from typing import List, Optional
import numpy as np
import torch
def identity_mat(x=None, device="cpu", is_numpy=False):
if x is not None:
if isinstance(x, torch.Tensor):
mat = torch.eye(4, device=device)
mat = mat.repeat(x.shape[:-2] + (1, 1))
elif isinstance(x, np.ndarray):
mat = np.eye(4, dtype=np.float32)
if x is not None:
for _ in range(len(x.shape) - 2):
mat = mat[None]
mat = np.tile(mat, x.shape[:-2] + (1, 1))
else:
raise ValueError
else:
# (4, 4)
if is_numpy:
mat = np.eye(4, dtype=np.float32)
else:
mat = torch.eye(4, device=device)
return mat
def vec2mat(vec):
"""_summary_
Args:
vec (tensor): [12], pos, forward, up and right
Returns:
mat_world(tensor): [4, 4]
"""
# Assume bs = 1
v = np.tile(np.array([[0, 0, 0, 1]]), (1, 1))
if isinstance(vec, torch.Tensor):
v = torch.tensor(
v,
device=vec.device,
dtype=vec.dtype,
)
pos = vec[:3]
forward = vec[3:6]
up = vec[6:9]
right = vec[9:12]
if isinstance(vec, torch.Tensor):
mat_world = torch.stack([right, up, forward, pos], dim=-1)
mat_world = torch.cat([mat_world, v], dim=-2)
elif isinstance(vec, np.ndarray):
mat_world = np.stack([right, up, forward, pos], axis=-1)
mat_world = np.concatenate([mat_world, v], axis=-2)
else:
raise ValueError
mat_world = normalized_matrix(mat_world)
return mat_world
def mat2vec(mat):
"""_summary_
Args:
mat(tensor): [4, 4]
Returns:
vec (tensor): [12], pos, forward, up and right
"""
# Assume bs = 1
pos = mat[:-1, 3]
forward = normalized(mat[:-1, 2])
up = normalized(mat[:-1, 1])
right = normalized(mat[:-1, 0])
if isinstance(mat, torch.Tensor):
vec = torch.cat((pos, forward, up, right))
elif isinstance(mat, np.ndarray):
vec = np.concatenate((pos, forward, up, right))
else:
raise ValueError
return vec
def vec2mat_batch(vec):
"""_summary_
Args:
vec (tensor): [B, 12], pos, forward, up and right
Returns:
mat_world(tensor): [B, 4, 4]
"""
# Assume bs = 1
v = np.tile(np.array([[0, 0, 0, 1]], dtype=np.float32), (vec.shape[0], 1, 1))
if isinstance(vec, torch.Tensor):
v = torch.tensor(
v,
device=vec.device,
dtype=vec.dtype,
)
pos = vec[..., :3]
forward = vec[..., 3:6]
up = vec[..., 6:9]
right = vec[..., 9:12]
if isinstance(vec, torch.Tensor):
mat_world = torch.stack([right, up, forward, pos], dim=-1)
mat_world = torch.cat([mat_world, v], dim=-2)
elif isinstance(vec, np.ndarray):
mat_world = np.stack([right, up, forward, pos], axis=-1)
mat_world = np.concatenate([mat_world, v], axis=-2)
else:
raise ValueError
mat_world = normalized_matrix(mat_world)
return mat_world
def rotmat2tan_norm(mat):
"""_summary_
Args:
mat(tensor): [B, 3, 3]
Returns:
vec (tensor): [B, 6], tan norm
"""
if isinstance(mat, np.ndarray):
tan = np.zeros_like(mat[..., 2])
norm = np.zeros_like(mat[..., 0])
elif isinstance(mat, torch.Tensor):
tan = torch.zeros_like(mat[..., 2])
norm = torch.zeros_like(mat[..., 0])
else:
raise ValueError
tan[...] = mat[..., 2, ::-1]
tan[..., -1] *= -1
norm[...] = mat[..., 0, ::-1]
norm[..., -1] *= -1
if isinstance(mat, np.ndarray):
tan_norm = np.concatenate((tan, norm), axis=-1)
elif isinstance(mat, torch.Tensor):
tan_norm = torch.cat((tan, norm), dim=-1)
else:
raise ValueError
return tan_norm
def mat2tan_norm(mat):
"""_summary_
Args:
mat(tensor): [B, 4, 4]
Returns:
vec (tensor): [B, 6], tan norm
"""
rot_mat = mat[..., :-1, :-1]
return rotmat2tan_norm(rot_mat)
def rotmat2tan_norm(mat):
"""_summary_
Args:
mat(tensor): [B, 3, 3]
Returns:
vec (tensor): [B, 6], tan norm
"""
if isinstance(mat, np.ndarray):
tan = np.zeros_like(mat[..., 2])
norm = np.zeros_like(mat[..., 0])
tan[...] = mat[..., 2, ::-1]
norm[...] = mat[..., 0, ::-1]
elif isinstance(mat, torch.Tensor):
tan = torch.zeros_like(mat[..., 2])
norm = torch.zeros_like(mat[..., 0])
tan[...] = torch.flip(mat[..., 2], dims=[-1])
norm[...] = torch.flip(mat[..., 0], dims=[-1])
else:
raise ValueError
tan[..., -1] *= -1
norm[..., -1] *= -1
if isinstance(mat, np.ndarray):
tan_norm = np.concatenate((tan, norm), axis=-1)
elif isinstance(mat, torch.Tensor):
tan_norm = torch.cat((tan, norm), dim=-1)
else:
raise ValueError
return tan_norm
def tan_norm2rotmat(tan_norm):
"""_summary_
Args:
mat(tensor): [B, 6]
Returns:
vec (tensor): [B, 3]
"""
tan = copy.deepcopy(tan_norm[..., :3])
norm = copy.deepcopy(tan_norm[..., 3:])
tan[..., -1] *= -1
norm[..., -1] *= -1
if isinstance(tan_norm, np.ndarray):
rotmat = np.zeros(tan_norm.shape[:-1] + (3, 3))
tan = tan[..., ::-1]
norm = norm[..., ::-1]
other = np.cross(tan, norm)
elif isinstance(tan_norm, torch.Tensor):
rotmat = torch.zeros(tan_norm.shape[:-1] + (3, 3), device=tan_norm.device)
tan = torch.flip(tan, dims=[-1])
norm = torch.flip(norm, dims=[-1])
other = torch.cross(tan, norm)
else:
raise ValueError
rotmat[..., 2, :] = tan
rotmat[..., 0, :] = norm
rotmat[..., 1, :] = other
return rotmat
def rotmat332vec_batch(mat):
"""_summary_
Args:
mat(tensor): [B, 3, 3]
Returns:
vec (tensor): [B, 6], forward, up, right
"""
# Assume bs = 1
mat = normalized_matrix(mat)
forward = mat[..., :, 2]
up = mat[..., :, 1]
right = mat[..., :, 0]
if isinstance(mat, torch.Tensor):
vec = torch.cat((forward, up, right), dim=-1)
elif isinstance(mat, np.ndarray):
vec = np.concatenate((forward, up, right), axis=-1)
else:
raise ValueError
return vec
def rotmat2vec_batch(mat):
"""_summary_
Args:
mat(tensor): [B, 4, 4]
Returns:
vec (tensor): [B, 9], forward, up, right
"""
# Assume bs = 1
mat = normalized_matrix(mat)
forward = mat[..., :-1, 2]
up = mat[..., :-1, 1]
right = mat[..., :-1, 0]
if isinstance(mat, torch.Tensor):
vec = torch.cat((forward, up, right), dim=-1)
elif isinstance(mat, np.ndarray):
vec = np.concatenate((forward, up, right), axis=-1)
else:
raise ValueError
return vec
def mat2vec_batch(mat):
"""_summary_
Args:
mat(tensor): [B, 4, 4]
Returns:
vec (tensor): [B, 12], pos, forward, up and right
"""
# Assume bs = 1
mat = normalized_matrix(mat)
pos = mat[..., :-1, 3]
forward = mat[..., :-1, 2]
up = mat[..., :-1, 1]
right = mat[..., :-1, 0]
if isinstance(mat, torch.Tensor):
vec = torch.cat((pos, forward, up, right), dim=-1)
elif isinstance(mat, np.ndarray):
vec = np.concatenate((pos, forward, up, right), axis=-1)
else:
raise ValueError
return vec
def mat2pose_batch(mat, returnvel=True):
"""_summary_
Args:
mat(tensor): [B, 4, 4]
Returns:
vec (tensor): [B, 12], pos, forward, up, zeros
"""
# Assume bs = 1
mat = normalized_matrix(mat)
pos = mat[..., :-1, 3]
forward = mat[..., :-1, 2]
up = mat[..., :-1, 1]
if isinstance(mat, torch.Tensor):
if returnvel:
vel = torch.zeros_like(up)
vec = torch.cat((pos, forward, up, vel), dim=-1)
else:
vec = torch.cat((pos, forward, up), dim=-1)
elif isinstance(mat, np.ndarray):
if returnvel:
vel = np.zeros_like(up)
vec = np.concatenate((pos, forward, up, vel), axis=-1)
else:
vec = np.concatenate((pos, forward, up), axis=-1)
else:
raise ValueError
return vec
def get_mat_BinA(matCtoA, matCtoB):
"""
given matrix of the same object in two coordinate A and B,
return matrix B in the coordinate of A
Args:
matCtoA (tensor): [4, 4] world matrix
matCtoB (tensor): [4, 4] world matrix
"""
if isinstance(matCtoA, torch.Tensor):
matCtoB_inv = torch.inverse(matCtoB)
elif isinstance(matCtoA, np.ndarray):
matCtoB_inv = np.linalg.inv(matCtoB)
else:
raise ValueError
matCtoB_inv = normalized_matrix(matCtoB_inv)
if isinstance(matCtoA, torch.Tensor):
mat_BtoA = torch.matmul(matCtoA, matCtoB_inv)
elif isinstance(matCtoA, np.ndarray):
mat_BtoA = np.matmul(matCtoA, matCtoB_inv)
mat_BtoA = normalized_matrix(mat_BtoA)
return mat_BtoA
def get_mat_BtoA(matA, matB):
"""
return matrix B in the coordinate of A
Args:
matA (tensor): [4, 4] world matrix
matB (tensor): [4, 4] world matrix
"""
if isinstance(matA, torch.Tensor):
matA_inv = torch.inverse(matA)
elif isinstance(matA, np.ndarray):
matA_inv = np.linalg.inv(matA)
else:
raise ValueError
matA_inv = normalized_matrix(matA_inv)
if isinstance(matA, torch.Tensor):
mat_BtoA = torch.matmul(matA_inv, matB)
elif isinstance(matA, np.ndarray):
mat_BtoA = np.matmul(matA_inv, matB)
mat_BtoA = normalized_matrix(mat_BtoA)
return mat_BtoA
def get_mat_BfromA(matA, matBtoA):
"""
return world matrix B given matrix A and mat B realtive to A
Args:
matA (_type_): [4, 4] world matrix
matBtoA (_type_): [4, 4] matrix B relative to A
"""
if isinstance(matA, torch.Tensor):
matB = torch.matmul(matA, matBtoA)
if isinstance(matA, np.ndarray):
matB = np.matmul(matA, matBtoA)
matB = normalized_matrix(matB)
return matB
def get_relative_position_to(pos, mat):
"""_summary_
Args:
pos (_type_): [N, M, 3] or [N, 3]
mat (_type_): [N, 4, 4] or [4, 4]
Returns:
_type_: _description_
"""
if isinstance(mat, torch.Tensor):
mat_inv = torch.inverse(mat)
elif isinstance(mat, np.ndarray):
mat_inv = np.linalg.inv(mat)
else:
raise ValueError
mat_inv = normalized_matrix(mat_inv)
if isinstance(mat, torch.Tensor):
rot_pos = torch.matmul(mat_inv[..., :-1, :-1], pos.transpose(-1, -2)).transpose(
-1, -2
)
elif isinstance(mat, np.ndarray):
rot_pos = np.matmul(mat_inv[..., :-1, :-1], pos.swapaxes(-1, -2)).swapaxes(
-1, -2
)
world_pos = rot_pos + mat_inv[..., None, :-1, 3]
return world_pos
def get_rotation(mat):
"""_summary_
Args:
mat (_type_): [..., 4, 4]
Returns:
_type_: _description_
"""
return mat[..., :-1, :-1]
def set_rotation(mat, rotmat):
"""_summary_
Args:
mat (_type_): [..., 4, 4]
Returns:
_type_: _description_
"""
mat[..., :-1, :-1] = rotmat
return mat
def set_position(mat, pos):
"""_summary_
Args:
mat (_type_): [..., 4, 4]
Returns:
_type_: _description_
"""
mat[..., :-1, 3] = pos
return mat
def get_position(mat):
"""_summary_
Args:
mat (_type_): [..., 4, 4]
Returns:
_type_: _description_
"""
return mat[..., :-1, 3]
def get_position_from(pos, mat):
"""_summary_
Args:
pos (_type_): [N, M, 3] or [N, 3]
mat (_type_): [N, 4, 4] or [4, 4]
Returns:
_type_: _description_
"""
if isinstance(mat, torch.Tensor):
rot_pos = torch.matmul(mat[..., :-1, :-1], pos.transpose(-1, -2)).transpose(
-1, -2
)
elif isinstance(mat, np.ndarray):
rot_pos = np.matmul(mat[..., :-1, :-1], pos.swapaxes(-1, -2)).swapaxes(-1, -2)
else:
raise ValueError
world_pos = rot_pos + mat[..., None, :-1, 3]
return world_pos
def get_position_from_rotmat(pos, mat):
"""_summary_
Args:
pos (_type_): [N, M, 3] or [N, 3]
mat (_type_): [N, 4, 4] or [4, 4]
Returns:
_type_: _description_
"""
if isinstance(mat, torch.Tensor):
rot_pos = torch.matmul(mat, pos.transpose(-1, -2)).transpose(-1, -2)
elif isinstance(mat, np.ndarray):
rot_pos = np.matmul(mat, pos.swapaxes(-1, -2)).swapaxes(-1, -2)
else:
raise ValueError
return rot_pos
def get_relative_direction_to(dir, mat):
"""_summary_
Args:
dir (_type_): [N, M, 3] or [N, 3]
mat (_type_): [N, 4, 4] or [4, 4]
Returns:
_type_: _description_
"""
if isinstance(mat, torch.Tensor):
mat_inv = torch.inverse(mat)
elif isinstance(mat, np.ndarray):
mat_inv = np.linalg.inv(mat)
else:
raise ValueError
mat_inv = normalized_matrix(mat_inv)
rot_mat_inv = mat_inv[..., :3, :3]
if isinstance(mat, torch.Tensor):
rel_dir = torch.matmul(rot_mat_inv, dir.transpose(-1, -2))
return rel_dir.transpose(-1, -2)
elif isinstance(mat, np.ndarray):
rel_dir = np.matmul(rot_mat_inv, dir.swapaxes(-1, -2))
return rel_dir.swapaxes(-1, -2)
else:
raise ValueError
return
def get_direction_from(dir, mat):
"""_summary_
Args:
dir (_type_): [N, M, 3] or [N, 3]
mat (_type_): [N, 4, 4] or [4, 4]
Returns:
tensor: [N, M, 3] or [N, 3]
"""
rot_mat = mat[..., :3, :3]
if isinstance(mat, torch.Tensor):
world_dir = torch.matmul(rot_mat, dir.transpose(-1, -2))
return world_dir.transpose(-1, -2)
elif isinstance(mat, np.ndarray):
world_dir = np.matmul(rot_mat, dir.swapaxes(-1, -2))
return world_dir.swapaxes(-1, -2)
else:
raise ValueError
return
def get_coord_vis(pos, rot_mat, scale=1.0):
forward = rot_mat[..., :, 2]
up = rot_mat[..., :, 1]
right = rot_mat[..., :, 0]
return pos + right * scale, pos + up * scale, pos + forward * scale
def project_vec(vec):
"""_summary_
Args:
vec (tensor): [*, 12], pos, forward, up and right
Returns:
proj_vec (tensor): [*, 4], posx, posz, forwardx, forwardz
"""
posx = vec[..., 0:1]
posz = vec[..., 2:3]
forwardx = vec[..., 3:4]
forwardz = vec[..., 5:6]
if isinstance(vec, torch.Tensor):
proj_vec = torch.cat((posx, posz, forwardx, forwardz), dim=-1)
elif isinstance(vec, np.ndarray):
proj_vec = np.concatenate((posx, posz, forwardx, forwardz), axis=-1)
else:
raise ValueError
return proj_vec
def xz2xyz(vec):
x = vec[..., 0:1]
z = vec[..., 1:2]
if isinstance(vec, torch.Tensor):
y = torch.zeros(vec.shape[:-1] + (1,), device=vec.device)
xyz_vec = torch.cat((x, y, z), dim=-1)
elif isinstance(vec, np.ndarray):
y = np.zeros(vec.shape[:-1] + (1,))
xyz_vec = np.concatenate((x, y, z), axis=-1)
else:
raise ValueError
return xyz_vec
def normalized(vec):
if isinstance(vec, torch.Tensor):
norm_vec = vec / (vec.norm(2, dim=-1, keepdim=True) + 1e-9)
elif isinstance(vec, np.ndarray):
norm_vec = vec / (np.linalg.norm(vec, ord=2, axis=-1, keepdims=True) + 1e-9)
else:
raise ValueError
return norm_vec
def normalized_matrix(mat):
if mat.shape[-1] == 4:
rot_mat = mat[..., :-1, :-1]
else:
rot_mat = mat
if isinstance(mat, torch.Tensor):
rot_mat_norm = rot_mat / (rot_mat.norm(2, dim=-2, keepdim=True) + 1e-9)
norm_mat = torch.zeros_like(mat)
elif isinstance(mat, np.ndarray):
rot_mat_norm = rot_mat / (
np.linalg.norm(rot_mat, ord=2, axis=-2, keepdims=True) + 1e-9
)
norm_mat = np.zeros_like(mat)
else:
raise ValueError
if mat.shape[-1] == 4:
norm_mat[..., :-1, :-1] = rot_mat_norm
norm_mat[..., :-1, -1] = mat[..., :-1, -1]
norm_mat[..., -1, -1] = 1.0
else:
norm_mat = rot_mat_norm
return norm_mat
def get_rot_mat_from_forward(forward):
"""_summary_
Args:
forward (tensor): [N, M, 3]
Returns:
mat (tensor): [N, M, 3, 3]
"""
if isinstance(forward, torch.Tensor):
mat = torch.eye(3, device=forward.device).repeat(forward.shape[:-1] + (1, 1))
right = torch.zeros_like(forward)
elif isinstance(forward, np.ndarray):
mat = np.eye(3, dtype=np.float32)
for _ in range(len(forward.shape) - 1):
mat = mat[None]
mat = np.tile(mat, forward.shape[:-1] + (1, 1))
right = np.zeros_like(forward)
else:
raise ValueError
right[..., 0] = forward[..., 2]
right[..., 1] = 0.0
right[..., 2] = -forward[..., 0]
# right = torch.cross(mat[..., 1], forward) # cannot backward
mat[..., 2] = normalized(forward)
right = normalized(right)
mat[..., 0] = right
return mat
def get_rot_mat_from_forward_up(forward, up):
"""_summary_
Args:
forward (tensor): [N, M, 3]
up (tensor): [N, M, 3]
Returns:
mat (tensor): [N, M, 3, 3]
"""
if isinstance(forward, torch.Tensor):
mat = torch.eye(3, device=forward.device).repeat(forward.shape[:-1] + (1, 1))
right = torch.cross(up, forward)
elif isinstance(forward, np.ndarray):
mat = np.eye(3, dtype=np.float32)
for _ in range(len(forward.shape) - 1):
mat = mat[None]
mat = np.tile(mat, forward.shape[:-1] + (1, 1))
right = np.cross(up, forward)
else:
raise ValueError
right = normalized(right)
mat[..., 2] = normalized(forward)
mat[..., 1] = normalized(up)
mat[..., 0] = right
return mat
def get_rot_mat_from_pose_vec(vec):
"""_summary_
Args:
vec (tensor): [N, M, 6]
Returns:
mat (tensor): [N, M, 3, 3]
"""
forward = vec[..., :3]
up = vec[..., 3:6]
return get_rot_mat_from_forward_up(forward, up)
def get_TRS(rot_mat, pos):
"""_summary_
Args:
rot_mat (tensor): [N, 3, 3]
pos (tensor): [N, 3]
Returns:
mat (tensor): [N, 4, 4]
"""
if isinstance(rot_mat, torch.Tensor):
mat = torch.eye(4, device=pos.device).repeat(pos.shape[:-1] + (1, 1))
elif isinstance(rot_mat, np.ndarray):
mat = np.eye(4, dtype=np.float32)
for _ in range(len(pos.shape) - 1):
mat = mat[None]
mat = np.tile(mat, pos.shape[:-1] + (1, 1))
else:
raise ValueError
mat[..., :3, :3] = rot_mat
mat[..., :3, 3] = pos
mat = normalized_matrix(mat)
return mat
def xzvec2mat(vec):
"""_summary_
Args:
vec (tensor): [N, 4]
Returns:
mat (tensor): [N, 4, 4]
"""
vec_shape = vec.shape[:-1]
if isinstance(vec, torch.Tensor):
pos = torch.zeros(vec_shape + (3,))
forward = torch.zeros(vec_shape + (3,))
elif isinstance(vec, np.ndarray):
pos = np.zeros(vec_shape + (3,))
forward = np.zeros(vec_shape + (3,))
else:
raise ValueError
pos[..., 0] = vec[..., 0]
pos[..., 2] = vec[..., 1]
forward[..., 0] = vec[..., 2]
forward[..., 2] = vec[..., 3]
rot_mat = get_rot_mat_from_forward(forward)
mat = get_TRS(rot_mat, pos)
return mat
def distance(vec1, vec2):
return ((vec1 - vec2) ** 2).sum() ** 0.5
def get_relative_pose_from_vec(pose, root, N):
root_p_mat = xzvec2mat(root)
pose = pose.reshape(-1, N, 12)
pose[..., :3] = get_position_from(pose[..., :3], root_p_mat)
pose[..., 3:6] = get_direction_from(pose[..., 3:6], root_p_mat)
pose[..., 6:9] = get_direction_from(pose[..., 6:9], root_p_mat)
pose[..., 9:] = get_direction_from(pose[..., 9:], root_p_mat)
pos = pose[..., 0, :3]
rot = pose[..., 3:9].reshape(-1, N * 6)
pose = np.concatenate((pos, rot), axis=-1)
return pose
def get_forward_from_pos(pos):
"""_summary_
Args:
pos (N, J, 3): joints positions of each frame
Returns:
_type_: _description_
"""
pos_y_vec = torch.tensor([0, 1, 0], dtype=torch.float32).to(pos.device)
face_joint_indx = [2, 1, 17, 16]
r_hip, l_hip, r_sdr, l_sdr = (
face_joint_indx # use hip and shoulder to get the cross vector
)
cross_hip = pos[..., 0, r_hip, :] - pos[..., 0, l_hip, :]
cross_sdr = pos[..., 0, r_sdr, :] - pos[..., 0, l_sdr, :]
cross_vec = cross_hip + cross_sdr # (3, )
forward_vec = torch.cross(pos_y_vec, cross_vec, dim=-1)
forward_vec = normalized(forward_vec)
return forward_vec
def project_point_along_ray(p, ray, keepnorm=False):
"""_summary_
Args:
p (*, 3): point positions
ray (*, 3): ray direction
keepnorm: False -> project point on the ray,
True -> project point on the ray and keep the point length
Returns:
_type_: _description_
"""
ray = normalized(ray)
if keepnorm:
new_p = ray * p.norm(dim=-1, keepdim=True)
else:
dot_product = torch.sum(p * ray, dim=-1, keepdim=True)
new_p = dot_product * ray
return new_p
def solve_point_along_ray_with_constraint(c, ray, p, constraint="x"):
"""_summary_
Args:
c (*,): constraint value
ray (*, 3): ray direction
p (*, 3): start point of the ray
Returns:
_type_: _description_
"""
ray = normalized(ray)
if constraint == "x":
ind = 0
elif constraint == "y":
ind = 1
elif constraint == "z":
ind = 2
else:
raise ValueError
t = (c - p[..., ind]) / ray[..., ind]
out_p = ray * t[..., None] + p
return out_p
def calc_cosine(vec1, vec2, return_angle=False):
"""_summary_
Args:
vec1 (*, 3): vector
vec2 (*, 3): vector
return_angle: True -> return angle, False -> return cosine
Returns:
_type_: _description_
"""
vec1 = normalized(vec1)
vec2 = normalized(vec2)
cosine = torch.sum(vec1 * vec2, dim=-1)
if return_angle:
return torch.acos(cosine)
return cosine
############################################
#
# quaternion assumes xyzw
#
############################################
def quat_xyzw2wxyz(quat):
new_quat = torch.cat([quat[..., 3:4], quat[..., :3]], dim=-1)
return new_quat
def quat_wxyz2xyzw(quat):
new_quat = torch.cat([quat[..., 1:4], quat[..., :1]], dim=-1)
return new_quat
def quat_mul(a, b):
"""
quaternion multiplication
"""
x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2
z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2
return torch.stack([x, y, z, w], dim=-1)
def quat_pos(x):
"""
make all the real part of the quaternion positive
"""
q = x
z = (q[..., 3:] < 0).float()
q = (1 - 2 * z) * q
return q
def quat_abs(x):
"""
quaternion norm (unit quaternion represents a 3D rotation, which has norm of 1)
"""
x = x.norm(p=2, dim=-1)
return x
def quat_unit(x):
"""
normalized quaternion with norm of 1
"""
norm = quat_abs(x).unsqueeze(-1)
return x / (norm.clamp(min=1e-4))
def quat_conjugate(x):
"""
quaternion with its imaginary part negated
"""
return torch.cat([-x[..., :3], x[..., 3:]], dim=-1)
def quat_real(x):
"""
real component of the quaternion
"""
return x[..., 3]
def quat_imaginary(x):
"""
imaginary components of the quaternion
"""
return x[..., :3]
def quat_norm_check(x):
"""
verify that a quaternion has norm 1
"""
assert bool((abs(x.norm(p=2, dim=-1) - 1) < 1e-3).all()), (
"the quaternion is has non-1 norm: {}".format(abs(x.norm(p=2, dim=-1) - 1))
)
assert bool((x[..., 3] >= 0).all()), "the quaternion has negative real part"
def quat_normalize(q):
"""
Construct 3D rotation from quaternion (the quaternion needs not to be normalized).
"""
q = quat_unit(quat_pos(q)) # normalized to positive and unit quaternion
return q
def quat_from_xyz(xyz):
"""
Construct 3D rotation from the imaginary component
"""
w = (1.0 - xyz.norm()).unsqueeze(-1)
assert bool((w >= 0).all()), "xyz has its norm greater than 1"
return torch.cat([xyz, w], dim=-1)
def quat_identity(shape: List[int]):
"""
Construct 3D identity rotation given shape
"""
w = torch.ones(shape + (1,))
xyz = torch.zeros(shape + (3,))
q = torch.cat([xyz, w], dim=-1)
return quat_normalize(q)
def tgm_quat_from_angle_axis(angle, axis, degree: bool = False):
"""Create a 3D rotation from angle and axis of rotation. The rotation is counter-clockwise
along the axis.
The rotation can be interpreted as a_R_b where frame "b" is the new frame that
gets rotated counter-clockwise along the axis from frame "a"
:param angle: angle of rotation
:type angle: Tensor
:param axis: axis of rotation
:type axis: Tensor
:param degree: put True here if the angle is given by degree
:type degree: bool, optional, default=False
"""
if degree:
angle = angle / 180.0 * math.pi
theta = (angle / 2).unsqueeze(-1)
axis = axis / (axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-4))
xyz = axis * theta.sin()
w = theta.cos()
return quat_normalize(torch.cat([w, xyz], dim=-1))
def quat_from_rotation_matrix(m):
"""
Construct a 3D rotation from a valid 3x3 rotation matrices.
Reference can be found here:
http://www.cg.info.hiroshima-cu.ac.jp/~miyazaki/knowledge/teche52.html
:param m: 3x3 orthogonal rotation matrices.
:type m: Tensor
:rtype: Tensor
"""
m = m.unsqueeze(0)
diag0 = m[..., 0, 0]
diag1 = m[..., 1, 1]
diag2 = m[..., 2, 2]
# Math stuff.
w = (((diag0 + diag1 + diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5
x = (((diag0 - diag1 - diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5
y = (((-diag0 + diag1 - diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5
z = (((-diag0 - diag1 + diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5
# Only modify quaternions where w > x, y, z.
c0 = (w >= x) & (w >= y) & (w >= z)
x[c0] *= (m[..., 2, 1][c0] - m[..., 1, 2][c0]).sign()
y[c0] *= (m[..., 0, 2][c0] - m[..., 2, 0][c0]).sign()
z[c0] *= (m[..., 1, 0][c0] - m[..., 0, 1][c0]).sign()
# Only modify quaternions where x > w, y, z
c1 = (x >= w) & (x >= y) & (x >= z)
w[c1] *= (m[..., 2, 1][c1] - m[..., 1, 2][c1]).sign()
y[c1] *= (m[..., 1, 0][c1] + m[..., 0, 1][c1]).sign()
z[c1] *= (m[..., 0, 2][c1] + m[..., 2, 0][c1]).sign()
# Only modify quaternions where y > w, x, z.
c2 = (y >= w) & (y >= x) & (y >= z)
w[c2] *= (m[..., 0, 2][c2] - m[..., 2, 0][c2]).sign()
x[c2] *= (m[..., 1, 0][c2] + m[..., 0, 1][c2]).sign()
z[c2] *= (m[..., 2, 1][c2] + m[..., 1, 2][c2]).sign()
# Only modify quaternions where z > w, x, y.
c3 = (z >= w) & (z >= x) & (z >= y)
w[c3] *= (m[..., 1, 0][c3] - m[..., 0, 1][c3]).sign()
x[c3] *= (m[..., 2, 0][c3] + m[..., 0, 2][c3]).sign()
y[c3] *= (m[..., 2, 1][c3] + m[..., 1, 2][c3]).sign()
return quat_normalize(torch.stack([x, y, z, w], dim=-1)).squeeze(0)
def quat_mul_norm(x, y):
"""
Combine two set of 3D rotations together using \**\* operator. The shape needs to be
broadcastable
"""
return quat_normalize(quat_mul(x, y))
def quat_rotate(rot, vec):
"""
Rotate a 3D vector with the 3D rotation
"""
other_q = torch.cat([vec, torch.zeros_like(vec[..., :1])], dim=-1)
return quat_imaginary(quat_mul(quat_mul(rot, other_q), quat_conjugate(rot)))
def quat_inverse(x):
"""
The inverse of the rotation
"""
return quat_conjugate(x)
def quat_identity_like(x):
"""
Construct identity 3D rotation with the same shape
"""
return quat_identity(x.shape[:-1])
def quat_angle_axis(x):
"""
The (angle, axis) representation of the rotation. The axis is normalized to unit length.
The angle is guaranteed to be between [0, pi].
"""
s = 2 * (x[..., 3] ** 2) - 1
angle = s.clamp(-1, 1).arccos() # just to be safe
axis = x[..., :3]
axis /= axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-4)
return angle, axis
def quat_yaw_rotation(x, z_up: bool = True):
"""
Yaw rotation (rotation along z-axis)
"""
q = x
if z_up:
q = torch.cat([torch.zeros_like(q[..., 0:2]), q[..., 2:3], q[..., 3:]], dim=-1)
else:
q = torch.cat(
[
torch.zeros_like(q[..., 0:1]),
q[..., 1:2],
torch.zeros_like(q[..., 2:3]),
q[..., 3:4],
],
dim=-1,
)
return quat_normalize(q)
def transform_from_rotation_translation(
r: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None
):
"""
Construct a transform from a quaternion and 3D translation. Only one of them can be None.
"""
assert r is not None or t is not None, "rotation and translation can't be all None"
if r is None:
assert t is not None
r = quat_identity(list(t.shape))
if t is None:
t = torch.zeros(list(r.shape) + [3])
return torch.cat([r, t], dim=-1)
def transform_identity(shape: List[int]):
"""
Identity transformation with given shape
"""
r = quat_identity(shape)
t = torch.zeros(shape + [3])
return transform_from_rotation_translation(r, t)
def transform_rotation(x):
"""Get rotation from transform"""
return x[..., :4]
def transform_translation(x):
"""Get translation from transform"""
return x[..., 4:]
def transform_inverse(x):
"""
Inverse transformation
"""
inv_so3 = quat_inverse(transform_rotation(x))
return transform_from_rotation_translation(
r=inv_so3, t=quat_rotate(inv_so3, -transform_translation(x))
)
def transform_identity_like(x):
"""
identity transformation with the same shape
"""
return transform_identity(x.shape)
def transform_mul(x, y):
"""
Combine two transformation together
"""
z = transform_from_rotation_translation(
r=quat_mul_norm(transform_rotation(x), transform_rotation(y)),
t=quat_rotate(transform_rotation(x), transform_translation(y))
+ transform_translation(x),
)
return z
def transform_apply(rot, vec):
"""
Transform a 3D vector
"""
assert isinstance(vec, torch.Tensor)
return quat_rotate(transform_rotation(rot), vec) + transform_translation(rot)
def rot_matrix_det(x):
"""
Return the determinant of the 3x3 matrix. The shape of the tensor will be as same as the
shape of the matrix
"""
a, b, c = x[..., 0, 0], x[..., 0, 1], x[..., 0, 2]
d, e, f = x[..., 1, 0], x[..., 1, 1], x[..., 1, 2]
g, h, i = x[..., 2, 0], x[..., 2, 1], x[..., 2, 2]
t1 = a * (e * i - f * h)
t2 = b * (d * i - f * g)
t3 = c * (d * h - e * g)
return t1 - t2 + t3
def rot_matrix_integrity_check(x):
"""
Verify that a rotation matrix has a determinant of one and is orthogonal
"""
det = rot_matrix_det(x)
assert bool((abs(det - 1) < 1e-3).all()), "the matrix has non-one determinant"
rtr = x @ x.permute(torch.arange(x.dim() - 2), -1, -2)
rtr_gt = rtr.zeros_like()
rtr_gt[..., 0, 0] = 1
rtr_gt[..., 1, 1] = 1
rtr_gt[..., 2, 2] = 1
assert bool(((rtr - rtr_gt) < 1e-3).all()), "the matrix is not orthogonal"
def rot_matrix_from_quaternion(q):
"""
Construct rotation matrix from quaternion
"""
# Shortcuts for individual elements (using wikipedia's convention)
qi, qj, qk, qr = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
# Set individual elements
R00 = 1.0 - 2.0 * (qj**2 + qk**2)
R01 = 2 * (qi * qj - qk * qr)
R02 = 2 * (qi * qk + qj * qr)
R10 = 2 * (qi * qj + qk * qr)
R11 = 1.0 - 2.0 * (qi**2 + qk**2)
R12 = 2 * (qj * qk - qi * qr)
R20 = 2 * (qi * qk - qj * qr)
R21 = 2 * (qj * qk + qi * qr)
R22 = 1.0 - 2.0 * (qi**2 + qj**2)
R0 = torch.stack([R00, R01, R02], dim=-1)
R1 = torch.stack([R10, R11, R12], dim=-1)
R2 = torch.stack([R20, R21, R22], dim=-1)
R = torch.stack([R0, R1, R2], dim=-2)
return R
def euclidean_to_rotation_matrix(x):
"""
Get the rotation matrix on the top-left corner of a Euclidean transformation matrix
"""
return x[..., :3, :3]
def euclidean_integrity_check(x):
euclidean_to_rotation_matrix(x) # check 3d-rotation matrix
assert bool((x[..., 3, :3] == 0).all()), "the last row is illegal"
assert bool((x[..., 3, 3] == 1).all()), "the last row is illegal"
def euclidean_translation(x):
"""
Get the translation vector located at the last column of the matrix
"""
return x[..., :3, 3]
def euclidean_inverse(x):
"""
Compute the matrix that represents the inverse rotation
"""
s = x.zeros_like()
irot = quat_inverse(quat_from_rotation_matrix(x))
s[..., :3, :3] = irot
s[..., :3, 4] = quat_rotate(irot, -euclidean_translation(x))
return s
def euclidean_to_transform(transformation_matrix):
"""
Construct a transform from a Euclidean transformation matrix
"""
return transform_from_rotation_translation(
r=quat_from_rotation_matrix(
m=euclidean_to_rotation_matrix(transformation_matrix)
),
t=euclidean_translation(transformation_matrix),
)
def to_torch(x, dtype=torch.float, device="cuda:0", requires_grad=False):
return torch.tensor(x, dtype=dtype, device=device, requires_grad=requires_grad)
def quat_mul(a, b):
assert a.shape == b.shape
shape = a.shape
a = a.reshape(-1, 4)
b = b.reshape(-1, 4)
x1, y1, z1, w1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3]
x2, y2, z2, w2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3]
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
quat = torch.stack([x, y, z, w], dim=-1).view(shape)
return quat
def normalize(x, eps: float = 1e-9):
return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)
def quat_apply(a, b):
shape = b.shape
a = a.reshape(-1, 4)
b = b.reshape(-1, 3)
xyz = a[:, :3]
t = xyz.cross(b, dim=-1) * 2
return (b + a[:, 3:] * t + xyz.cross(t, dim=-1)).view(shape)
def quat_rotate(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = (
q_vec
* torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1)
* 2.0
)
return a + b + c
def quat_rotate_inverse(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = (
q_vec
* torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1)
* 2.0
)
return a - b + c
def quat_conjugate(a):
shape = a.shape
a = a.reshape(-1, 4)
return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape)
def quat_unit(a):
return normalize(a)
def quat_from_angle_axis(angle, axis):
theta = (angle / 2).unsqueeze(-1)
xyz = normalize(axis) * torch.sin(theta.clone())
w = torch.cos(theta.clone())
return quat_unit(torch.cat([xyz, w], dim=-1))
def normalize_angle(x):
return torch.atan2(torch.sin(x.clone()), torch.cos(x.clone()))
def tf_inverse(q, t):
q_inv = quat_conjugate(q)
return q_inv, -quat_apply(q_inv, t)
def tf_apply(q, t, v):
return quat_apply(q, v) + t
def tf_vector(q, v):
return quat_apply(q, v)
def tf_combine(q1, t1, q2, t2):
return quat_mul(q1, q2), quat_apply(q1, t2) + t1
def get_basis_vector(q, v):
return quat_rotate(q, v)
def get_axis_params(value, axis_idx, x_value=0.0, dtype=float, n_dims=3):
"""construct arguments to `Vec` according to axis index."""
zs = np.zeros((n_dims,))
assert axis_idx < n_dims, "the axis dim should be within the vector dimensions"
zs[axis_idx] = 1.0
params = np.where(zs == 1.0, value, zs)
params[0] = x_value
return list(params.astype(dtype))
def copysign(a, b):
# type: (float, Tensor) -> Tensor
a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0])
return torch.abs(a) * torch.sign(b)
def get_euler_xyz(q):
qx, qy, qz, qw = 0, 1, 2, 3
# roll (x-axis rotation)
sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])
cosr_cosp = (
q[:, qw] * q[:, qw]
- q[:, qx] * q[:, qx]
- q[:, qy] * q[:, qy]
+ q[:, qz] * q[:, qz]
)
roll = torch.atan2(sinr_cosp, cosr_cosp)
# pitch (y-axis rotation)
sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])
pitch = torch.where(
torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp)
)
# yaw (z-axis rotation)
siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])
cosy_cosp = (
q[:, qw] * q[:, qw]
+ q[:, qx] * q[:, qx]
- q[:, qy] * q[:, qy]
- q[:, qz] * q[:, qz]
)
yaw = torch.atan2(siny_cosp, cosy_cosp)
return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi)
def quat_from_euler_xyz(roll, pitch, yaw):
cy = torch.cos(yaw * 0.5)
sy = torch.sin(yaw * 0.5)
cr = torch.cos(roll * 0.5)
sr = torch.sin(roll * 0.5)
cp = torch.cos(pitch * 0.5)
sp = torch.sin(pitch * 0.5)
qw = cy * cr * cp + sy * sr * sp
qx = cy * sr * cp - sy * cr * sp
qy = cy * cr * sp + sy * sr * cp
qz = sy * cr * cp - cy * sr * sp
return torch.stack([qx, qy, qz, qw], dim=-1)
def torch_rand_float(lower, upper, shape, device):
# type: (float, float, Tuple[int, int], str) -> Tensor
return (upper - lower) * torch.rand(*shape, device=device) + lower
def torch_random_dir_2(shape, device):
# type: (Tuple[int, int], str) -> Tensor
angle = torch_rand_float(-np.pi, np.pi, shape, device).squeeze(-1)
return torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1)
def tensor_clamp(t, min_t, max_t):
return torch.max(torch.min(t, max_t), min_t)
def scale(x, lower, upper):
return 0.5 * (x + 1.0) * (upper - lower) + lower
def unscale(x, lower, upper):
return (2.0 * x - upper - lower) / (upper - lower)
def unscale_np(x, lower, upper):
return (2.0 * x - upper - lower) / (upper - lower)
def quat_to_angle_axis(q):
# type: (Tensor) -> Tuple[Tensor, Tensor]
# computes axis-angle representation from quaternion q
# q must be normalized
min_theta = 1e-5
qx, qy, qz, qw = 0, 1, 2, 3
sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw])
angle = 2 * torch.acos(q[..., qw])
angle = normalize_angle(angle)
sin_theta_expand = sin_theta.unsqueeze(-1)
axis = q[..., qx:qw] / sin_theta_expand
mask = torch.abs(sin_theta) > min_theta
default_axis = torch.zeros_like(axis)
default_axis[..., -1] = 1
angle = torch.where(mask, angle, torch.zeros_like(angle))
mask_expand = mask.unsqueeze(-1)
axis = torch.where(mask_expand, axis, default_axis)
return angle, axis
def angle_axis_to_exp_map(angle, axis):
# type: (Tensor, Tensor) -> Tensor
# compute exponential map from axis-angle
angle_expand = angle.unsqueeze(-1)
exp_map = angle_expand * axis
return exp_map
def quat_to_exp_map(q):
# type: (Tensor) -> Tensor
# compute exponential map from quaternion
# q must be normalized
angle, axis = quat_to_angle_axis(q)
exp_map = angle_axis_to_exp_map(angle, axis)
return exp_map
def quat_to_tan_norm(q):
# type: (Tensor) -> Tensor
# represents a rotation using the tangent and normal vectors
ref_tan = torch.zeros_like(q[..., 0:3])
ref_tan[..., 0] = 1
tan = quat_rotate(q, ref_tan)
ref_norm = torch.zeros_like(q[..., 0:3])
ref_norm[..., -1] = 1
norm = quat_rotate(q, ref_norm)
norm_tan = torch.cat([tan, norm], dim=len(tan.shape) - 1)
return norm_tan
def euler_xyz_to_exp_map(roll, pitch, yaw):
# type: (Tensor, Tensor, Tensor) -> Tensor
q = quat_from_euler_xyz(roll, pitch, yaw)
exp_map = quat_to_exp_map(q)
return exp_map
def exp_map_to_angle_axis(exp_map):
min_theta = 1e-5
angle = torch.norm(exp_map.clone(), dim=-1) + 1e-6
angle_exp = torch.unsqueeze(angle, dim=-1)
axis = exp_map.clone() / angle_exp.clone()
angle = normalize_angle(angle)
default_axis = torch.zeros_like(exp_map)
default_axis[..., -1] = 1
mask = torch.abs(angle) > min_theta
angle = torch.where(mask, angle, torch.zeros_like(angle))
mask_expand = mask.unsqueeze(-1)
axis = torch.where(mask_expand, axis, default_axis)
return angle, axis
def exp_map_to_quat(exp_map):
angle, axis = exp_map_to_angle_axis(exp_map)
q = quat_from_angle_axis(angle, axis)
return q
def slerp(q0, q1, t):
# type: (Tensor, Tensor, Tensor) -> Tensor
cos_half_theta = torch.sum(q0 * q1, dim=-1)
neg_mask = cos_half_theta < 0
q1 = q1.clone()
q1[neg_mask] = -q1[neg_mask]
cos_half_theta = torch.abs(cos_half_theta)
cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)
half_theta = torch.acos(cos_half_theta)
sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta)
ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta
ratioB = torch.sin(t * half_theta) / sin_half_theta
new_q = ratioA * q0 + ratioB * q1
new_q = torch.where(torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q)
new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)
return new_q
def calc_heading_vec(q, head_ind=0):
# type: (Tensor, int) -> Tensor
# calculate heading direction from quaternion
# the heading is the direction vector
# q must be normalized
ref_dir = torch.zeros_like(q[..., 0:3])
ref_dir[..., head_ind] = 1
rot_dir = quat_rotate(q, ref_dir)
return rot_dir
def calc_heading(q, head_ind=0, gravity_axis="z"):
# type: (Tensor, int, str) -> Tensor
# calculate heading direction from quaternion
# the heading is the direction on the xy plane
# q must be normalized
ref_dir = torch.zeros_like(q[..., 0:3])
ref_dir[..., head_ind] = 1
# ref_dir[..., 0] = 1
shape = ref_dir.shape[:-1]
q = q.reshape((-1, 4))
ref_dir = ref_dir.reshape(-1, 3)
rot_dir = quat_rotate(q, ref_dir)
rot_dir = rot_dir.reshape(shape + (3,))
if gravity_axis == "z":
heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])
elif gravity_axis == "y":
heading = torch.atan2(rot_dir[..., 0], rot_dir[..., 2])
elif gravity_axis == "x":
heading = torch.atan2(rot_dir[..., 2], rot_dir[..., 1])
return heading
def calc_heading_quat(q, head_ind=0, gravity_axis="z"):
# type: (Tensor, int, str) -> Tensor
# calculate heading rotation from quaternion
# the heading is the direction on the xy plane
# q must be normalized
heading = calc_heading(q, head_ind, gravity_axis=gravity_axis)
axis = torch.zeros_like(q[..., 0:3])
if gravity_axis == "z":
g_axis = 2
elif gravity_axis == "y":
g_axis = 1
elif gravity_axis == "x":
g_axis = 0
axis[..., g_axis] = 1
heading_q = quat_from_angle_axis(heading, axis)
return heading_q
def calc_heading_quat_inv(q, head_ind=0):
# type: (Tensor, int) -> Tensor
# calculate heading rotation from quaternion
# the heading is the direction on the xy plane
# q must be normalized
heading = calc_heading(q, head_ind)
axis = torch.zeros_like(q[..., 0:3])
axis[..., 2] = 1
heading_q = quat_from_angle_axis(-heading, axis)
return heading_q
def forward_kinematics(mat, parent):
"""_summary_
Args:
mat ([..., N, 3, 3]): _description_
parent (): _description_
"""
if isinstance(mat, torch.Tensor):
rotations = torch.eye(mat.shape[-1], device=mat.device)
rotations = rotations.repeat(mat.shape[:-2] + (1, 1))
else:
rotations = np.eye(mat.shape[-1], dtype=np.float32)
rotations = np.tile(rotations, mat.shape[:-2] + (1, 1))
for i in range(mat.shape[-3]):
if parent[i] != -1:
if isinstance(mat, torch.Tensor):
# this way make gradient flow
new_mat = get_mat_BfromA(
rotations[..., parent[i], :, :], mat[..., i, :, :]
)
rotations = torch.cat(
(
rotations[..., :i, :, :],
new_mat[..., None, :, :],
rotations[..., i + 1 :, :, :],
),
dim=-3,
)
else:
rotations[..., i, :, :] = get_mat_BfromA(
rotations[..., parent[i], :, :], mat[..., i, :, :]
)
else:
if isinstance(mat, torch.Tensor):
# this way make gradient flow
rotations = torch.cat(
(mat[..., : i + 1, :, :], rotations[..., i + 1 :, :, :]), dim=-3
)
else:
rotations[..., i, :, :] = mat[..., i, :, :]
return rotations