File size: 2,183 Bytes
b678162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
import torch
from torch import Tensor
import torch.nn.functional as F
from jaxtyping import Float


def transform_points(
    T: Float[Tensor, "... d d"],
    pts: Float[Tensor, "... n c"]
) -> Float[Tensor, "... n 3"]:
    """
    Args:
        T (torch.Tensor): transformation matrix of shape (d, d)
        pts (torch.Tensor): Input points of shape (n, c)
    """
    if pts.shape[-1] == (T.shape[-1] - 1):
        pts = F.pad(pts, (0, 1), value=1)
    pts = torch.einsum("...ji,...ni->...nj", T, pts)
    return pts[..., :3]


def transform_points_np(
    T: Float[np.ndarray, "... d d"],
    pts: Float[np.ndarray, "... n c"]
) -> Float[np.ndarray, "... n 3"]:
    """
    Args:
        T (torch.Tensor): transformation matrix of shape (d, d)
        pts (torch.Tensor): Input points of shape (n, c)
    """
    orig_shape = pts.shape
    pts = pts.reshape(-1, 3)
    if pts.shape[-1] == (T.shape[-1] - 1):
        pts = np.pad(pts, ((0, 0), (0, 1)), constant_values=1)
    pts = np.einsum("...ji,...ni->...nj", T, pts)
    pts = pts[..., :3]
    pts = pts.reshape(orig_shape)
    return pts


def invert_intrinsics(
    K: Float[Tensor, "3 3"]
):
    fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
    K_inv = torch.tensor([
        [1/fx, 0,    -cx/fx],
        [0,    1/fy, -cy/fy],
        [0,    0,    1     ]
    ], device=K.device)
    return K_inv


def se3_from_Rt(
    R: Float[Tensor, "3 3"],
    t: Float[Tensor, "3"]
) -> Float[Tensor, "4 4"]:
    T = torch.eye(4, dtype=R.dtype)
    T[:3, :3] = R
    T[:3,  3] = t
    return T


def invert_se3(
    T: Float[Tensor, "4 4"]
):
    R_ = T[:3, :3].transpose(0, 1)
    t = T[:3, 3]
    t_ = -torch.einsum("ij,j->i", R_, t)
    T_ = torch.eye(4, dtype=T.dtype)
    T_[:3, :3] = R_
    T_[:3, 3] = t_
    return T_


def to_4x4(
    m: Float[Tensor, "3 3"]
):
    m_ = torch.eye(4, dtype=m.dtype)
    m_[:3, :3] = m
    return m_


def project_points(
    K: Float[Tensor, "... d d"],
    pts: Float[Tensor, "... n c"]
):
    """
    Non-differentiable
    """
    if K.shape[-1] == 3:
        K = to_4x4(K)
    xyz = transform_points(K, pts)
    uv = xyz[..., :2] / xyz[..., 2:]
    return uv