File size: 2,311 Bytes
26225c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
import numpy as np
__all__ = [
'cross_product_matrix', 'rodrigues_rotation_matrix', 'base_vectors_3d']
def cross_product_matrix(k):
"""Compute the cross-product matrix of a vector k.
Credit: https://github.com/torch-points3d/torch-points3d
"""
return torch.tensor(
[[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]], device=k.device)
def rodrigues_rotation_matrix(axis, theta_degrees):
"""Given an axis and a rotation angle, compute the rotation matrix
using the Rodrigues formula.
Source : https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
Credit: https://github.com/torch-points3d/torch-points3d
"""
axis = axis / axis.norm()
K = cross_product_matrix(axis)
t = torch.tensor([theta_degrees / 180. * np.pi], device=axis.device)
R = torch.eye(3, device=axis.device) \
+ torch.sin(t) * K + (1 - torch.cos(t)) * K.mm(K)
return R
def base_vectors_3d(x):
"""Compute orthonormal bases for a set of 3D vectors. The 1st base
vector is the normalized input vector, while the 2nd and 3rd vectors
are constructed in the corresponding orthogonal plane. Note that
this problem is underconstrained and, as such, any rotation of the
output base around the 1st vector is a valid orthonormal base.
"""
assert x.dim() == 2
assert x.shape[1] == 3
# First direction is along x
a = x
# If x is 0 vector (norm=0), arbitrarily put a to (1, 0, 0)
a[torch.where(a.norm(dim=1) == 0)[0]] = torch.tensor(
[[1, 0, 0]], dtype=x.dtype, device=x.device)
# Safely normalize a
a = a / a.norm(dim=1).view(-1, 1)
# Build a vector orthogonal to a
b = torch.vstack((a[:, 1] - a[:, 2], a[:, 2] - a[:, 0], a[:, 0] - a[:, 1])).T
# In the same fashion as when building a, the second base vector
# may be 0 by construction (i.e. a is of type (v, v, v)). So we need
# to deal with this edge case by setting
b[torch.where(b.norm(dim=1) == 0)[0]] = torch.tensor(
[[2, 1, -1]], dtype=x.dtype, device=x.device)
# Safely normalize b
b /= b.norm(dim=1).view(-1, 1)
# Cross product of a and b to build the 3rd base vector
c = torch.linalg.cross(a, b)
return torch.cat((a.unsqueeze(1), b.unsqueeze(1), c.unsqueeze(1)), dim=1)
|