Spaces:
Running on Zero
Running on Zero
| # Copyright (c) 2018-2022, NVIDIA Corporation | |
| # All rights reserved. | |
| # | |
| # Redistribution and use in source and binary forms, with or without | |
| # modification, are permitted provided that the following conditions are met: | |
| # | |
| # 1. Redistributions of source code must retain the above copyright notice, this | |
| # list of conditions and the following disclaimer. | |
| # | |
| # 2. Redistributions in binary form must reproduce the above copyright notice, | |
| # this list of conditions and the following disclaimer in the documentation | |
| # and/or other materials provided with the distribution. | |
| # | |
| # 3. Neither the name of the copyright holder nor the names of its | |
| # contributors may be used to endorse or promote products derived from | |
| # this software without specific prior written permission. | |
| # | |
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
| # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
| # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| from typing import Tuple | |
| import torch | |
| from torch import Tensor | |
| def euler_from_quaternion(quat_angle): | |
| """ | |
| Convert a quaternion into euler angles (roll, pitch, yaw) | |
| roll is rotation around x in radians (counterclockwise) | |
| pitch is rotation around y in radians (counterclockwise) | |
| yaw is rotation around z in radians (counterclockwise) | |
| """ | |
| x = quat_angle[:,0]; y = quat_angle[:,1]; z = quat_angle[:,2]; w = quat_angle[:,3] | |
| t0 = +2.0 * (w * x + y * z) | |
| t1 = +1.0 - 2.0 * (x * x + y * y) | |
| roll_x = torch.atan2(t0, t1) | |
| t2 = +2.0 * (w * y - z * x) | |
| t2 = torch.clip(t2, -1, 1) | |
| pitch_y = torch.asin(t2) | |
| t3 = +2.0 * (w * z + x * y) | |
| t4 = +1.0 - 2.0 * (y * y + z * z) | |
| yaw_z = torch.atan2(t3, t4) | |
| return roll_x, pitch_y, yaw_z # in radians | |
| def normalize(x, eps: float = 1e-9): | |
| return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1) | |
| def normalize_angle(x): | |
| return torch.atan2(torch.sin(x), torch.cos(x)) | |
| def quat_rotate(q, v): | |
| shape = q.shape | |
| q_w = q[:, -1] | |
| q_vec = q[:, :3] | |
| a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1) | |
| b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 | |
| c = q_vec * \ | |
| torch.bmm(q_vec.view(shape[0], 1, 3), v.view( | |
| shape[0], 3, 1)).squeeze(-1) * 2.0 | |
| return a + b + c | |
| def quat_rotate_inverse(q, v): | |
| shape = q.shape | |
| q_w = q[:, -1] | |
| q_vec = q[:, :3] | |
| a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1) | |
| b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 | |
| c = q_vec * \ | |
| torch.bmm(q_vec.view(shape[0], 1, 3), v.view( | |
| shape[0], 3, 1)).squeeze(-1) * 2.0 | |
| return a - b + c | |
| def quat_from_euler_xyz(roll, pitch, yaw): | |
| cy = torch.cos(yaw * 0.5) | |
| sy = torch.sin(yaw * 0.5) | |
| cr = torch.cos(roll * 0.5) | |
| sr = torch.sin(roll * 0.5) | |
| cp = torch.cos(pitch * 0.5) | |
| sp = torch.sin(pitch * 0.5) | |
| qw = cy * cr * cp + sy * sr * sp | |
| qx = cy * sr * cp - sy * cr * sp | |
| qy = cy * cr * sp + sy * sr * cp | |
| qz = sy * cr * cp - cy * sr * sp | |
| return torch.stack([qx, qy, qz, qw], dim=-1) | |
| def quat_unit(a): | |
| return normalize(a) | |
| def quat_from_angle_axis(angle, axis): | |
| theta = (angle / 2).unsqueeze(-1) | |
| xyz = normalize(axis) * theta.sin() | |
| w = theta.cos() | |
| return quat_unit(torch.cat([xyz, w], dim=-1)) | |
| def quat_mul(a, b): | |
| assert a.shape == b.shape | |
| shape = a.shape | |
| a = a.reshape(-1, 4) | |
| b = b.reshape(-1, 4) | |
| x1, y1, z1, w1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3] | |
| x2, y2, z2, w2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3] | |
| ww = (z1 + x1) * (x2 + y2) | |
| yy = (w1 - y1) * (w2 + z2) | |
| zz = (w1 + y1) * (w2 - z2) | |
| xx = ww + yy + zz | |
| qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) | |
| w = qq - ww + (z1 - y1) * (y2 - z2) | |
| x = qq - xx + (x1 + w1) * (x2 + w2) | |
| y = qq - yy + (w1 - x1) * (y2 + z2) | |
| z = qq - zz + (z1 + y1) * (w2 - x2) | |
| quat = torch.stack([x, y, z, w], dim=-1).view(shape) | |
| return quat | |
| def quat_conjugate(a): | |
| shape = a.shape | |
| a = a.reshape(-1, 4) | |
| return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape) | |
| def quat_to_angle_axis(q): | |
| # computes axis-angle representation from quaternion q | |
| # q must be normalized | |
| min_theta = 1e-5 | |
| qx, qy, qz, qw = 0, 1, 2, 3 | |
| sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw]) | |
| angle = 2 * torch.acos(q[..., qw]) | |
| angle = normalize_angle(angle) | |
| sin_theta_expand = sin_theta.unsqueeze(-1) | |
| axis = q[..., qx:qw] / sin_theta_expand | |
| mask = torch.abs(sin_theta) > min_theta | |
| default_axis = torch.zeros_like(axis) | |
| default_axis[..., -1] = 1 | |
| angle = torch.where(mask, angle, torch.zeros_like(angle)) | |
| mask_expand = mask.unsqueeze(-1) | |
| axis = torch.where(mask_expand, axis, default_axis) | |
| return angle, axis | |
| def angle_axis_to_exp_map(angle, axis): | |
| # compute exponential map from axis-angle | |
| angle_expand = angle.unsqueeze(-1) | |
| exp_map = angle_expand * axis | |
| return exp_map | |
| def quat_to_exp_map(q): | |
| # compute exponential map from quaternion | |
| # q must be normalized | |
| angle, axis = quat_to_angle_axis(q) | |
| exp_map = angle_axis_to_exp_map(angle, axis) | |
| return exp_map | |
| def quat_to_tan_norm(q): | |
| # represents a rotation using the tangent and normal vectors | |
| ref_tan = torch.zeros_like(q[..., 0:3]) | |
| ref_tan[..., 0] = 1 | |
| tan = quat_rotate(q, ref_tan) | |
| ref_norm = torch.zeros_like(q[..., 0:3]) | |
| ref_norm[..., -1] = 1 | |
| norm = quat_rotate(q, ref_norm) | |
| norm_tan = torch.cat([tan, norm], dim=len(tan.shape) - 1) | |
| return norm_tan | |
| def euler_xyz_to_exp_map(roll, pitch, yaw): | |
| q = quat_from_euler_xyz(roll, pitch, yaw) | |
| exp_map = quat_to_exp_map(q) | |
| return exp_map | |
| def exp_map_to_angle_axis(exp_map): | |
| min_theta = 1e-5 | |
| angle = torch.norm(exp_map, dim=-1) | |
| angle_exp = torch.unsqueeze(angle, dim=-1) | |
| axis = exp_map / angle_exp | |
| angle = normalize_angle(angle) | |
| default_axis = torch.zeros_like(exp_map) | |
| default_axis[..., -1] = 1 | |
| mask = torch.abs(angle) > min_theta | |
| angle = torch.where(mask, angle, torch.zeros_like(angle)) | |
| mask_expand = mask.unsqueeze(-1) | |
| axis = torch.where(mask_expand, axis, default_axis) | |
| return angle, axis | |
| def exp_map_to_quat(exp_map): | |
| angle, axis = exp_map_to_angle_axis(exp_map) | |
| q = quat_from_angle_axis(angle, axis) | |
| return q | |
| def slerp(q0, q1, t): | |
| assert(len(t.shape) == len(q0.shape) - 1) | |
| cos_half_theta = torch.sum(q0 * q1, dim=-1) | |
| neg_mask = cos_half_theta < 0 | |
| q1 = torch.where(neg_mask.unsqueeze(-1), -q1, q1) | |
| cos_half_theta = torch.abs(cos_half_theta) | |
| cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1) | |
| half_theta = torch.acos(cos_half_theta) | |
| sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta) | |
| t = t.unsqueeze(-1) | |
| ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta | |
| ratioB = torch.sin(t * half_theta) / sin_half_theta | |
| new_q = ratioA * q0 + ratioB * q1 | |
| new_q = torch.where(torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q) | |
| new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q) | |
| return new_q | |
| def slerp2(q0, q1, t): | |
| cos_half_theta = torch.sum(q0 * q1, dim=-1) | |
| neg_mask = cos_half_theta < 0 | |
| q1 = q1.clone() | |
| q1[neg_mask] = -q1[neg_mask] | |
| cos_half_theta = torch.abs(cos_half_theta) | |
| cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1) | |
| half_theta = torch.acos(cos_half_theta); | |
| sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta); | |
| ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta; | |
| ratioB = torch.sin(t * half_theta) / sin_half_theta; | |
| new_q = ratioA * q0 + ratioB * q1 | |
| new_q = torch.where(torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q) | |
| new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q) | |
| return new_q | |
| def calc_heading(q): | |
| # calculate heading direction from quaternion | |
| # the heading is the direction on the xy plane | |
| # q must be normalized | |
| ref_dir = torch.zeros_like(q[..., 0:3]) | |
| ref_dir[..., 0] = 1 | |
| rot_dir = quat_rotate(q, ref_dir) | |
| heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0]) | |
| return heading | |
| def calc_heading_quat(q): | |
| # calculate heading rotation from quaternion | |
| # the heading is the direction on the xy plane | |
| # q must be normalized | |
| heading = calc_heading(q) | |
| axis = torch.zeros_like(q[..., 0:3]) | |
| axis[..., 2] = 1 | |
| heading_q = quat_from_angle_axis(heading, axis) | |
| return heading_q | |
| def calc_heading_quat_inv(q): | |
| # calculate heading rotation from quaternion | |
| # the heading is the direction on the xy plane | |
| # q must be normalized | |
| heading = calc_heading(q) | |
| axis = torch.zeros_like(q[..., 0:3]) | |
| axis[..., 2] = 1 | |
| heading_q = quat_from_angle_axis(-heading, axis) | |
| return heading_q | |
| def quat_pos(x): | |
| q = x | |
| z = (q[..., 3:] < 0).float() | |
| q = (1 - 2 * z) * q | |
| return q | |
| def quat_to_axis_angle(q): | |
| eps = 1e-5 | |
| qx, qy, qz, qw = 0, 1, 2, 3 | |
| # need to make sure w is not negative to calculate geodesic distance | |
| q = quat_pos(q) | |
| length = torch.norm(q[..., 0:3], dim=-1, p=2) | |
| angle = 2.0 * torch.atan2(length, q[..., qw]) | |
| axis = q[..., qx:qw] / length.unsqueeze(-1) | |
| default_axis = torch.zeros_like(axis) | |
| default_axis[..., -1] = 1 | |
| mask = length > eps | |
| angle = torch.where(mask, angle, torch.zeros_like(angle)) | |
| mask_expand = mask.unsqueeze(-1) | |
| axis = torch.where(mask_expand, axis, default_axis) | |
| return axis, angle | |
| def quat_diff(q0, q1): | |
| dq = quat_mul(q1, quat_conjugate(q0)) | |
| return dq | |
| def quat_diff_angle(q0, q1): | |
| dq = quat_diff(q0, q1) | |
| _, angle = quat_to_axis_angle(dq) | |
| return angle | |
| def axis_angle_to_quat(axis, angle): | |
| # type: (Tensor, Tensor) -> Tensor | |
| theta = (angle / 2).unsqueeze(-1) | |
| xyz = normalize(axis) * theta.sin() | |
| w = theta.cos() | |
| return quat_unit(torch.cat([xyz, w], dim=-1)) | |