|
|
import copy |
|
|
import math |
|
|
from typing import List, Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
def identity_mat(x=None, device="cpu", is_numpy=False): |
|
|
if x is not None: |
|
|
if isinstance(x, torch.Tensor): |
|
|
mat = torch.eye(4, device=device) |
|
|
mat = mat.repeat(x.shape[:-2] + (1, 1)) |
|
|
elif isinstance(x, np.ndarray): |
|
|
mat = np.eye(4, dtype=np.float32) |
|
|
if x is not None: |
|
|
for _ in range(len(x.shape) - 2): |
|
|
mat = mat[None] |
|
|
mat = np.tile(mat, x.shape[:-2] + (1, 1)) |
|
|
else: |
|
|
raise ValueError |
|
|
else: |
|
|
|
|
|
if is_numpy: |
|
|
mat = np.eye(4, dtype=np.float32) |
|
|
else: |
|
|
mat = torch.eye(4, device=device) |
|
|
|
|
|
return mat |
|
|
|
|
|
|
|
|
def vec2mat(vec): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
vec (tensor): [12], pos, forward, up and right |
|
|
|
|
|
Returns: |
|
|
mat_world(tensor): [4, 4] |
|
|
""" |
|
|
|
|
|
v = np.tile(np.array([[0, 0, 0, 1]]), (1, 1)) |
|
|
if isinstance(vec, torch.Tensor): |
|
|
v = torch.tensor( |
|
|
v, |
|
|
device=vec.device, |
|
|
dtype=vec.dtype, |
|
|
) |
|
|
pos = vec[:3] |
|
|
forward = vec[3:6] |
|
|
up = vec[6:9] |
|
|
right = vec[9:12] |
|
|
|
|
|
if isinstance(vec, torch.Tensor): |
|
|
mat_world = torch.stack([right, up, forward, pos], dim=-1) |
|
|
mat_world = torch.cat([mat_world, v], dim=-2) |
|
|
elif isinstance(vec, np.ndarray): |
|
|
mat_world = np.stack([right, up, forward, pos], axis=-1) |
|
|
mat_world = np.concatenate([mat_world, v], axis=-2) |
|
|
else: |
|
|
raise ValueError |
|
|
mat_world = normalized_matrix(mat_world) |
|
|
return mat_world |
|
|
|
|
|
|
|
|
def mat2vec(mat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
mat(tensor): [4, 4] |
|
|
|
|
|
Returns: |
|
|
vec (tensor): [12], pos, forward, up and right |
|
|
""" |
|
|
|
|
|
pos = mat[:-1, 3] |
|
|
forward = normalized(mat[:-1, 2]) |
|
|
up = normalized(mat[:-1, 1]) |
|
|
right = normalized(mat[:-1, 0]) |
|
|
if isinstance(mat, torch.Tensor): |
|
|
vec = torch.cat((pos, forward, up, right)) |
|
|
elif isinstance(mat, np.ndarray): |
|
|
vec = np.concatenate((pos, forward, up, right)) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
return vec |
|
|
|
|
|
|
|
|
def vec2mat_batch(vec): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
vec (tensor): [B, 12], pos, forward, up and right |
|
|
|
|
|
Returns: |
|
|
mat_world(tensor): [B, 4, 4] |
|
|
""" |
|
|
|
|
|
|
|
|
v = np.tile(np.array([[0, 0, 0, 1]], dtype=np.float32), (vec.shape[0], 1, 1)) |
|
|
if isinstance(vec, torch.Tensor): |
|
|
v = torch.tensor( |
|
|
v, |
|
|
device=vec.device, |
|
|
dtype=vec.dtype, |
|
|
) |
|
|
pos = vec[..., :3] |
|
|
forward = vec[..., 3:6] |
|
|
up = vec[..., 6:9] |
|
|
right = vec[..., 9:12] |
|
|
if isinstance(vec, torch.Tensor): |
|
|
mat_world = torch.stack([right, up, forward, pos], dim=-1) |
|
|
mat_world = torch.cat([mat_world, v], dim=-2) |
|
|
elif isinstance(vec, np.ndarray): |
|
|
mat_world = np.stack([right, up, forward, pos], axis=-1) |
|
|
mat_world = np.concatenate([mat_world, v], axis=-2) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
mat_world = normalized_matrix(mat_world) |
|
|
return mat_world |
|
|
|
|
|
|
|
|
def rotmat2tan_norm(mat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
mat(tensor): [B, 3, 3] |
|
|
|
|
|
Returns: |
|
|
vec (tensor): [B, 6], tan norm |
|
|
""" |
|
|
if isinstance(mat, np.ndarray): |
|
|
tan = np.zeros_like(mat[..., 2]) |
|
|
norm = np.zeros_like(mat[..., 0]) |
|
|
elif isinstance(mat, torch.Tensor): |
|
|
tan = torch.zeros_like(mat[..., 2]) |
|
|
norm = torch.zeros_like(mat[..., 0]) |
|
|
else: |
|
|
raise ValueError |
|
|
tan[...] = mat[..., 2, ::-1] |
|
|
tan[..., -1] *= -1 |
|
|
norm[...] = mat[..., 0, ::-1] |
|
|
norm[..., -1] *= -1 |
|
|
if isinstance(mat, np.ndarray): |
|
|
tan_norm = np.concatenate((tan, norm), axis=-1) |
|
|
elif isinstance(mat, torch.Tensor): |
|
|
tan_norm = torch.cat((tan, norm), dim=-1) |
|
|
else: |
|
|
raise ValueError |
|
|
return tan_norm |
|
|
|
|
|
|
|
|
def mat2tan_norm(mat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
mat(tensor): [B, 4, 4] |
|
|
|
|
|
Returns: |
|
|
vec (tensor): [B, 6], tan norm |
|
|
""" |
|
|
rot_mat = mat[..., :-1, :-1] |
|
|
return rotmat2tan_norm(rot_mat) |
|
|
|
|
|
|
|
|
def rotmat2tan_norm(mat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
mat(tensor): [B, 3, 3] |
|
|
|
|
|
Returns: |
|
|
vec (tensor): [B, 6], tan norm |
|
|
""" |
|
|
if isinstance(mat, np.ndarray): |
|
|
tan = np.zeros_like(mat[..., 2]) |
|
|
norm = np.zeros_like(mat[..., 0]) |
|
|
tan[...] = mat[..., 2, ::-1] |
|
|
norm[...] = mat[..., 0, ::-1] |
|
|
elif isinstance(mat, torch.Tensor): |
|
|
tan = torch.zeros_like(mat[..., 2]) |
|
|
norm = torch.zeros_like(mat[..., 0]) |
|
|
tan[...] = torch.flip(mat[..., 2], dims=[-1]) |
|
|
norm[...] = torch.flip(mat[..., 0], dims=[-1]) |
|
|
else: |
|
|
raise ValueError |
|
|
tan[..., -1] *= -1 |
|
|
norm[..., -1] *= -1 |
|
|
if isinstance(mat, np.ndarray): |
|
|
tan_norm = np.concatenate((tan, norm), axis=-1) |
|
|
elif isinstance(mat, torch.Tensor): |
|
|
tan_norm = torch.cat((tan, norm), dim=-1) |
|
|
else: |
|
|
raise ValueError |
|
|
return tan_norm |
|
|
|
|
|
|
|
|
def tan_norm2rotmat(tan_norm): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
mat(tensor): [B, 6] |
|
|
|
|
|
Returns: |
|
|
vec (tensor): [B, 3] |
|
|
""" |
|
|
tan = copy.deepcopy(tan_norm[..., :3]) |
|
|
norm = copy.deepcopy(tan_norm[..., 3:]) |
|
|
tan[..., -1] *= -1 |
|
|
norm[..., -1] *= -1 |
|
|
if isinstance(tan_norm, np.ndarray): |
|
|
rotmat = np.zeros(tan_norm.shape[:-1] + (3, 3)) |
|
|
tan = tan[..., ::-1] |
|
|
norm = norm[..., ::-1] |
|
|
other = np.cross(tan, norm) |
|
|
elif isinstance(tan_norm, torch.Tensor): |
|
|
rotmat = torch.zeros(tan_norm.shape[:-1] + (3, 3), device=tan_norm.device) |
|
|
tan = torch.flip(tan, dims=[-1]) |
|
|
norm = torch.flip(norm, dims=[-1]) |
|
|
other = torch.cross(tan, norm) |
|
|
else: |
|
|
raise ValueError |
|
|
rotmat[..., 2, :] = tan |
|
|
rotmat[..., 0, :] = norm |
|
|
rotmat[..., 1, :] = other |
|
|
return rotmat |
|
|
|
|
|
|
|
|
def rotmat332vec_batch(mat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
mat(tensor): [B, 3, 3] |
|
|
|
|
|
Returns: |
|
|
vec (tensor): [B, 6], forward, up, right |
|
|
""" |
|
|
|
|
|
mat = normalized_matrix(mat) |
|
|
forward = mat[..., :, 2] |
|
|
up = mat[..., :, 1] |
|
|
right = mat[..., :, 0] |
|
|
if isinstance(mat, torch.Tensor): |
|
|
vec = torch.cat((forward, up, right), dim=-1) |
|
|
elif isinstance(mat, np.ndarray): |
|
|
vec = np.concatenate((forward, up, right), axis=-1) |
|
|
else: |
|
|
raise ValueError |
|
|
return vec |
|
|
|
|
|
|
|
|
def rotmat2vec_batch(mat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
mat(tensor): [B, 4, 4] |
|
|
|
|
|
Returns: |
|
|
vec (tensor): [B, 9], forward, up, right |
|
|
""" |
|
|
|
|
|
mat = normalized_matrix(mat) |
|
|
forward = mat[..., :-1, 2] |
|
|
up = mat[..., :-1, 1] |
|
|
right = mat[..., :-1, 0] |
|
|
if isinstance(mat, torch.Tensor): |
|
|
vec = torch.cat((forward, up, right), dim=-1) |
|
|
elif isinstance(mat, np.ndarray): |
|
|
vec = np.concatenate((forward, up, right), axis=-1) |
|
|
else: |
|
|
raise ValueError |
|
|
return vec |
|
|
|
|
|
|
|
|
def mat2vec_batch(mat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
mat(tensor): [B, 4, 4] |
|
|
|
|
|
Returns: |
|
|
vec (tensor): [B, 12], pos, forward, up and right |
|
|
""" |
|
|
|
|
|
mat = normalized_matrix(mat) |
|
|
pos = mat[..., :-1, 3] |
|
|
forward = mat[..., :-1, 2] |
|
|
up = mat[..., :-1, 1] |
|
|
right = mat[..., :-1, 0] |
|
|
if isinstance(mat, torch.Tensor): |
|
|
vec = torch.cat((pos, forward, up, right), dim=-1) |
|
|
elif isinstance(mat, np.ndarray): |
|
|
vec = np.concatenate((pos, forward, up, right), axis=-1) |
|
|
else: |
|
|
raise ValueError |
|
|
return vec |
|
|
|
|
|
|
|
|
def mat2pose_batch(mat, returnvel=True): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
mat(tensor): [B, 4, 4] |
|
|
|
|
|
Returns: |
|
|
vec (tensor): [B, 12], pos, forward, up, zeros |
|
|
""" |
|
|
|
|
|
mat = normalized_matrix(mat) |
|
|
pos = mat[..., :-1, 3] |
|
|
forward = mat[..., :-1, 2] |
|
|
up = mat[..., :-1, 1] |
|
|
if isinstance(mat, torch.Tensor): |
|
|
if returnvel: |
|
|
vel = torch.zeros_like(up) |
|
|
vec = torch.cat((pos, forward, up, vel), dim=-1) |
|
|
else: |
|
|
vec = torch.cat((pos, forward, up), dim=-1) |
|
|
elif isinstance(mat, np.ndarray): |
|
|
if returnvel: |
|
|
vel = np.zeros_like(up) |
|
|
vec = np.concatenate((pos, forward, up, vel), axis=-1) |
|
|
else: |
|
|
vec = np.concatenate((pos, forward, up), axis=-1) |
|
|
else: |
|
|
raise ValueError |
|
|
return vec |
|
|
|
|
|
|
|
|
def get_mat_BinA(matCtoA, matCtoB): |
|
|
""" |
|
|
given matrix of the same object in two coordinate A and B, |
|
|
return matrix B in the coordinate of A |
|
|
|
|
|
Args: |
|
|
matCtoA (tensor): [4, 4] world matrix |
|
|
matCtoB (tensor): [4, 4] world matrix |
|
|
""" |
|
|
if isinstance(matCtoA, torch.Tensor): |
|
|
matCtoB_inv = torch.inverse(matCtoB) |
|
|
elif isinstance(matCtoA, np.ndarray): |
|
|
matCtoB_inv = np.linalg.inv(matCtoB) |
|
|
else: |
|
|
raise ValueError |
|
|
matCtoB_inv = normalized_matrix(matCtoB_inv) |
|
|
if isinstance(matCtoA, torch.Tensor): |
|
|
mat_BtoA = torch.matmul(matCtoA, matCtoB_inv) |
|
|
elif isinstance(matCtoA, np.ndarray): |
|
|
mat_BtoA = np.matmul(matCtoA, matCtoB_inv) |
|
|
mat_BtoA = normalized_matrix(mat_BtoA) |
|
|
return mat_BtoA |
|
|
|
|
|
|
|
|
def get_mat_BtoA(matA, matB): |
|
|
""" |
|
|
return matrix B in the coordinate of A |
|
|
|
|
|
Args: |
|
|
matA (tensor): [4, 4] world matrix |
|
|
matB (tensor): [4, 4] world matrix |
|
|
""" |
|
|
if isinstance(matA, torch.Tensor): |
|
|
matA_inv = torch.inverse(matA) |
|
|
elif isinstance(matA, np.ndarray): |
|
|
matA_inv = np.linalg.inv(matA) |
|
|
else: |
|
|
raise ValueError |
|
|
matA_inv = normalized_matrix(matA_inv) |
|
|
if isinstance(matA, torch.Tensor): |
|
|
mat_BtoA = torch.matmul(matA_inv, matB) |
|
|
elif isinstance(matA, np.ndarray): |
|
|
mat_BtoA = np.matmul(matA_inv, matB) |
|
|
mat_BtoA = normalized_matrix(mat_BtoA) |
|
|
return mat_BtoA |
|
|
|
|
|
|
|
|
def get_mat_BfromA(matA, matBtoA): |
|
|
""" |
|
|
return world matrix B given matrix A and mat B realtive to A |
|
|
|
|
|
Args: |
|
|
matA (_type_): [4, 4] world matrix |
|
|
matBtoA (_type_): [4, 4] matrix B relative to A |
|
|
""" |
|
|
if isinstance(matA, torch.Tensor): |
|
|
matB = torch.matmul(matA, matBtoA) |
|
|
if isinstance(matA, np.ndarray): |
|
|
matB = np.matmul(matA, matBtoA) |
|
|
matB = normalized_matrix(matB) |
|
|
return matB |
|
|
|
|
|
|
|
|
def get_relative_position_to(pos, mat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
pos (_type_): [N, M, 3] or [N, 3] |
|
|
mat (_type_): [N, 4, 4] or [4, 4] |
|
|
|
|
|
Returns: |
|
|
_type_: _description_ |
|
|
""" |
|
|
if isinstance(mat, torch.Tensor): |
|
|
mat_inv = torch.inverse(mat) |
|
|
elif isinstance(mat, np.ndarray): |
|
|
mat_inv = np.linalg.inv(mat) |
|
|
else: |
|
|
raise ValueError |
|
|
mat_inv = normalized_matrix(mat_inv) |
|
|
if isinstance(mat, torch.Tensor): |
|
|
rot_pos = torch.matmul(mat_inv[..., :-1, :-1], pos.transpose(-1, -2)).transpose( |
|
|
-1, -2 |
|
|
) |
|
|
elif isinstance(mat, np.ndarray): |
|
|
rot_pos = np.matmul(mat_inv[..., :-1, :-1], pos.swapaxes(-1, -2)).swapaxes( |
|
|
-1, -2 |
|
|
) |
|
|
world_pos = rot_pos + mat_inv[..., None, :-1, 3] |
|
|
return world_pos |
|
|
|
|
|
|
|
|
def get_rotation(mat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
mat (_type_): [..., 4, 4] |
|
|
|
|
|
Returns: |
|
|
_type_: _description_ |
|
|
""" |
|
|
return mat[..., :-1, :-1] |
|
|
|
|
|
|
|
|
def set_rotation(mat, rotmat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
mat (_type_): [..., 4, 4] |
|
|
|
|
|
Returns: |
|
|
_type_: _description_ |
|
|
""" |
|
|
mat[..., :-1, :-1] = rotmat |
|
|
return mat |
|
|
|
|
|
|
|
|
def set_position(mat, pos): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
mat (_type_): [..., 4, 4] |
|
|
|
|
|
Returns: |
|
|
_type_: _description_ |
|
|
""" |
|
|
mat[..., :-1, 3] = pos |
|
|
return mat |
|
|
|
|
|
|
|
|
def get_position(mat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
mat (_type_): [..., 4, 4] |
|
|
|
|
|
Returns: |
|
|
_type_: _description_ |
|
|
""" |
|
|
return mat[..., :-1, 3] |
|
|
|
|
|
|
|
|
def get_position_from(pos, mat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
pos (_type_): [N, M, 3] or [N, 3] |
|
|
mat (_type_): [N, 4, 4] or [4, 4] |
|
|
|
|
|
Returns: |
|
|
_type_: _description_ |
|
|
""" |
|
|
if isinstance(mat, torch.Tensor): |
|
|
rot_pos = torch.matmul(mat[..., :-1, :-1], pos.transpose(-1, -2)).transpose( |
|
|
-1, -2 |
|
|
) |
|
|
elif isinstance(mat, np.ndarray): |
|
|
rot_pos = np.matmul(mat[..., :-1, :-1], pos.swapaxes(-1, -2)).swapaxes(-1, -2) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
world_pos = rot_pos + mat[..., None, :-1, 3] |
|
|
return world_pos |
|
|
|
|
|
|
|
|
def get_position_from_rotmat(pos, mat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
pos (_type_): [N, M, 3] or [N, 3] |
|
|
mat (_type_): [N, 4, 4] or [4, 4] |
|
|
|
|
|
Returns: |
|
|
_type_: _description_ |
|
|
""" |
|
|
if isinstance(mat, torch.Tensor): |
|
|
rot_pos = torch.matmul(mat, pos.transpose(-1, -2)).transpose(-1, -2) |
|
|
elif isinstance(mat, np.ndarray): |
|
|
rot_pos = np.matmul(mat, pos.swapaxes(-1, -2)).swapaxes(-1, -2) |
|
|
else: |
|
|
raise ValueError |
|
|
return rot_pos |
|
|
|
|
|
|
|
|
def get_relative_direction_to(dir, mat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
dir (_type_): [N, M, 3] or [N, 3] |
|
|
mat (_type_): [N, 4, 4] or [4, 4] |
|
|
|
|
|
Returns: |
|
|
_type_: _description_ |
|
|
""" |
|
|
if isinstance(mat, torch.Tensor): |
|
|
mat_inv = torch.inverse(mat) |
|
|
elif isinstance(mat, np.ndarray): |
|
|
mat_inv = np.linalg.inv(mat) |
|
|
else: |
|
|
raise ValueError |
|
|
mat_inv = normalized_matrix(mat_inv) |
|
|
rot_mat_inv = mat_inv[..., :3, :3] |
|
|
if isinstance(mat, torch.Tensor): |
|
|
rel_dir = torch.matmul(rot_mat_inv, dir.transpose(-1, -2)) |
|
|
return rel_dir.transpose(-1, -2) |
|
|
elif isinstance(mat, np.ndarray): |
|
|
rel_dir = np.matmul(rot_mat_inv, dir.swapaxes(-1, -2)) |
|
|
return rel_dir.swapaxes(-1, -2) |
|
|
else: |
|
|
raise ValueError |
|
|
return |
|
|
|
|
|
|
|
|
def get_direction_from(dir, mat): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
dir (_type_): [N, M, 3] or [N, 3] |
|
|
mat (_type_): [N, 4, 4] or [4, 4] |
|
|
|
|
|
Returns: |
|
|
tensor: [N, M, 3] or [N, 3] |
|
|
""" |
|
|
rot_mat = mat[..., :3, :3] |
|
|
if isinstance(mat, torch.Tensor): |
|
|
world_dir = torch.matmul(rot_mat, dir.transpose(-1, -2)) |
|
|
return world_dir.transpose(-1, -2) |
|
|
elif isinstance(mat, np.ndarray): |
|
|
world_dir = np.matmul(rot_mat, dir.swapaxes(-1, -2)) |
|
|
return world_dir.swapaxes(-1, -2) |
|
|
else: |
|
|
raise ValueError |
|
|
return |
|
|
|
|
|
|
|
|
def get_coord_vis(pos, rot_mat, scale=1.0): |
|
|
forward = rot_mat[..., :, 2] |
|
|
up = rot_mat[..., :, 1] |
|
|
right = rot_mat[..., :, 0] |
|
|
return pos + right * scale, pos + up * scale, pos + forward * scale |
|
|
|
|
|
|
|
|
def project_vec(vec): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
vec (tensor): [*, 12], pos, forward, up and right |
|
|
|
|
|
Returns: |
|
|
proj_vec (tensor): [*, 4], posx, posz, forwardx, forwardz |
|
|
""" |
|
|
posx = vec[..., 0:1] |
|
|
posz = vec[..., 2:3] |
|
|
forwardx = vec[..., 3:4] |
|
|
forwardz = vec[..., 5:6] |
|
|
if isinstance(vec, torch.Tensor): |
|
|
proj_vec = torch.cat((posx, posz, forwardx, forwardz), dim=-1) |
|
|
elif isinstance(vec, np.ndarray): |
|
|
proj_vec = np.concatenate((posx, posz, forwardx, forwardz), axis=-1) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
return proj_vec |
|
|
|
|
|
|
|
|
def xz2xyz(vec): |
|
|
x = vec[..., 0:1] |
|
|
z = vec[..., 1:2] |
|
|
if isinstance(vec, torch.Tensor): |
|
|
y = torch.zeros(vec.shape[:-1] + (1,), device=vec.device) |
|
|
xyz_vec = torch.cat((x, y, z), dim=-1) |
|
|
elif isinstance(vec, np.ndarray): |
|
|
y = np.zeros(vec.shape[:-1] + (1,)) |
|
|
xyz_vec = np.concatenate((x, y, z), axis=-1) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
return xyz_vec |
|
|
|
|
|
|
|
|
def normalized(vec): |
|
|
if isinstance(vec, torch.Tensor): |
|
|
norm_vec = vec / (vec.norm(2, dim=-1, keepdim=True) + 1e-9) |
|
|
elif isinstance(vec, np.ndarray): |
|
|
norm_vec = vec / (np.linalg.norm(vec, ord=2, axis=-1, keepdims=True) + 1e-9) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
return norm_vec |
|
|
|
|
|
|
|
|
def normalized_matrix(mat): |
|
|
if mat.shape[-1] == 4: |
|
|
rot_mat = mat[..., :-1, :-1] |
|
|
else: |
|
|
rot_mat = mat |
|
|
if isinstance(mat, torch.Tensor): |
|
|
rot_mat_norm = rot_mat / (rot_mat.norm(2, dim=-2, keepdim=True) + 1e-9) |
|
|
norm_mat = torch.zeros_like(mat) |
|
|
elif isinstance(mat, np.ndarray): |
|
|
rot_mat_norm = rot_mat / ( |
|
|
np.linalg.norm(rot_mat, ord=2, axis=-2, keepdims=True) + 1e-9 |
|
|
) |
|
|
norm_mat = np.zeros_like(mat) |
|
|
else: |
|
|
raise ValueError |
|
|
if mat.shape[-1] == 4: |
|
|
norm_mat[..., :-1, :-1] = rot_mat_norm |
|
|
norm_mat[..., :-1, -1] = mat[..., :-1, -1] |
|
|
norm_mat[..., -1, -1] = 1.0 |
|
|
else: |
|
|
norm_mat = rot_mat_norm |
|
|
return norm_mat |
|
|
|
|
|
|
|
|
def get_rot_mat_from_forward(forward): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
forward (tensor): [N, M, 3] |
|
|
|
|
|
Returns: |
|
|
mat (tensor): [N, M, 3, 3] |
|
|
""" |
|
|
if isinstance(forward, torch.Tensor): |
|
|
mat = torch.eye(3, device=forward.device).repeat(forward.shape[:-1] + (1, 1)) |
|
|
right = torch.zeros_like(forward) |
|
|
elif isinstance(forward, np.ndarray): |
|
|
mat = np.eye(3, dtype=np.float32) |
|
|
for _ in range(len(forward.shape) - 1): |
|
|
mat = mat[None] |
|
|
mat = np.tile(mat, forward.shape[:-1] + (1, 1)) |
|
|
right = np.zeros_like(forward) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
right[..., 0] = forward[..., 2] |
|
|
right[..., 1] = 0.0 |
|
|
right[..., 2] = -forward[..., 0] |
|
|
|
|
|
|
|
|
mat[..., 2] = normalized(forward) |
|
|
right = normalized(right) |
|
|
mat[..., 0] = right |
|
|
return mat |
|
|
|
|
|
|
|
|
def get_rot_mat_from_forward_up(forward, up): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
forward (tensor): [N, M, 3] |
|
|
up (tensor): [N, M, 3] |
|
|
|
|
|
Returns: |
|
|
mat (tensor): [N, M, 3, 3] |
|
|
""" |
|
|
if isinstance(forward, torch.Tensor): |
|
|
mat = torch.eye(3, device=forward.device).repeat(forward.shape[:-1] + (1, 1)) |
|
|
right = torch.cross(up, forward) |
|
|
elif isinstance(forward, np.ndarray): |
|
|
mat = np.eye(3, dtype=np.float32) |
|
|
for _ in range(len(forward.shape) - 1): |
|
|
mat = mat[None] |
|
|
mat = np.tile(mat, forward.shape[:-1] + (1, 1)) |
|
|
right = np.cross(up, forward) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
right = normalized(right) |
|
|
mat[..., 2] = normalized(forward) |
|
|
mat[..., 1] = normalized(up) |
|
|
mat[..., 0] = right |
|
|
return mat |
|
|
|
|
|
|
|
|
def get_rot_mat_from_pose_vec(vec): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
vec (tensor): [N, M, 6] |
|
|
|
|
|
Returns: |
|
|
mat (tensor): [N, M, 3, 3] |
|
|
""" |
|
|
forward = vec[..., :3] |
|
|
up = vec[..., 3:6] |
|
|
return get_rot_mat_from_forward_up(forward, up) |
|
|
|
|
|
|
|
|
def get_TRS(rot_mat, pos): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
rot_mat (tensor): [N, 3, 3] |
|
|
pos (tensor): [N, 3] |
|
|
|
|
|
Returns: |
|
|
mat (tensor): [N, 4, 4] |
|
|
""" |
|
|
if isinstance(rot_mat, torch.Tensor): |
|
|
mat = torch.eye(4, device=pos.device).repeat(pos.shape[:-1] + (1, 1)) |
|
|
elif isinstance(rot_mat, np.ndarray): |
|
|
mat = np.eye(4, dtype=np.float32) |
|
|
for _ in range(len(pos.shape) - 1): |
|
|
mat = mat[None] |
|
|
mat = np.tile(mat, pos.shape[:-1] + (1, 1)) |
|
|
else: |
|
|
raise ValueError |
|
|
mat[..., :3, :3] = rot_mat |
|
|
mat[..., :3, 3] = pos |
|
|
mat = normalized_matrix(mat) |
|
|
return mat |
|
|
|
|
|
|
|
|
def xzvec2mat(vec): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
vec (tensor): [N, 4] |
|
|
|
|
|
Returns: |
|
|
mat (tensor): [N, 4, 4] |
|
|
""" |
|
|
vec_shape = vec.shape[:-1] |
|
|
if isinstance(vec, torch.Tensor): |
|
|
pos = torch.zeros(vec_shape + (3,)) |
|
|
forward = torch.zeros(vec_shape + (3,)) |
|
|
elif isinstance(vec, np.ndarray): |
|
|
pos = np.zeros(vec_shape + (3,)) |
|
|
forward = np.zeros(vec_shape + (3,)) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
pos[..., 0] = vec[..., 0] |
|
|
pos[..., 2] = vec[..., 1] |
|
|
forward[..., 0] = vec[..., 2] |
|
|
forward[..., 2] = vec[..., 3] |
|
|
rot_mat = get_rot_mat_from_forward(forward) |
|
|
mat = get_TRS(rot_mat, pos) |
|
|
return mat |
|
|
|
|
|
|
|
|
def distance(vec1, vec2): |
|
|
return ((vec1 - vec2) ** 2).sum() ** 0.5 |
|
|
|
|
|
|
|
|
def get_relative_pose_from_vec(pose, root, N): |
|
|
root_p_mat = xzvec2mat(root) |
|
|
pose = pose.reshape(-1, N, 12) |
|
|
pose[..., :3] = get_position_from(pose[..., :3], root_p_mat) |
|
|
pose[..., 3:6] = get_direction_from(pose[..., 3:6], root_p_mat) |
|
|
pose[..., 6:9] = get_direction_from(pose[..., 6:9], root_p_mat) |
|
|
pose[..., 9:] = get_direction_from(pose[..., 9:], root_p_mat) |
|
|
pos = pose[..., 0, :3] |
|
|
rot = pose[..., 3:9].reshape(-1, N * 6) |
|
|
pose = np.concatenate((pos, rot), axis=-1) |
|
|
return pose |
|
|
|
|
|
|
|
|
def get_forward_from_pos(pos): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
pos (N, J, 3): joints positions of each frame |
|
|
|
|
|
Returns: |
|
|
_type_: _description_ |
|
|
""" |
|
|
|
|
|
pos_y_vec = torch.tensor([0, 1, 0], dtype=torch.float32).to(pos.device) |
|
|
face_joint_indx = [2, 1, 17, 16] |
|
|
r_hip, l_hip, r_sdr, l_sdr = ( |
|
|
face_joint_indx |
|
|
) |
|
|
cross_hip = pos[..., 0, r_hip, :] - pos[..., 0, l_hip, :] |
|
|
cross_sdr = pos[..., 0, r_sdr, :] - pos[..., 0, l_sdr, :] |
|
|
cross_vec = cross_hip + cross_sdr |
|
|
forward_vec = torch.cross(pos_y_vec, cross_vec, dim=-1) |
|
|
forward_vec = normalized(forward_vec) |
|
|
return forward_vec |
|
|
|
|
|
|
|
|
def project_point_along_ray(p, ray, keepnorm=False): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
p (*, 3): point positions |
|
|
ray (*, 3): ray direction |
|
|
keepnorm: False -> project point on the ray, |
|
|
True -> project point on the ray and keep the point length |
|
|
|
|
|
Returns: |
|
|
_type_: _description_ |
|
|
""" |
|
|
ray = normalized(ray) |
|
|
if keepnorm: |
|
|
new_p = ray * p.norm(dim=-1, keepdim=True) |
|
|
else: |
|
|
dot_product = torch.sum(p * ray, dim=-1, keepdim=True) |
|
|
new_p = dot_product * ray |
|
|
return new_p |
|
|
|
|
|
|
|
|
def solve_point_along_ray_with_constraint(c, ray, p, constraint="x"): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
c (*,): constraint value |
|
|
ray (*, 3): ray direction |
|
|
p (*, 3): start point of the ray |
|
|
|
|
|
Returns: |
|
|
_type_: _description_ |
|
|
""" |
|
|
ray = normalized(ray) |
|
|
if constraint == "x": |
|
|
ind = 0 |
|
|
elif constraint == "y": |
|
|
ind = 1 |
|
|
elif constraint == "z": |
|
|
ind = 2 |
|
|
else: |
|
|
raise ValueError |
|
|
t = (c - p[..., ind]) / ray[..., ind] |
|
|
out_p = ray * t[..., None] + p |
|
|
|
|
|
return out_p |
|
|
|
|
|
|
|
|
def calc_cosine(vec1, vec2, return_angle=False): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
vec1 (*, 3): vector |
|
|
vec2 (*, 3): vector |
|
|
return_angle: True -> return angle, False -> return cosine |
|
|
|
|
|
Returns: |
|
|
_type_: _description_ |
|
|
""" |
|
|
vec1 = normalized(vec1) |
|
|
vec2 = normalized(vec2) |
|
|
cosine = torch.sum(vec1 * vec2, dim=-1) |
|
|
if return_angle: |
|
|
return torch.acos(cosine) |
|
|
return cosine |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def quat_xyzw2wxyz(quat): |
|
|
new_quat = torch.cat([quat[..., 3:4], quat[..., :3]], dim=-1) |
|
|
return new_quat |
|
|
|
|
|
|
|
|
def quat_wxyz2xyzw(quat): |
|
|
new_quat = torch.cat([quat[..., 1:4], quat[..., :1]], dim=-1) |
|
|
return new_quat |
|
|
|
|
|
|
|
|
def quat_mul(a, b): |
|
|
""" |
|
|
quaternion multiplication |
|
|
""" |
|
|
x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3] |
|
|
x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3] |
|
|
|
|
|
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 |
|
|
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 |
|
|
y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 |
|
|
z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 |
|
|
|
|
|
return torch.stack([x, y, z, w], dim=-1) |
|
|
|
|
|
|
|
|
def quat_pos(x): |
|
|
""" |
|
|
make all the real part of the quaternion positive |
|
|
""" |
|
|
q = x |
|
|
z = (q[..., 3:] < 0).float() |
|
|
q = (1 - 2 * z) * q |
|
|
return q |
|
|
|
|
|
|
|
|
def quat_abs(x): |
|
|
""" |
|
|
quaternion norm (unit quaternion represents a 3D rotation, which has norm of 1) |
|
|
""" |
|
|
x = x.norm(p=2, dim=-1) |
|
|
return x |
|
|
|
|
|
|
|
|
def quat_unit(x): |
|
|
""" |
|
|
normalized quaternion with norm of 1 |
|
|
""" |
|
|
norm = quat_abs(x).unsqueeze(-1) |
|
|
return x / (norm.clamp(min=1e-4)) |
|
|
|
|
|
|
|
|
def quat_conjugate(x): |
|
|
""" |
|
|
quaternion with its imaginary part negated |
|
|
""" |
|
|
return torch.cat([-x[..., :3], x[..., 3:]], dim=-1) |
|
|
|
|
|
|
|
|
def quat_real(x): |
|
|
""" |
|
|
real component of the quaternion |
|
|
""" |
|
|
return x[..., 3] |
|
|
|
|
|
|
|
|
def quat_imaginary(x): |
|
|
""" |
|
|
imaginary components of the quaternion |
|
|
""" |
|
|
return x[..., :3] |
|
|
|
|
|
|
|
|
def quat_norm_check(x): |
|
|
""" |
|
|
verify that a quaternion has norm 1 |
|
|
""" |
|
|
assert bool((abs(x.norm(p=2, dim=-1) - 1) < 1e-3).all()), ( |
|
|
"the quaternion is has non-1 norm: {}".format(abs(x.norm(p=2, dim=-1) - 1)) |
|
|
) |
|
|
assert bool((x[..., 3] >= 0).all()), "the quaternion has negative real part" |
|
|
|
|
|
|
|
|
def quat_normalize(q): |
|
|
""" |
|
|
Construct 3D rotation from quaternion (the quaternion needs not to be normalized). |
|
|
""" |
|
|
q = quat_unit(quat_pos(q)) |
|
|
return q |
|
|
|
|
|
|
|
|
def quat_from_xyz(xyz): |
|
|
""" |
|
|
Construct 3D rotation from the imaginary component |
|
|
""" |
|
|
w = (1.0 - xyz.norm()).unsqueeze(-1) |
|
|
assert bool((w >= 0).all()), "xyz has its norm greater than 1" |
|
|
return torch.cat([xyz, w], dim=-1) |
|
|
|
|
|
|
|
|
def quat_identity(shape: List[int]): |
|
|
""" |
|
|
Construct 3D identity rotation given shape |
|
|
""" |
|
|
w = torch.ones(shape + (1,)) |
|
|
xyz = torch.zeros(shape + (3,)) |
|
|
q = torch.cat([xyz, w], dim=-1) |
|
|
return quat_normalize(q) |
|
|
|
|
|
|
|
|
def tgm_quat_from_angle_axis(angle, axis, degree: bool = False): |
|
|
"""Create a 3D rotation from angle and axis of rotation. The rotation is counter-clockwise |
|
|
along the axis. |
|
|
|
|
|
The rotation can be interpreted as a_R_b where frame "b" is the new frame that |
|
|
gets rotated counter-clockwise along the axis from frame "a" |
|
|
|
|
|
:param angle: angle of rotation |
|
|
:type angle: Tensor |
|
|
:param axis: axis of rotation |
|
|
:type axis: Tensor |
|
|
:param degree: put True here if the angle is given by degree |
|
|
:type degree: bool, optional, default=False |
|
|
""" |
|
|
if degree: |
|
|
angle = angle / 180.0 * math.pi |
|
|
theta = (angle / 2).unsqueeze(-1) |
|
|
axis = axis / (axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-4)) |
|
|
xyz = axis * theta.sin() |
|
|
w = theta.cos() |
|
|
return quat_normalize(torch.cat([w, xyz], dim=-1)) |
|
|
|
|
|
|
|
|
def quat_from_rotation_matrix(m): |
|
|
""" |
|
|
Construct a 3D rotation from a valid 3x3 rotation matrices. |
|
|
Reference can be found here: |
|
|
http://www.cg.info.hiroshima-cu.ac.jp/~miyazaki/knowledge/teche52.html |
|
|
|
|
|
:param m: 3x3 orthogonal rotation matrices. |
|
|
:type m: Tensor |
|
|
|
|
|
:rtype: Tensor |
|
|
""" |
|
|
m = m.unsqueeze(0) |
|
|
diag0 = m[..., 0, 0] |
|
|
diag1 = m[..., 1, 1] |
|
|
diag2 = m[..., 2, 2] |
|
|
|
|
|
|
|
|
w = (((diag0 + diag1 + diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5 |
|
|
x = (((diag0 - diag1 - diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5 |
|
|
y = (((-diag0 + diag1 - diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5 |
|
|
z = (((-diag0 - diag1 + diag2 + 1.0) / 4.0).clamp(0.0, None)) ** 0.5 |
|
|
|
|
|
|
|
|
c0 = (w >= x) & (w >= y) & (w >= z) |
|
|
x[c0] *= (m[..., 2, 1][c0] - m[..., 1, 2][c0]).sign() |
|
|
y[c0] *= (m[..., 0, 2][c0] - m[..., 2, 0][c0]).sign() |
|
|
z[c0] *= (m[..., 1, 0][c0] - m[..., 0, 1][c0]).sign() |
|
|
|
|
|
|
|
|
c1 = (x >= w) & (x >= y) & (x >= z) |
|
|
w[c1] *= (m[..., 2, 1][c1] - m[..., 1, 2][c1]).sign() |
|
|
y[c1] *= (m[..., 1, 0][c1] + m[..., 0, 1][c1]).sign() |
|
|
z[c1] *= (m[..., 0, 2][c1] + m[..., 2, 0][c1]).sign() |
|
|
|
|
|
|
|
|
c2 = (y >= w) & (y >= x) & (y >= z) |
|
|
w[c2] *= (m[..., 0, 2][c2] - m[..., 2, 0][c2]).sign() |
|
|
x[c2] *= (m[..., 1, 0][c2] + m[..., 0, 1][c2]).sign() |
|
|
z[c2] *= (m[..., 2, 1][c2] + m[..., 1, 2][c2]).sign() |
|
|
|
|
|
|
|
|
c3 = (z >= w) & (z >= x) & (z >= y) |
|
|
w[c3] *= (m[..., 1, 0][c3] - m[..., 0, 1][c3]).sign() |
|
|
x[c3] *= (m[..., 2, 0][c3] + m[..., 0, 2][c3]).sign() |
|
|
y[c3] *= (m[..., 2, 1][c3] + m[..., 1, 2][c3]).sign() |
|
|
|
|
|
return quat_normalize(torch.stack([x, y, z, w], dim=-1)).squeeze(0) |
|
|
|
|
|
|
|
|
def quat_mul_norm(x, y): |
|
|
""" |
|
|
Combine two set of 3D rotations together using \**\* operator. The shape needs to be |
|
|
broadcastable |
|
|
""" |
|
|
return quat_normalize(quat_mul(x, y)) |
|
|
|
|
|
|
|
|
def quat_rotate(rot, vec): |
|
|
""" |
|
|
Rotate a 3D vector with the 3D rotation |
|
|
""" |
|
|
other_q = torch.cat([vec, torch.zeros_like(vec[..., :1])], dim=-1) |
|
|
return quat_imaginary(quat_mul(quat_mul(rot, other_q), quat_conjugate(rot))) |
|
|
|
|
|
|
|
|
def quat_inverse(x): |
|
|
""" |
|
|
The inverse of the rotation |
|
|
""" |
|
|
return quat_conjugate(x) |
|
|
|
|
|
|
|
|
def quat_identity_like(x): |
|
|
""" |
|
|
Construct identity 3D rotation with the same shape |
|
|
""" |
|
|
return quat_identity(x.shape[:-1]) |
|
|
|
|
|
|
|
|
def quat_angle_axis(x): |
|
|
""" |
|
|
The (angle, axis) representation of the rotation. The axis is normalized to unit length. |
|
|
The angle is guaranteed to be between [0, pi]. |
|
|
""" |
|
|
s = 2 * (x[..., 3] ** 2) - 1 |
|
|
angle = s.clamp(-1, 1).arccos() |
|
|
axis = x[..., :3] |
|
|
axis /= axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-4) |
|
|
return angle, axis |
|
|
|
|
|
|
|
|
def quat_yaw_rotation(x, z_up: bool = True): |
|
|
""" |
|
|
Yaw rotation (rotation along z-axis) |
|
|
""" |
|
|
q = x |
|
|
if z_up: |
|
|
q = torch.cat([torch.zeros_like(q[..., 0:2]), q[..., 2:3], q[..., 3:]], dim=-1) |
|
|
else: |
|
|
q = torch.cat( |
|
|
[ |
|
|
torch.zeros_like(q[..., 0:1]), |
|
|
q[..., 1:2], |
|
|
torch.zeros_like(q[..., 2:3]), |
|
|
q[..., 3:4], |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
return quat_normalize(q) |
|
|
|
|
|
|
|
|
def transform_from_rotation_translation( |
|
|
r: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None |
|
|
): |
|
|
""" |
|
|
Construct a transform from a quaternion and 3D translation. Only one of them can be None. |
|
|
""" |
|
|
assert r is not None or t is not None, "rotation and translation can't be all None" |
|
|
if r is None: |
|
|
assert t is not None |
|
|
r = quat_identity(list(t.shape)) |
|
|
if t is None: |
|
|
t = torch.zeros(list(r.shape) + [3]) |
|
|
return torch.cat([r, t], dim=-1) |
|
|
|
|
|
|
|
|
def transform_identity(shape: List[int]): |
|
|
""" |
|
|
Identity transformation with given shape |
|
|
""" |
|
|
r = quat_identity(shape) |
|
|
t = torch.zeros(shape + [3]) |
|
|
return transform_from_rotation_translation(r, t) |
|
|
|
|
|
|
|
|
def transform_rotation(x): |
|
|
"""Get rotation from transform""" |
|
|
return x[..., :4] |
|
|
|
|
|
|
|
|
def transform_translation(x): |
|
|
"""Get translation from transform""" |
|
|
return x[..., 4:] |
|
|
|
|
|
|
|
|
def transform_inverse(x): |
|
|
""" |
|
|
Inverse transformation |
|
|
""" |
|
|
inv_so3 = quat_inverse(transform_rotation(x)) |
|
|
return transform_from_rotation_translation( |
|
|
r=inv_so3, t=quat_rotate(inv_so3, -transform_translation(x)) |
|
|
) |
|
|
|
|
|
|
|
|
def transform_identity_like(x): |
|
|
""" |
|
|
identity transformation with the same shape |
|
|
""" |
|
|
return transform_identity(x.shape) |
|
|
|
|
|
|
|
|
def transform_mul(x, y): |
|
|
""" |
|
|
Combine two transformation together |
|
|
""" |
|
|
z = transform_from_rotation_translation( |
|
|
r=quat_mul_norm(transform_rotation(x), transform_rotation(y)), |
|
|
t=quat_rotate(transform_rotation(x), transform_translation(y)) |
|
|
+ transform_translation(x), |
|
|
) |
|
|
return z |
|
|
|
|
|
|
|
|
def transform_apply(rot, vec): |
|
|
""" |
|
|
Transform a 3D vector |
|
|
""" |
|
|
assert isinstance(vec, torch.Tensor) |
|
|
return quat_rotate(transform_rotation(rot), vec) + transform_translation(rot) |
|
|
|
|
|
|
|
|
def rot_matrix_det(x): |
|
|
""" |
|
|
Return the determinant of the 3x3 matrix. The shape of the tensor will be as same as the |
|
|
shape of the matrix |
|
|
""" |
|
|
a, b, c = x[..., 0, 0], x[..., 0, 1], x[..., 0, 2] |
|
|
d, e, f = x[..., 1, 0], x[..., 1, 1], x[..., 1, 2] |
|
|
g, h, i = x[..., 2, 0], x[..., 2, 1], x[..., 2, 2] |
|
|
t1 = a * (e * i - f * h) |
|
|
t2 = b * (d * i - f * g) |
|
|
t3 = c * (d * h - e * g) |
|
|
return t1 - t2 + t3 |
|
|
|
|
|
|
|
|
def rot_matrix_integrity_check(x): |
|
|
""" |
|
|
Verify that a rotation matrix has a determinant of one and is orthogonal |
|
|
""" |
|
|
det = rot_matrix_det(x) |
|
|
assert bool((abs(det - 1) < 1e-3).all()), "the matrix has non-one determinant" |
|
|
rtr = x @ x.permute(torch.arange(x.dim() - 2), -1, -2) |
|
|
rtr_gt = rtr.zeros_like() |
|
|
rtr_gt[..., 0, 0] = 1 |
|
|
rtr_gt[..., 1, 1] = 1 |
|
|
rtr_gt[..., 2, 2] = 1 |
|
|
assert bool(((rtr - rtr_gt) < 1e-3).all()), "the matrix is not orthogonal" |
|
|
|
|
|
|
|
|
def rot_matrix_from_quaternion(q): |
|
|
""" |
|
|
Construct rotation matrix from quaternion |
|
|
""" |
|
|
|
|
|
qi, qj, qk, qr = q[..., 0], q[..., 1], q[..., 2], q[..., 3] |
|
|
|
|
|
|
|
|
R00 = 1.0 - 2.0 * (qj**2 + qk**2) |
|
|
R01 = 2 * (qi * qj - qk * qr) |
|
|
R02 = 2 * (qi * qk + qj * qr) |
|
|
R10 = 2 * (qi * qj + qk * qr) |
|
|
R11 = 1.0 - 2.0 * (qi**2 + qk**2) |
|
|
R12 = 2 * (qj * qk - qi * qr) |
|
|
R20 = 2 * (qi * qk - qj * qr) |
|
|
R21 = 2 * (qj * qk + qi * qr) |
|
|
R22 = 1.0 - 2.0 * (qi**2 + qj**2) |
|
|
|
|
|
R0 = torch.stack([R00, R01, R02], dim=-1) |
|
|
R1 = torch.stack([R10, R11, R12], dim=-1) |
|
|
R2 = torch.stack([R20, R21, R22], dim=-1) |
|
|
|
|
|
R = torch.stack([R0, R1, R2], dim=-2) |
|
|
|
|
|
return R |
|
|
|
|
|
|
|
|
def euclidean_to_rotation_matrix(x): |
|
|
""" |
|
|
Get the rotation matrix on the top-left corner of a Euclidean transformation matrix |
|
|
""" |
|
|
return x[..., :3, :3] |
|
|
|
|
|
|
|
|
def euclidean_integrity_check(x): |
|
|
euclidean_to_rotation_matrix(x) |
|
|
assert bool((x[..., 3, :3] == 0).all()), "the last row is illegal" |
|
|
assert bool((x[..., 3, 3] == 1).all()), "the last row is illegal" |
|
|
|
|
|
|
|
|
def euclidean_translation(x): |
|
|
""" |
|
|
Get the translation vector located at the last column of the matrix |
|
|
""" |
|
|
return x[..., :3, 3] |
|
|
|
|
|
|
|
|
def euclidean_inverse(x): |
|
|
""" |
|
|
Compute the matrix that represents the inverse rotation |
|
|
""" |
|
|
s = x.zeros_like() |
|
|
irot = quat_inverse(quat_from_rotation_matrix(x)) |
|
|
s[..., :3, :3] = irot |
|
|
s[..., :3, 4] = quat_rotate(irot, -euclidean_translation(x)) |
|
|
return s |
|
|
|
|
|
|
|
|
def euclidean_to_transform(transformation_matrix): |
|
|
""" |
|
|
Construct a transform from a Euclidean transformation matrix |
|
|
""" |
|
|
return transform_from_rotation_translation( |
|
|
r=quat_from_rotation_matrix( |
|
|
m=euclidean_to_rotation_matrix(transformation_matrix) |
|
|
), |
|
|
t=euclidean_translation(transformation_matrix), |
|
|
) |
|
|
|
|
|
|
|
|
def to_torch(x, dtype=torch.float, device="cuda:0", requires_grad=False): |
|
|
return torch.tensor(x, dtype=dtype, device=device, requires_grad=requires_grad) |
|
|
|
|
|
|
|
|
def quat_mul(a, b): |
|
|
assert a.shape == b.shape |
|
|
shape = a.shape |
|
|
a = a.reshape(-1, 4) |
|
|
b = b.reshape(-1, 4) |
|
|
|
|
|
x1, y1, z1, w1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3] |
|
|
x2, y2, z2, w2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3] |
|
|
ww = (z1 + x1) * (x2 + y2) |
|
|
yy = (w1 - y1) * (w2 + z2) |
|
|
zz = (w1 + y1) * (w2 - z2) |
|
|
xx = ww + yy + zz |
|
|
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) |
|
|
w = qq - ww + (z1 - y1) * (y2 - z2) |
|
|
x = qq - xx + (x1 + w1) * (x2 + w2) |
|
|
y = qq - yy + (w1 - x1) * (y2 + z2) |
|
|
z = qq - zz + (z1 + y1) * (w2 - x2) |
|
|
|
|
|
quat = torch.stack([x, y, z, w], dim=-1).view(shape) |
|
|
|
|
|
return quat |
|
|
|
|
|
|
|
|
def normalize(x, eps: float = 1e-9): |
|
|
return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1) |
|
|
|
|
|
|
|
|
def quat_apply(a, b): |
|
|
shape = b.shape |
|
|
a = a.reshape(-1, 4) |
|
|
b = b.reshape(-1, 3) |
|
|
xyz = a[:, :3] |
|
|
t = xyz.cross(b, dim=-1) * 2 |
|
|
return (b + a[:, 3:] * t + xyz.cross(t, dim=-1)).view(shape) |
|
|
|
|
|
|
|
|
def quat_rotate(q, v): |
|
|
shape = q.shape |
|
|
q_w = q[:, -1] |
|
|
q_vec = q[:, :3] |
|
|
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) |
|
|
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 |
|
|
c = ( |
|
|
q_vec |
|
|
* torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) |
|
|
* 2.0 |
|
|
) |
|
|
return a + b + c |
|
|
|
|
|
|
|
|
def quat_rotate_inverse(q, v): |
|
|
shape = q.shape |
|
|
q_w = q[:, -1] |
|
|
q_vec = q[:, :3] |
|
|
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) |
|
|
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 |
|
|
c = ( |
|
|
q_vec |
|
|
* torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) |
|
|
* 2.0 |
|
|
) |
|
|
return a - b + c |
|
|
|
|
|
|
|
|
def quat_conjugate(a): |
|
|
shape = a.shape |
|
|
a = a.reshape(-1, 4) |
|
|
return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape) |
|
|
|
|
|
|
|
|
def quat_unit(a): |
|
|
return normalize(a) |
|
|
|
|
|
|
|
|
def quat_from_angle_axis(angle, axis): |
|
|
theta = (angle / 2).unsqueeze(-1) |
|
|
xyz = normalize(axis) * torch.sin(theta.clone()) |
|
|
w = torch.cos(theta.clone()) |
|
|
return quat_unit(torch.cat([xyz, w], dim=-1)) |
|
|
|
|
|
|
|
|
def normalize_angle(x): |
|
|
return torch.atan2(torch.sin(x.clone()), torch.cos(x.clone())) |
|
|
|
|
|
|
|
|
def tf_inverse(q, t): |
|
|
q_inv = quat_conjugate(q) |
|
|
return q_inv, -quat_apply(q_inv, t) |
|
|
|
|
|
|
|
|
def tf_apply(q, t, v): |
|
|
return quat_apply(q, v) + t |
|
|
|
|
|
|
|
|
def tf_vector(q, v): |
|
|
return quat_apply(q, v) |
|
|
|
|
|
|
|
|
def tf_combine(q1, t1, q2, t2): |
|
|
return quat_mul(q1, q2), quat_apply(q1, t2) + t1 |
|
|
|
|
|
|
|
|
def get_basis_vector(q, v): |
|
|
return quat_rotate(q, v) |
|
|
|
|
|
|
|
|
def get_axis_params(value, axis_idx, x_value=0.0, dtype=float, n_dims=3): |
|
|
"""construct arguments to `Vec` according to axis index.""" |
|
|
zs = np.zeros((n_dims,)) |
|
|
assert axis_idx < n_dims, "the axis dim should be within the vector dimensions" |
|
|
zs[axis_idx] = 1.0 |
|
|
params = np.where(zs == 1.0, value, zs) |
|
|
params[0] = x_value |
|
|
return list(params.astype(dtype)) |
|
|
|
|
|
|
|
|
def copysign(a, b): |
|
|
|
|
|
a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0]) |
|
|
return torch.abs(a) * torch.sign(b) |
|
|
|
|
|
|
|
|
def get_euler_xyz(q): |
|
|
qx, qy, qz, qw = 0, 1, 2, 3 |
|
|
|
|
|
sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz]) |
|
|
cosr_cosp = ( |
|
|
q[:, qw] * q[:, qw] |
|
|
- q[:, qx] * q[:, qx] |
|
|
- q[:, qy] * q[:, qy] |
|
|
+ q[:, qz] * q[:, qz] |
|
|
) |
|
|
roll = torch.atan2(sinr_cosp, cosr_cosp) |
|
|
|
|
|
|
|
|
sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx]) |
|
|
pitch = torch.where( |
|
|
torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp) |
|
|
) |
|
|
|
|
|
|
|
|
siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy]) |
|
|
cosy_cosp = ( |
|
|
q[:, qw] * q[:, qw] |
|
|
+ q[:, qx] * q[:, qx] |
|
|
- q[:, qy] * q[:, qy] |
|
|
- q[:, qz] * q[:, qz] |
|
|
) |
|
|
yaw = torch.atan2(siny_cosp, cosy_cosp) |
|
|
|
|
|
return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi) |
|
|
|
|
|
|
|
|
def quat_from_euler_xyz(roll, pitch, yaw): |
|
|
cy = torch.cos(yaw * 0.5) |
|
|
sy = torch.sin(yaw * 0.5) |
|
|
cr = torch.cos(roll * 0.5) |
|
|
sr = torch.sin(roll * 0.5) |
|
|
cp = torch.cos(pitch * 0.5) |
|
|
sp = torch.sin(pitch * 0.5) |
|
|
|
|
|
qw = cy * cr * cp + sy * sr * sp |
|
|
qx = cy * sr * cp - sy * cr * sp |
|
|
qy = cy * cr * sp + sy * sr * cp |
|
|
qz = sy * cr * cp - cy * sr * sp |
|
|
|
|
|
return torch.stack([qx, qy, qz, qw], dim=-1) |
|
|
|
|
|
|
|
|
def torch_rand_float(lower, upper, shape, device): |
|
|
|
|
|
return (upper - lower) * torch.rand(*shape, device=device) + lower |
|
|
|
|
|
|
|
|
def torch_random_dir_2(shape, device): |
|
|
|
|
|
angle = torch_rand_float(-np.pi, np.pi, shape, device).squeeze(-1) |
|
|
return torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1) |
|
|
|
|
|
|
|
|
def tensor_clamp(t, min_t, max_t): |
|
|
return torch.max(torch.min(t, max_t), min_t) |
|
|
|
|
|
|
|
|
def scale(x, lower, upper): |
|
|
return 0.5 * (x + 1.0) * (upper - lower) + lower |
|
|
|
|
|
|
|
|
def unscale(x, lower, upper): |
|
|
return (2.0 * x - upper - lower) / (upper - lower) |
|
|
|
|
|
|
|
|
def unscale_np(x, lower, upper): |
|
|
return (2.0 * x - upper - lower) / (upper - lower) |
|
|
|
|
|
|
|
|
def quat_to_angle_axis(q): |
|
|
|
|
|
|
|
|
|
|
|
min_theta = 1e-5 |
|
|
qx, qy, qz, qw = 0, 1, 2, 3 |
|
|
|
|
|
sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw]) |
|
|
angle = 2 * torch.acos(q[..., qw]) |
|
|
angle = normalize_angle(angle) |
|
|
sin_theta_expand = sin_theta.unsqueeze(-1) |
|
|
axis = q[..., qx:qw] / sin_theta_expand |
|
|
|
|
|
mask = torch.abs(sin_theta) > min_theta |
|
|
default_axis = torch.zeros_like(axis) |
|
|
default_axis[..., -1] = 1 |
|
|
|
|
|
angle = torch.where(mask, angle, torch.zeros_like(angle)) |
|
|
mask_expand = mask.unsqueeze(-1) |
|
|
axis = torch.where(mask_expand, axis, default_axis) |
|
|
return angle, axis |
|
|
|
|
|
|
|
|
def angle_axis_to_exp_map(angle, axis): |
|
|
|
|
|
|
|
|
angle_expand = angle.unsqueeze(-1) |
|
|
exp_map = angle_expand * axis |
|
|
return exp_map |
|
|
|
|
|
|
|
|
def quat_to_exp_map(q): |
|
|
|
|
|
|
|
|
|
|
|
angle, axis = quat_to_angle_axis(q) |
|
|
exp_map = angle_axis_to_exp_map(angle, axis) |
|
|
return exp_map |
|
|
|
|
|
|
|
|
def quat_to_tan_norm(q): |
|
|
|
|
|
|
|
|
ref_tan = torch.zeros_like(q[..., 0:3]) |
|
|
ref_tan[..., 0] = 1 |
|
|
tan = quat_rotate(q, ref_tan) |
|
|
|
|
|
ref_norm = torch.zeros_like(q[..., 0:3]) |
|
|
ref_norm[..., -1] = 1 |
|
|
norm = quat_rotate(q, ref_norm) |
|
|
|
|
|
norm_tan = torch.cat([tan, norm], dim=len(tan.shape) - 1) |
|
|
return norm_tan |
|
|
|
|
|
|
|
|
def euler_xyz_to_exp_map(roll, pitch, yaw): |
|
|
|
|
|
q = quat_from_euler_xyz(roll, pitch, yaw) |
|
|
exp_map = quat_to_exp_map(q) |
|
|
return exp_map |
|
|
|
|
|
|
|
|
def exp_map_to_angle_axis(exp_map): |
|
|
min_theta = 1e-5 |
|
|
|
|
|
angle = torch.norm(exp_map.clone(), dim=-1) + 1e-6 |
|
|
angle_exp = torch.unsqueeze(angle, dim=-1) |
|
|
axis = exp_map.clone() / angle_exp.clone() |
|
|
angle = normalize_angle(angle) |
|
|
|
|
|
default_axis = torch.zeros_like(exp_map) |
|
|
default_axis[..., -1] = 1 |
|
|
|
|
|
mask = torch.abs(angle) > min_theta |
|
|
angle = torch.where(mask, angle, torch.zeros_like(angle)) |
|
|
mask_expand = mask.unsqueeze(-1) |
|
|
axis = torch.where(mask_expand, axis, default_axis) |
|
|
|
|
|
return angle, axis |
|
|
|
|
|
|
|
|
def exp_map_to_quat(exp_map): |
|
|
angle, axis = exp_map_to_angle_axis(exp_map) |
|
|
q = quat_from_angle_axis(angle, axis) |
|
|
return q |
|
|
|
|
|
|
|
|
def slerp(q0, q1, t): |
|
|
|
|
|
cos_half_theta = torch.sum(q0 * q1, dim=-1) |
|
|
|
|
|
neg_mask = cos_half_theta < 0 |
|
|
q1 = q1.clone() |
|
|
q1[neg_mask] = -q1[neg_mask] |
|
|
cos_half_theta = torch.abs(cos_half_theta) |
|
|
cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1) |
|
|
|
|
|
half_theta = torch.acos(cos_half_theta) |
|
|
sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta) |
|
|
|
|
|
ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta |
|
|
ratioB = torch.sin(t * half_theta) / sin_half_theta |
|
|
|
|
|
new_q = ratioA * q0 + ratioB * q1 |
|
|
|
|
|
new_q = torch.where(torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q) |
|
|
new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q) |
|
|
|
|
|
return new_q |
|
|
|
|
|
|
|
|
def calc_heading_vec(q, head_ind=0): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ref_dir = torch.zeros_like(q[..., 0:3]) |
|
|
ref_dir[..., head_ind] = 1 |
|
|
rot_dir = quat_rotate(q, ref_dir) |
|
|
|
|
|
return rot_dir |
|
|
|
|
|
|
|
|
def calc_heading(q, head_ind=0, gravity_axis="z"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ref_dir = torch.zeros_like(q[..., 0:3]) |
|
|
ref_dir[..., head_ind] = 1 |
|
|
|
|
|
shape = ref_dir.shape[:-1] |
|
|
q = q.reshape((-1, 4)) |
|
|
ref_dir = ref_dir.reshape(-1, 3) |
|
|
rot_dir = quat_rotate(q, ref_dir) |
|
|
rot_dir = rot_dir.reshape(shape + (3,)) |
|
|
if gravity_axis == "z": |
|
|
heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0]) |
|
|
elif gravity_axis == "y": |
|
|
heading = torch.atan2(rot_dir[..., 0], rot_dir[..., 2]) |
|
|
elif gravity_axis == "x": |
|
|
heading = torch.atan2(rot_dir[..., 2], rot_dir[..., 1]) |
|
|
return heading |
|
|
|
|
|
|
|
|
def calc_heading_quat(q, head_ind=0, gravity_axis="z"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
heading = calc_heading(q, head_ind, gravity_axis=gravity_axis) |
|
|
axis = torch.zeros_like(q[..., 0:3]) |
|
|
if gravity_axis == "z": |
|
|
g_axis = 2 |
|
|
elif gravity_axis == "y": |
|
|
g_axis = 1 |
|
|
elif gravity_axis == "x": |
|
|
g_axis = 0 |
|
|
axis[..., g_axis] = 1 |
|
|
|
|
|
heading_q = quat_from_angle_axis(heading, axis) |
|
|
return heading_q |
|
|
|
|
|
|
|
|
def calc_heading_quat_inv(q, head_ind=0): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
heading = calc_heading(q, head_ind) |
|
|
axis = torch.zeros_like(q[..., 0:3]) |
|
|
axis[..., 2] = 1 |
|
|
|
|
|
heading_q = quat_from_angle_axis(-heading, axis) |
|
|
return heading_q |
|
|
|
|
|
|
|
|
def forward_kinematics(mat, parent): |
|
|
"""_summary_ |
|
|
|
|
|
Args: |
|
|
mat ([..., N, 3, 3]): _description_ |
|
|
parent (): _description_ |
|
|
""" |
|
|
if isinstance(mat, torch.Tensor): |
|
|
rotations = torch.eye(mat.shape[-1], device=mat.device) |
|
|
rotations = rotations.repeat(mat.shape[:-2] + (1, 1)) |
|
|
else: |
|
|
rotations = np.eye(mat.shape[-1], dtype=np.float32) |
|
|
rotations = np.tile(rotations, mat.shape[:-2] + (1, 1)) |
|
|
for i in range(mat.shape[-3]): |
|
|
if parent[i] != -1: |
|
|
if isinstance(mat, torch.Tensor): |
|
|
|
|
|
new_mat = get_mat_BfromA( |
|
|
rotations[..., parent[i], :, :], mat[..., i, :, :] |
|
|
) |
|
|
rotations = torch.cat( |
|
|
( |
|
|
rotations[..., :i, :, :], |
|
|
new_mat[..., None, :, :], |
|
|
rotations[..., i + 1 :, :, :], |
|
|
), |
|
|
dim=-3, |
|
|
) |
|
|
else: |
|
|
rotations[..., i, :, :] = get_mat_BfromA( |
|
|
rotations[..., parent[i], :, :], mat[..., i, :, :] |
|
|
) |
|
|
else: |
|
|
if isinstance(mat, torch.Tensor): |
|
|
|
|
|
rotations = torch.cat( |
|
|
(mat[..., : i + 1, :, :], rotations[..., i + 1 :, :, :]), dim=-3 |
|
|
) |
|
|
else: |
|
|
rotations[..., i, :, :] = mat[..., i, :, :] |
|
|
return rotations |
|
|
|