vdpm / util /transforms.py
dxm21's picture
Upload folder using huggingface_hub
b678162 verified
import numpy as np
import torch
from torch import Tensor
import torch.nn.functional as F
from jaxtyping import Float
def transform_points(
T: Float[Tensor, "... d d"],
pts: Float[Tensor, "... n c"]
) -> Float[Tensor, "... n 3"]:
"""
Args:
T (torch.Tensor): transformation matrix of shape (d, d)
pts (torch.Tensor): Input points of shape (n, c)
"""
if pts.shape[-1] == (T.shape[-1] - 1):
pts = F.pad(pts, (0, 1), value=1)
pts = torch.einsum("...ji,...ni->...nj", T, pts)
return pts[..., :3]
def transform_points_np(
T: Float[np.ndarray, "... d d"],
pts: Float[np.ndarray, "... n c"]
) -> Float[np.ndarray, "... n 3"]:
"""
Args:
T (torch.Tensor): transformation matrix of shape (d, d)
pts (torch.Tensor): Input points of shape (n, c)
"""
orig_shape = pts.shape
pts = pts.reshape(-1, 3)
if pts.shape[-1] == (T.shape[-1] - 1):
pts = np.pad(pts, ((0, 0), (0, 1)), constant_values=1)
pts = np.einsum("...ji,...ni->...nj", T, pts)
pts = pts[..., :3]
pts = pts.reshape(orig_shape)
return pts
def invert_intrinsics(
K: Float[Tensor, "3 3"]
):
fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
K_inv = torch.tensor([
[1/fx, 0, -cx/fx],
[0, 1/fy, -cy/fy],
[0, 0, 1 ]
], device=K.device)
return K_inv
def se3_from_Rt(
R: Float[Tensor, "3 3"],
t: Float[Tensor, "3"]
) -> Float[Tensor, "4 4"]:
T = torch.eye(4, dtype=R.dtype)
T[:3, :3] = R
T[:3, 3] = t
return T
def invert_se3(
T: Float[Tensor, "4 4"]
):
R_ = T[:3, :3].transpose(0, 1)
t = T[:3, 3]
t_ = -torch.einsum("ij,j->i", R_, t)
T_ = torch.eye(4, dtype=T.dtype)
T_[:3, :3] = R_
T_[:3, 3] = t_
return T_
def to_4x4(
m: Float[Tensor, "3 3"]
):
m_ = torch.eye(4, dtype=m.dtype)
m_[:3, :3] = m
return m_
def project_points(
K: Float[Tensor, "... d d"],
pts: Float[Tensor, "... n c"]
):
"""
Non-differentiable
"""
if K.shape[-1] == 3:
K = to_4x4(K)
xyz = transform_points(K, pts)
uv = xyz[..., :2] / xyz[..., 2:]
return uv