| | |
| | import torch |
| | from torch.nn import functional as F |
| |
|
| |
|
| | def rot6d_to_rotmat(x): |
| | """Convert 6D rotation representation to 3x3 rotation matrix. |
| | |
| | Based on Zhou et al., "On the Continuity of Rotation |
| | Representations in Neural Networks", CVPR 2019 |
| | Input: |
| | (B,6) Batch of 6-D rotation representations |
| | Output: |
| | (B,3,3) Batch of corresponding rotation matrices |
| | """ |
| | x = x.view(-1, 3, 2) |
| | a1 = x[:, :, 0] |
| | a2 = x[:, :, 1] |
| | b1 = F.normalize(a1) |
| | b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) |
| | b3 = torch.cross(b1, b2) |
| | return torch.stack((b1, b2, b3), dim=-1) |
| |
|
| |
|
| | def batch_rodrigues(theta): |
| | """Convert axis-angle representation to rotation matrix. |
| | Args: |
| | theta: size = [B, 3] |
| | Returns: |
| | Rotation matrix corresponding to the quaternion |
| | -- size = [B, 3, 3] |
| | """ |
| | l2norm = torch.norm(theta + 1e-8, p=2, dim=1) |
| | angle = torch.unsqueeze(l2norm, -1) |
| | normalized = torch.div(theta, angle) |
| | angle = angle * 0.5 |
| | v_cos = torch.cos(angle) |
| | v_sin = torch.sin(angle) |
| | quat = torch.cat([v_cos, v_sin * normalized], dim=1) |
| | return quat_to_rotmat(quat) |
| |
|
| |
|
| | def quat_to_rotmat(quat): |
| | """Convert quaternion coefficients to rotation matrix. |
| | Args: |
| | quat: size = [B, 4] 4 <===>(w, x, y, z) |
| | Returns: |
| | Rotation matrix corresponding to the quaternion |
| | -- size = [B, 3, 3] |
| | """ |
| | norm_quat = quat |
| | norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) |
| | w, x, y, z = norm_quat[:, 0], norm_quat[:, 1],\ |
| | norm_quat[:, 2], norm_quat[:, 3] |
| |
|
| | B = quat.size(0) |
| |
|
| | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) |
| | wx, wy, wz = w * x, w * y, w * z |
| | xy, xz, yz = x * y, x * z, y * z |
| |
|
| | rotMat = torch.stack([ |
| | w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, |
| | w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, |
| | w2 - x2 - y2 + z2 |
| | ], |
| | dim=1).view(B, 3, 3) |
| | return rotMat |
| |
|