# Copyright (c) 2018-2022, NVIDIA Corporation # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from typing import Tuple import torch from torch import Tensor def euler_from_quaternion(quat_angle): """ Convert a quaternion into euler angles (roll, pitch, yaw) roll is rotation around x in radians (counterclockwise) pitch is rotation around y in radians (counterclockwise) yaw is rotation around z in radians (counterclockwise) """ x = quat_angle[:,0]; y = quat_angle[:,1]; z = quat_angle[:,2]; w = quat_angle[:,3] t0 = +2.0 * (w * x + y * z) t1 = +1.0 - 2.0 * (x * x + y * y) roll_x = torch.atan2(t0, t1) t2 = +2.0 * (w * y - z * x) t2 = torch.clip(t2, -1, 1) pitch_y = torch.asin(t2) t3 = +2.0 * (w * z + x * y) t4 = +1.0 - 2.0 * (y * y + z * z) yaw_z = torch.atan2(t3, t4) return roll_x, pitch_y, yaw_z # in radians @torch.jit.script def normalize(x, eps: float = 1e-9): return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1) @torch.jit.script def normalize_angle(x): return torch.atan2(torch.sin(x), torch.cos(x)) @torch.jit.script def quat_rotate(q, v): shape = q.shape q_w = q[:, -1] q_vec = q[:, :3] a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = q_vec * \ torch.bmm(q_vec.view(shape[0], 1, 3), v.view( shape[0], 3, 1)).squeeze(-1) * 2.0 return a + b + c @torch.jit.script def quat_rotate_inverse(q, v): shape = q.shape q_w = q[:, -1] q_vec = q[:, :3] a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = q_vec * \ torch.bmm(q_vec.view(shape[0], 1, 3), v.view( shape[0], 3, 1)).squeeze(-1) * 2.0 return a - b + c @torch.jit.script def quat_from_euler_xyz(roll, pitch, yaw): cy = torch.cos(yaw * 0.5) sy = torch.sin(yaw * 0.5) cr = torch.cos(roll * 0.5) sr = torch.sin(roll * 0.5) cp = torch.cos(pitch * 0.5) sp = torch.sin(pitch * 0.5) qw = cy * cr * cp + sy * sr * sp qx = cy * sr * cp - sy * cr * sp qy = cy * cr * sp + sy * sr * cp qz = sy * cr * cp - cy * sr * sp return torch.stack([qx, qy, qz, qw], dim=-1) @torch.jit.script def quat_unit(a): return normalize(a) @torch.jit.script def quat_from_angle_axis(angle, axis): theta = (angle / 2).unsqueeze(-1) xyz = normalize(axis) * theta.sin() w = theta.cos() return quat_unit(torch.cat([xyz, w], dim=-1)) @torch.jit.script def quat_mul(a, b): assert a.shape == b.shape shape = a.shape a = a.reshape(-1, 4) b = b.reshape(-1, 4) x1, y1, z1, w1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3] x2, y2, z2, w2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3] ww = (z1 + x1) * (x2 + y2) yy = (w1 - y1) * (w2 + z2) zz = (w1 + y1) * (w2 - z2) xx = ww + yy + zz qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) w = qq - ww + (z1 - y1) * (y2 - z2) x = qq - xx + (x1 + w1) * (x2 + w2) y = qq - yy + (w1 - x1) * (y2 + z2) z = qq - zz + (z1 + y1) * (w2 - x2) quat = torch.stack([x, y, z, w], dim=-1).view(shape) return quat @torch.jit.script def quat_conjugate(a): shape = a.shape a = a.reshape(-1, 4) return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape) @torch.jit.script def quat_to_angle_axis(q): # computes axis-angle representation from quaternion q # q must be normalized min_theta = 1e-5 qx, qy, qz, qw = 0, 1, 2, 3 sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw]) angle = 2 * torch.acos(q[..., qw]) angle = normalize_angle(angle) sin_theta_expand = sin_theta.unsqueeze(-1) axis = q[..., qx:qw] / sin_theta_expand mask = torch.abs(sin_theta) > min_theta default_axis = torch.zeros_like(axis) default_axis[..., -1] = 1 angle = torch.where(mask, angle, torch.zeros_like(angle)) mask_expand = mask.unsqueeze(-1) axis = torch.where(mask_expand, axis, default_axis) return angle, axis @torch.jit.script def angle_axis_to_exp_map(angle, axis): # compute exponential map from axis-angle angle_expand = angle.unsqueeze(-1) exp_map = angle_expand * axis return exp_map @torch.jit.script def quat_to_exp_map(q): # compute exponential map from quaternion # q must be normalized angle, axis = quat_to_angle_axis(q) exp_map = angle_axis_to_exp_map(angle, axis) return exp_map @torch.jit.script def quat_to_tan_norm(q): # represents a rotation using the tangent and normal vectors ref_tan = torch.zeros_like(q[..., 0:3]) ref_tan[..., 0] = 1 tan = quat_rotate(q, ref_tan) ref_norm = torch.zeros_like(q[..., 0:3]) ref_norm[..., -1] = 1 norm = quat_rotate(q, ref_norm) norm_tan = torch.cat([tan, norm], dim=len(tan.shape) - 1) return norm_tan @torch.jit.script def euler_xyz_to_exp_map(roll, pitch, yaw): q = quat_from_euler_xyz(roll, pitch, yaw) exp_map = quat_to_exp_map(q) return exp_map @torch.jit.script def exp_map_to_angle_axis(exp_map): min_theta = 1e-5 angle = torch.norm(exp_map, dim=-1) angle_exp = torch.unsqueeze(angle, dim=-1) axis = exp_map / angle_exp angle = normalize_angle(angle) default_axis = torch.zeros_like(exp_map) default_axis[..., -1] = 1 mask = torch.abs(angle) > min_theta angle = torch.where(mask, angle, torch.zeros_like(angle)) mask_expand = mask.unsqueeze(-1) axis = torch.where(mask_expand, axis, default_axis) return angle, axis @torch.jit.script def exp_map_to_quat(exp_map): angle, axis = exp_map_to_angle_axis(exp_map) q = quat_from_angle_axis(angle, axis) return q @torch.jit.script def slerp(q0, q1, t): assert(len(t.shape) == len(q0.shape) - 1) cos_half_theta = torch.sum(q0 * q1, dim=-1) neg_mask = cos_half_theta < 0 q1 = torch.where(neg_mask.unsqueeze(-1), -q1, q1) cos_half_theta = torch.abs(cos_half_theta) cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1) half_theta = torch.acos(cos_half_theta) sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta) t = t.unsqueeze(-1) ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta ratioB = torch.sin(t * half_theta) / sin_half_theta new_q = ratioA * q0 + ratioB * q1 new_q = torch.where(torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q) new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q) return new_q @torch.jit.script def slerp2(q0, q1, t): cos_half_theta = torch.sum(q0 * q1, dim=-1) neg_mask = cos_half_theta < 0 q1 = q1.clone() q1[neg_mask] = -q1[neg_mask] cos_half_theta = torch.abs(cos_half_theta) cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1) half_theta = torch.acos(cos_half_theta); sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta); ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta; ratioB = torch.sin(t * half_theta) / sin_half_theta; new_q = ratioA * q0 + ratioB * q1 new_q = torch.where(torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q) new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q) return new_q @torch.jit.script def calc_heading(q): # calculate heading direction from quaternion # the heading is the direction on the xy plane # q must be normalized ref_dir = torch.zeros_like(q[..., 0:3]) ref_dir[..., 0] = 1 rot_dir = quat_rotate(q, ref_dir) heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0]) return heading @torch.jit.script def calc_heading_quat(q): # calculate heading rotation from quaternion # the heading is the direction on the xy plane # q must be normalized heading = calc_heading(q) axis = torch.zeros_like(q[..., 0:3]) axis[..., 2] = 1 heading_q = quat_from_angle_axis(heading, axis) return heading_q @torch.jit.script def calc_heading_quat_inv(q): # calculate heading rotation from quaternion # the heading is the direction on the xy plane # q must be normalized heading = calc_heading(q) axis = torch.zeros_like(q[..., 0:3]) axis[..., 2] = 1 heading_q = quat_from_angle_axis(-heading, axis) return heading_q @torch.jit.script def quat_pos(x): q = x z = (q[..., 3:] < 0).float() q = (1 - 2 * z) * q return q @torch.jit.script def quat_to_axis_angle(q): eps = 1e-5 qx, qy, qz, qw = 0, 1, 2, 3 # need to make sure w is not negative to calculate geodesic distance q = quat_pos(q) length = torch.norm(q[..., 0:3], dim=-1, p=2) angle = 2.0 * torch.atan2(length, q[..., qw]) axis = q[..., qx:qw] / length.unsqueeze(-1) default_axis = torch.zeros_like(axis) default_axis[..., -1] = 1 mask = length > eps angle = torch.where(mask, angle, torch.zeros_like(angle)) mask_expand = mask.unsqueeze(-1) axis = torch.where(mask_expand, axis, default_axis) return axis, angle @torch.jit.script def quat_diff(q0, q1): dq = quat_mul(q1, quat_conjugate(q0)) return dq @torch.jit.script def quat_diff_angle(q0, q1): dq = quat_diff(q0, q1) _, angle = quat_to_axis_angle(dq) return angle @torch.jit.script def axis_angle_to_quat(axis, angle): # type: (Tensor, Tensor) -> Tensor theta = (angle / 2).unsqueeze(-1) xyz = normalize(axis) * theta.sin() w = theta.cos() return quat_unit(torch.cat([xyz, w], dim=-1))