| import numpy as np |
| import torch |
|
|
| if __name__ != "__main__": |
| from .konia_transform import ( |
| angle_axis_to_quaternion, |
| angle_axis_to_rotation_matrix, |
| quaternion_to_angle_axis, |
| quaternion_to_rotation_matrix, |
| rotation_matrix_to_angle_axis, |
| rotation_matrix_to_quaternion, |
| ) |
| else: |
| from konia_transform import ( |
| angle_axis_to_quaternion, |
| angle_axis_to_rotation_matrix, |
| quaternion_to_rotation_matrix, |
| rotation_matrix_to_angle_axis, |
| rotation_matrix_to_quaternion, |
| ) |
|
|
|
|
| def normalize(x, eps: float = 1e-9): |
| return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1) |
|
|
|
|
| @torch.jit.script |
| def quat_mul(a, b): |
| assert a.shape == b.shape |
| shape = a.shape |
| a = a.reshape(-1, 4) |
| b = b.reshape(-1, 4) |
|
|
| w1, x1, y1, z1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3] |
| w2, x2, y2, z2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3] |
| ww = (z1 + x1) * (x2 + y2) |
| yy = (w1 - y1) * (w2 + z2) |
| zz = (w1 + y1) * (w2 - z2) |
| xx = ww + yy + zz |
| qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) |
| w = qq - ww + (z1 - y1) * (y2 - z2) |
| x = qq - xx + (x1 + w1) * (x2 + w2) |
| y = qq - yy + (w1 - x1) * (y2 + z2) |
| z = qq - zz + (z1 + y1) * (w2 - x2) |
| return torch.stack([w, x, y, z], dim=-1).view(shape) |
|
|
|
|
| @torch.jit.script |
| def quat_conjugate(a): |
| shape = a.shape |
| a = a.reshape(-1, 4) |
| return torch.cat((a[:, 0:1], -a[:, 1:]), dim=-1).view(shape) |
|
|
|
|
| @torch.jit.script |
| def quat_apply(a, b): |
| shape = b.shape |
| a = a.reshape(-1, 4) |
| b = b.reshape(-1, 3) |
| xyz = a[:, 1:].clone() |
| t = xyz.cross(b, dim=-1) * 2 |
| return (b + a[:, 0:1].clone() * t + xyz.cross(t, dim=-1)).view(shape) |
|
|
|
|
| @torch.jit.script |
| def quat_angle(a, eps: float = 1e-6): |
| shape = a.shape |
| a = a.reshape(-1, 4) |
| s = 2 * (a[:, 0] ** 2) - 1 |
| s = s.clamp(-1 + eps, 1 - eps) |
| s = s.acos() |
| return s.view(shape[:-1]) |
|
|
|
|
| @torch.jit.script |
| def quat_angle_diff(quat1, quat2): |
| return quat_angle(quat_mul(quat1, quat_conjugate(quat2))) |
|
|
|
|
| @torch.jit.script |
| def torch_safe_atan2(y, x, eps: float = 1e-8): |
| y = y.clone() |
| y[(y.abs() < eps) & (x.abs() < eps)] += eps |
| return torch.atan2(y, x) |
|
|
|
|
| @torch.jit.script |
| def ypr_euler_from_quat( |
| q, handle_singularity: bool = False, eps: float = 1e-6, singular_eps: float = 1e-6 |
| ): |
| """ |
| convert quaternion to yaw-pitch-roll euler angles |
| """ |
| yaw_atany = 2 * (q[..., 0] * q[..., 3] + q[..., 1] * q[..., 2]) |
| yaw_atanx = 1 - 2 * (q[..., 2] * q[..., 2] + q[..., 3] * q[..., 3]) |
| roll_atany = 2 * (q[..., 0] * q[..., 1] + q[..., 2] * q[..., 3]) |
| roll_atanx = 1 - 2 * (q[..., 1] * q[..., 1] + q[..., 2] * q[..., 2]) |
| yaw = torch_safe_atan2(yaw_atany, yaw_atanx, eps) |
| pitch = torch.asin( |
| torch.clamp( |
| 2 * (q[..., 0] * q[..., 2] - q[..., 1] * q[..., 3]), |
| min=-1 + eps, |
| max=1 - eps, |
| ) |
| ) |
| roll = torch_safe_atan2(roll_atany, roll_atanx, eps) |
|
|
| if handle_singularity: |
| """ handle two special cases """ |
| test = q[..., 0] * q[..., 2] - q[..., 1] * q[..., 3] |
| |
| np_ind = test > 0.5 - singular_eps |
| if torch.any(np_ind): |
| |
| roll[np_ind] = 0.0 |
| pitch[np_ind].clamp_max_(0.5 * np.pi) |
| yaw_atany = q[..., 3][np_ind] |
| yaw_atanx = q[..., 0][np_ind] |
| yaw[np_ind] = 2 * torch_safe_atan2(yaw_atany, yaw_atanx, eps) |
| |
| sp_ind = test < -0.5 + singular_eps |
| if torch.any(sp_ind): |
| |
| roll[sp_ind] = 0.0 |
| pitch[sp_ind].clamp_min_(-0.5 * np.pi) |
| yaw_atany = q[..., 3][sp_ind] |
| yaw_atanx = q[..., 0][sp_ind] |
| yaw[sp_ind] = 2 * torch_safe_atan2(yaw_atany, yaw_atanx, eps) |
|
|
| return torch.stack([roll, pitch, yaw], dim=-1) |
|
|
|
|
| @torch.jit.script |
| def quat_from_ypr_euler(angles): |
| """ |
| convert yaw-pitch-roll euler angles to quaternion |
| """ |
| half_ang = angles * 0.5 |
| sin = torch.sin(half_ang) |
| cos = torch.cos(half_ang) |
| q = torch.stack( |
| [ |
| cos[..., 0] * cos[..., 1] * cos[..., 2] |
| + sin[..., 0] * sin[..., 1] * sin[..., 2], |
| sin[..., 0] * cos[..., 1] * cos[..., 2] |
| - cos[..., 0] * sin[..., 1] * sin[..., 2], |
| cos[..., 0] * sin[..., 1] * cos[..., 2] |
| + sin[..., 0] * cos[..., 1] * sin[..., 2], |
| cos[..., 0] * cos[..., 1] * sin[..., 2] |
| - sin[..., 0] * sin[..., 1] * cos[..., 2], |
| ], |
| dim=-1, |
| ) |
| return q |
|
|
|
|
| def quat_between_two_vec(v1, v2, eps: float = 1e-6): |
| """ |
| quaternion for rotating v1 to v2 |
| """ |
| orig_shape = v1.shape |
| v1 = v1.reshape(-1, 3) |
| v2 = v2.reshape(-1, 3) |
| dot = (v1 * v2).sum(-1) |
| cross = torch.cross(v1, v2, dim=-1) |
| out = torch.cat([(1 + dot).unsqueeze(-1), cross], dim=-1) |
| |
| sind = dot > 1 - eps |
| out[sind] = torch.tensor([1.0, 0.0, 0.0, 0.0], device=v1.device) |
| |
| nind = dot < -1 + eps |
| if torch.any(nind): |
| vx = torch.tensor([1.0, 0.0, 0.0], device=v1.device) |
| vxdot = (v1 * vx).sum(-1).abs() |
| nxind = nind & (vxdot < 1 - eps) |
| if torch.any(nxind): |
| out[nxind] = angle_axis_to_quaternion( |
| normalize(torch.cross(vx.expand_as(v1[nxind]), v1[nxind], dim=-1)) |
| * np.pi |
| ) |
| |
| pind = nind & (vxdot >= 1 - eps) |
| if torch.any(pind): |
| vy = torch.tensor([0.0, 1.0, 0.0], device=v1.device) |
| out[pind] = angle_axis_to_quaternion( |
| normalize(torch.cross(vy.expand_as(v1[pind]), v1[pind], dim=-1)) * np.pi |
| ) |
| |
| out = normalize(out).view(orig_shape[:-1] + (4,)) |
| return out |
|
|
|
|
| @torch.jit.script |
| def get_yaw(q, eps: float = 1e-6): |
| yaw_atany = 2 * (q[..., 0] * q[..., 3] + q[..., 1] * q[..., 2]) |
| yaw_atanx = 1 - 2 * (q[..., 2] * q[..., 2] + q[..., 3] * q[..., 3]) |
| yaw = torch_safe_atan2(yaw_atany, yaw_atanx, eps) |
| return yaw |
|
|
|
|
| @torch.jit.script |
| def get_yaw_q(q): |
| yaw = get_yaw(q) |
| angle_axis = torch.cat( |
| [torch.zeros(yaw.shape + (2,), device=q.device), yaw.unsqueeze(-1)], dim=-1 |
| ) |
| heading_q = angle_axis_to_quaternion(angle_axis) |
| return heading_q |
|
|
|
|
| @torch.jit.script |
| def get_heading(q, eps: float = 1e-6): |
| heading_atany = q[..., 3] |
| heading_atanx = q[..., 0] |
| heading = 2 * torch_safe_atan2(heading_atany, heading_atanx, eps) |
| return heading |
|
|
|
|
| def get_heading_q(q): |
| q_new = q.clone() |
| q_new[..., 1] = 0 |
| q_new[..., 2] = 0 |
| q_new = normalize(q_new) |
| return q_new |
|
|
|
|
| def get_y_heading_q(q): |
| q_new = q.clone() |
| q_new[..., 1] = 0 |
| q_new[..., 3] = 0 |
| q_new = normalize(q_new) |
| return q_new |
|
|
|
|
| @torch.jit.script |
| def heading_to_vec(h_theta): |
| v = torch.stack([torch.cos(h_theta), torch.sin(h_theta)], dim=-1) |
| return v |
|
|
|
|
| @torch.jit.script |
| def vec_to_heading(h_vec): |
| h_theta = torch_safe_atan2(h_vec[..., 1], h_vec[..., 0]) |
| return h_theta |
|
|
|
|
| @torch.jit.script |
| def heading_to_quat(h_theta): |
| angle_axis = torch.cat( |
| [ |
| torch.zeros(h_theta.shape + (2,), device=h_theta.device), |
| h_theta.unsqueeze(-1), |
| ], |
| dim=-1, |
| ) |
| heading_q = angle_axis_to_quaternion(angle_axis) |
| return heading_q |
|
|
|
|
| def deheading_quat(q, heading_q=None): |
| if heading_q is None: |
| heading_q = get_heading_q(q) |
| dq = quat_mul(quat_conjugate(heading_q), q) |
| return dq |
|
|
|
|
| @torch.jit.script |
| def rotmat_to_rot6d(mat): |
| rot6d = torch.cat([mat[..., 0], mat[..., 1]], dim=-1) |
| return rot6d |
|
|
|
|
| @torch.jit.script |
| def rot6d_to_rotmat(rot6d, eps: float = 1e-8): |
| a1 = rot6d[..., :3].clone() |
| a2 = rot6d[..., 3:].clone() |
| ind = torch.norm(a1, dim=-1) < eps |
| a1[ind] = torch.tensor([1.0, 0.0, 0.0], device=a1.device) |
| b1 = normalize(a1) |
|
|
| b2 = normalize(a2 - (b1 * a2).sum(dim=-1).unsqueeze(-1) * b1) |
| ind = torch.norm(b2, dim=-1) < eps |
| b2[ind] = torch.tensor([0.0, 1.0, 0.0], device=b2.device) |
|
|
| b3 = torch.cross(b1, b2, dim=-1) |
| mat = torch.stack([b1, b2, b3], dim=-1) |
| return mat |
|
|
|
|
| @torch.jit.script |
| def angle_axis_to_rot6d(aa): |
| return rotmat_to_rot6d(angle_axis_to_rotation_matrix(aa)) |
|
|
|
|
| @torch.jit.script |
| def rot6d_to_angle_axis(rot6d): |
| return rotation_matrix_to_angle_axis(rot6d_to_rotmat(rot6d)) |
|
|
|
|
| @torch.jit.script |
| def quat_to_rot6d(q): |
| return rotmat_to_rot6d(quaternion_to_rotation_matrix(q)) |
|
|
|
|
| @torch.jit.script |
| def rot6d_to_quat(rot6d): |
| return rotation_matrix_to_quaternion(rot6d_to_rotmat(rot6d)) |
|
|
|
|
| @torch.jit.script |
| def make_transform(rot, trans, rot_type: str = "rotmat"): |
| if rot_type == "axis_angle": |
| rot = angle_axis_to_rotation_matrix(rot) |
| elif rot_type == "6d": |
| rot = rot6d_to_rotmat(rot) |
| transform = torch.eye(4).to(trans.device).repeat(rot.shape[:-2] + (1, 1)) |
| transform[..., :3, :3] = rot |
| transform[..., :3, 3] = trans |
| return transform |
|
|
|
|
| @torch.jit.script |
| def transform_trans(transform_mat, trans): |
| trans = torch.cat((trans, torch.ones_like(trans[..., :1])), dim=-1)[..., None, :] |
| while len(transform_mat.shape) < len(trans.shape): |
| transform_mat = transform_mat.unsqueeze(-3) |
| trans_new = torch.matmul(trans, transform_mat.transpose(-2, -1))[..., 0, :3] |
| return trans_new |
|
|
|
|
| @torch.jit.script |
| def transform_rot(transform_mat, rot): |
| rot_qmat = angle_axis_to_rotation_matrix(rot) |
| while len(transform_mat.shape) < len(rot_qmat.shape): |
| transform_mat = transform_mat.unsqueeze(-3) |
| rot_qmat_new = torch.matmul(transform_mat[..., :3, :3], rot_qmat) |
| rot_new = rotation_matrix_to_angle_axis(rot_qmat_new) |
| return rot_new |
|
|
|
|
| @torch.jit.script |
| def inverse_transform(transform_mat): |
| transform_inv = torch.zeros_like(transform_mat) |
| transform_inv[..., :3, :3] = transform_mat[..., :3, :3].transpose(-2, -1) |
| transform_inv[..., :3, 3] = -torch.matmul( |
| transform_mat[..., :3, 3].unsqueeze(-2), transform_mat[..., :3, :3] |
| ).squeeze(-2) |
| transform_inv[..., 3, 3] = 1.0 |
| return transform_inv |
|
|
|
|
| def batch_compute_similarity_transform_torch(S1, S2): |
| """ |
| Computes a similarity transform (sR, t) that takes |
| a set of 3D points S1 (3 x N) closest to a set of 3D points S2, |
| where R is an 3x3 rotation matrix, t 3x1 translation, s scale. |
| i.e. solves the orthogonal Procrutes problem. |
| """ |
| if len(S1.shape) > 3: |
| orig_shape = S1.shape |
| S1 = S1.reshape(-1, *S1.shape[-2:]) |
| S2 = S2.reshape(-1, *S2.shape[-2:]) |
| else: |
| orig_shape = None |
|
|
| transposed = False |
| if S1.shape[0] != 3 and S1.shape[0] != 2: |
| S1 = S1.permute(0, 2, 1) |
| S2 = S2.permute(0, 2, 1) |
| transposed = True |
| assert S2.shape[1] == S1.shape[1] |
|
|
| |
| mu1 = S1.mean(axis=-1, keepdims=True) |
| mu2 = S2.mean(axis=-1, keepdims=True) |
|
|
| X1 = S1 - mu1 |
| X2 = S2 - mu2 |
|
|
| |
| var1 = torch.sum(X1**2, dim=1).sum(dim=1) |
|
|
| |
| K = X1.bmm(X2.permute(0, 2, 1)) |
|
|
| |
| |
| U, s, V = torch.svd(K) |
|
|
| |
| Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) |
| Z = Z.repeat(U.shape[0], 1, 1) |
| Z[:, -1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0, 2, 1)))) |
|
|
| |
| R = V.bmm(Z.bmm(U.permute(0, 2, 1))) |
|
|
| |
| scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 |
|
|
| |
| t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) |
|
|
| |
| S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t |
|
|
| if transposed: |
| S1_hat = S1_hat.permute(0, 2, 1) |
|
|
| if orig_shape is not None: |
| S1_hat = S1_hat.reshape(orig_shape) |
|
|
| return S1_hat |
|
|