English
File size: 2,311 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import numpy as np


__all__ = [
    'cross_product_matrix', 'rodrigues_rotation_matrix', 'base_vectors_3d']


def cross_product_matrix(k):
    """Compute the cross-product matrix of a vector k.

    Credit: https://github.com/torch-points3d/torch-points3d
    """
    return torch.tensor(
        [[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]], device=k.device)


def rodrigues_rotation_matrix(axis, theta_degrees):
    """Given an axis and a rotation angle, compute the rotation matrix
    using the Rodrigues formula.

    Source : https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
    Credit: https://github.com/torch-points3d/torch-points3d
    """
    axis = axis / axis.norm()
    K = cross_product_matrix(axis)
    t = torch.tensor([theta_degrees / 180. * np.pi], device=axis.device)
    R = torch.eye(3, device=axis.device) \
        + torch.sin(t) * K + (1 - torch.cos(t)) * K.mm(K)
    return R


def base_vectors_3d(x):
    """Compute orthonormal bases for a set of 3D vectors. The 1st base
    vector is the normalized input vector, while the 2nd and 3rd vectors
    are constructed in the corresponding orthogonal plane. Note that
    this problem is underconstrained and, as such, any rotation of the
    output base around the 1st vector is a valid orthonormal base.
    """
    assert x.dim() == 2
    assert x.shape[1] == 3

    # First direction is along x
    a = x

    # If x is 0 vector (norm=0), arbitrarily put a to (1, 0, 0)
    a[torch.where(a.norm(dim=1) == 0)[0]] = torch.tensor(
        [[1, 0, 0]], dtype=x.dtype, device=x.device)

    # Safely normalize a
    a = a / a.norm(dim=1).view(-1, 1)

    # Build a vector orthogonal to a
    b = torch.vstack((a[:, 1] - a[:, 2], a[:, 2] - a[:, 0], a[:, 0] - a[:, 1])).T

    # In the same fashion as when building a, the second base vector
    # may be 0 by construction (i.e. a is of type (v, v, v)). So we need
    # to deal with this edge case by setting
    b[torch.where(b.norm(dim=1) == 0)[0]] = torch.tensor(
        [[2, 1, -1]], dtype=x.dtype, device=x.device)

    # Safely normalize b
    b /= b.norm(dim=1).view(-1, 1)

    # Cross product of a and b to build the 3rd base vector
    c = torch.linalg.cross(a, b)

    return torch.cat((a.unsqueeze(1), b.unsqueeze(1), c.unsqueeze(1)), dim=1)