NMR / src /utils /torch_utils.py
RayZhao's picture
update visulization
4417ac0
# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from typing import Tuple
import torch
from torch import Tensor
def euler_from_quaternion(quat_angle):
"""
Convert a quaternion into euler angles (roll, pitch, yaw)
roll is rotation around x in radians (counterclockwise)
pitch is rotation around y in radians (counterclockwise)
yaw is rotation around z in radians (counterclockwise)
"""
x = quat_angle[:,0]; y = quat_angle[:,1]; z = quat_angle[:,2]; w = quat_angle[:,3]
t0 = +2.0 * (w * x + y * z)
t1 = +1.0 - 2.0 * (x * x + y * y)
roll_x = torch.atan2(t0, t1)
t2 = +2.0 * (w * y - z * x)
t2 = torch.clip(t2, -1, 1)
pitch_y = torch.asin(t2)
t3 = +2.0 * (w * z + x * y)
t4 = +1.0 - 2.0 * (y * y + z * z)
yaw_z = torch.atan2(t3, t4)
return roll_x, pitch_y, yaw_z # in radians
@torch.jit.script
def normalize(x, eps: float = 1e-9):
return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)
@torch.jit.script
def normalize_angle(x):
return torch.atan2(torch.sin(x), torch.cos(x))
@torch.jit.script
def quat_rotate(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * \
torch.bmm(q_vec.view(shape[0], 1, 3), v.view(
shape[0], 3, 1)).squeeze(-1) * 2.0
return a + b + c
@torch.jit.script
def quat_rotate_inverse(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * \
torch.bmm(q_vec.view(shape[0], 1, 3), v.view(
shape[0], 3, 1)).squeeze(-1) * 2.0
return a - b + c
@torch.jit.script
def quat_from_euler_xyz(roll, pitch, yaw):
cy = torch.cos(yaw * 0.5)
sy = torch.sin(yaw * 0.5)
cr = torch.cos(roll * 0.5)
sr = torch.sin(roll * 0.5)
cp = torch.cos(pitch * 0.5)
sp = torch.sin(pitch * 0.5)
qw = cy * cr * cp + sy * sr * sp
qx = cy * sr * cp - sy * cr * sp
qy = cy * cr * sp + sy * sr * cp
qz = sy * cr * cp - cy * sr * sp
return torch.stack([qx, qy, qz, qw], dim=-1)
@torch.jit.script
def quat_unit(a):
return normalize(a)
@torch.jit.script
def quat_from_angle_axis(angle, axis):
theta = (angle / 2).unsqueeze(-1)
xyz = normalize(axis) * theta.sin()
w = theta.cos()
return quat_unit(torch.cat([xyz, w], dim=-1))
@torch.jit.script
def quat_mul(a, b):
assert a.shape == b.shape
shape = a.shape
a = a.reshape(-1, 4)
b = b.reshape(-1, 4)
x1, y1, z1, w1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3]
x2, y2, z2, w2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3]
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
quat = torch.stack([x, y, z, w], dim=-1).view(shape)
return quat
@torch.jit.script
def quat_conjugate(a):
shape = a.shape
a = a.reshape(-1, 4)
return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape)
@torch.jit.script
def quat_to_angle_axis(q):
# computes axis-angle representation from quaternion q
# q must be normalized
min_theta = 1e-5
qx, qy, qz, qw = 0, 1, 2, 3
sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw])
angle = 2 * torch.acos(q[..., qw])
angle = normalize_angle(angle)
sin_theta_expand = sin_theta.unsqueeze(-1)
axis = q[..., qx:qw] / sin_theta_expand
mask = torch.abs(sin_theta) > min_theta
default_axis = torch.zeros_like(axis)
default_axis[..., -1] = 1
angle = torch.where(mask, angle, torch.zeros_like(angle))
mask_expand = mask.unsqueeze(-1)
axis = torch.where(mask_expand, axis, default_axis)
return angle, axis
@torch.jit.script
def angle_axis_to_exp_map(angle, axis):
# compute exponential map from axis-angle
angle_expand = angle.unsqueeze(-1)
exp_map = angle_expand * axis
return exp_map
@torch.jit.script
def quat_to_exp_map(q):
# compute exponential map from quaternion
# q must be normalized
angle, axis = quat_to_angle_axis(q)
exp_map = angle_axis_to_exp_map(angle, axis)
return exp_map
@torch.jit.script
def quat_to_tan_norm(q):
# represents a rotation using the tangent and normal vectors
ref_tan = torch.zeros_like(q[..., 0:3])
ref_tan[..., 0] = 1
tan = quat_rotate(q, ref_tan)
ref_norm = torch.zeros_like(q[..., 0:3])
ref_norm[..., -1] = 1
norm = quat_rotate(q, ref_norm)
norm_tan = torch.cat([tan, norm], dim=len(tan.shape) - 1)
return norm_tan
@torch.jit.script
def euler_xyz_to_exp_map(roll, pitch, yaw):
q = quat_from_euler_xyz(roll, pitch, yaw)
exp_map = quat_to_exp_map(q)
return exp_map
@torch.jit.script
def exp_map_to_angle_axis(exp_map):
min_theta = 1e-5
angle = torch.norm(exp_map, dim=-1)
angle_exp = torch.unsqueeze(angle, dim=-1)
axis = exp_map / angle_exp
angle = normalize_angle(angle)
default_axis = torch.zeros_like(exp_map)
default_axis[..., -1] = 1
mask = torch.abs(angle) > min_theta
angle = torch.where(mask, angle, torch.zeros_like(angle))
mask_expand = mask.unsqueeze(-1)
axis = torch.where(mask_expand, axis, default_axis)
return angle, axis
@torch.jit.script
def exp_map_to_quat(exp_map):
angle, axis = exp_map_to_angle_axis(exp_map)
q = quat_from_angle_axis(angle, axis)
return q
@torch.jit.script
def slerp(q0, q1, t):
assert(len(t.shape) == len(q0.shape) - 1)
cos_half_theta = torch.sum(q0 * q1, dim=-1)
neg_mask = cos_half_theta < 0
q1 = torch.where(neg_mask.unsqueeze(-1), -q1, q1)
cos_half_theta = torch.abs(cos_half_theta)
cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)
half_theta = torch.acos(cos_half_theta)
sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta)
t = t.unsqueeze(-1)
ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta
ratioB = torch.sin(t * half_theta) / sin_half_theta
new_q = ratioA * q0 + ratioB * q1
new_q = torch.where(torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q)
new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)
return new_q
@torch.jit.script
def slerp2(q0, q1, t):
cos_half_theta = torch.sum(q0 * q1, dim=-1)
neg_mask = cos_half_theta < 0
q1 = q1.clone()
q1[neg_mask] = -q1[neg_mask]
cos_half_theta = torch.abs(cos_half_theta)
cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)
half_theta = torch.acos(cos_half_theta);
sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta);
ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta;
ratioB = torch.sin(t * half_theta) / sin_half_theta;
new_q = ratioA * q0 + ratioB * q1
new_q = torch.where(torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q)
new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)
return new_q
@torch.jit.script
def calc_heading(q):
# calculate heading direction from quaternion
# the heading is the direction on the xy plane
# q must be normalized
ref_dir = torch.zeros_like(q[..., 0:3])
ref_dir[..., 0] = 1
rot_dir = quat_rotate(q, ref_dir)
heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])
return heading
@torch.jit.script
def calc_heading_quat(q):
# calculate heading rotation from quaternion
# the heading is the direction on the xy plane
# q must be normalized
heading = calc_heading(q)
axis = torch.zeros_like(q[..., 0:3])
axis[..., 2] = 1
heading_q = quat_from_angle_axis(heading, axis)
return heading_q
@torch.jit.script
def calc_heading_quat_inv(q):
# calculate heading rotation from quaternion
# the heading is the direction on the xy plane
# q must be normalized
heading = calc_heading(q)
axis = torch.zeros_like(q[..., 0:3])
axis[..., 2] = 1
heading_q = quat_from_angle_axis(-heading, axis)
return heading_q
@torch.jit.script
def quat_pos(x):
q = x
z = (q[..., 3:] < 0).float()
q = (1 - 2 * z) * q
return q
@torch.jit.script
def quat_to_axis_angle(q):
eps = 1e-5
qx, qy, qz, qw = 0, 1, 2, 3
# need to make sure w is not negative to calculate geodesic distance
q = quat_pos(q)
length = torch.norm(q[..., 0:3], dim=-1, p=2)
angle = 2.0 * torch.atan2(length, q[..., qw])
axis = q[..., qx:qw] / length.unsqueeze(-1)
default_axis = torch.zeros_like(axis)
default_axis[..., -1] = 1
mask = length > eps
angle = torch.where(mask, angle, torch.zeros_like(angle))
mask_expand = mask.unsqueeze(-1)
axis = torch.where(mask_expand, axis, default_axis)
return axis, angle
@torch.jit.script
def quat_diff(q0, q1):
dq = quat_mul(q1, quat_conjugate(q0))
return dq
@torch.jit.script
def quat_diff_angle(q0, q1):
dq = quat_diff(q0, q1)
_, angle = quat_to_axis_angle(dq)
return angle
@torch.jit.script
def axis_angle_to_quat(axis, angle):
# type: (Tensor, Tensor) -> Tensor
theta = (angle / 2).unsqueeze(-1)
xyz = normalize(axis) * theta.sin()
w = theta.cos()
return quat_unit(torch.cat([xyz, w], dim=-1))