Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
"""Module to perform projective transformations to tensors."""
from typing import List, Tuple
import torch
from kornia.geometry.conversions import angle_axis_to_rotation_matrix, convert_affinematrix_to_homography3d, deg2rad
from kornia.testing import check_is_tensor
from kornia.utils import eye_like
from kornia.utils.helpers import _torch_inverse_cast, _torch_solve_cast
from .homography_warper import homography_warp3d, normalize_homography3d
__all__ = [
"warp_affine3d",
"get_projective_transform",
"projection_from_Rt",
"get_perspective_transform3d",
"warp_perspective3d",
]
def warp_affine3d(
src: torch.Tensor,
M: torch.Tensor,
dsize: Tuple[int, int, int],
flags: str = 'bilinear',
padding_mode: str = 'zeros',
align_corners: bool = True,
) -> torch.Tensor:
r"""Apply a projective transformation a to 3d tensor.
.. warning::
This API signature it is experimental and might suffer some changes in the future.
Args:
src : input tensor of shape :math:`(B, C, D, H, W)`.
M: projective transformation matrix of shape :math:`(B, 3, 4)`.
dsize: size of the output image (depth, height, width).
mode: interpolation mode to calculate output values
``'bilinear'`` | ``'nearest'``.
padding_mode: padding mode for outside grid values
``'zeros'`` | ``'border'`` | ``'reflection'``.
align_corners : mode for grid_generation.
Returns:
torch.Tensor: the warped 3d tensor with shape :math:`(B, C, D, H, W)`.
.. note::
This function is often used in conjunction with :func:`get_perspective_transform3d`.
"""
if len(src.shape) != 5:
raise AssertionError(src.shape)
if not (len(M.shape) == 3 and M.shape[-2:] == (3, 4)):
raise AssertionError(M.shape)
if len(dsize) != 3:
raise AssertionError(dsize)
B, C, D, H, W = src.size()
size_src: Tuple[int, int, int] = (D, H, W)
size_out: Tuple[int, int, int] = dsize
M_4x4 = convert_affinematrix_to_homography3d(M) # Bx4x4
# we need to normalize the transformation since grid sample needs -1/1 coordinates
dst_norm_trans_src_norm: torch.Tensor = normalize_homography3d(M_4x4, size_src, size_out) # Bx4x4
src_norm_trans_dst_norm = _torch_inverse_cast(dst_norm_trans_src_norm)
P_norm: torch.Tensor = src_norm_trans_dst_norm[:, :3] # Bx3x4
# compute meshgrid and apply to input
dsize_out: List[int] = [B, C] + list(size_out)
grid = torch.nn.functional.affine_grid(P_norm, dsize_out, align_corners=align_corners)
return torch.nn.functional.grid_sample(
src, grid, align_corners=align_corners, mode=flags, padding_mode=padding_mode
)
def projection_from_Rt(rmat: torch.Tensor, tvec: torch.Tensor) -> torch.Tensor:
r"""Compute the projection matrix from Rotation and translation.
.. warning::
This API signature it is experimental and might suffer some changes in the future.
Concatenates the batch of rotations and translations such that :math:`P = [R | t]`.
Args:
rmat: the rotation matrix with shape :math:`(*, 3, 3)`.
tvec: the translation vector with shape :math:`(*, 3, 1)`.
Returns:
the projection matrix with shape :math:`(*, 3, 4)`.
"""
if not (len(rmat.shape) >= 2 and rmat.shape[-2:] == (3, 3)):
raise AssertionError(rmat.shape)
if not (len(tvec.shape) >= 2 and tvec.shape[-2:] == (3, 1)):
raise AssertionError(tvec.shape)
return torch.cat([rmat, tvec], dim=-1) # Bx3x4
def get_projective_transform(center: torch.Tensor, angles: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
r"""Calculate the projection matrix for a 3D rotation.
.. warning::
This API signature it is experimental and might suffer some changes in the future.
The function computes the projection matrix given the center and angles per axis.
Args:
center: center of the rotation (x,y,z) in the source with shape :math:`(B, 3)`.
angles: angle axis vector containing the rotation angles in degrees in the form
of (rx, ry, rz) with shape :math:`(B, 3)`. Internally it calls Rodrigues to compute
the rotation matrix from axis-angle.
scales: scale factor for x-y-z-directions with shape :math:`(B, 3)`.
Returns:
the projection matrix of 3D rotation with shape :math:`(B, 3, 4)`.
.. note::
This function is often used in conjunction with :func:`warp_affine3d`.
"""
if not (len(center.shape) == 2 and center.shape[-1] == 3):
raise AssertionError(center.shape)
if not (len(angles.shape) == 2 and angles.shape[-1] == 3):
raise AssertionError(angles.shape)
if center.device != angles.device:
raise AssertionError(center.device, angles.device)
if center.dtype != angles.dtype:
raise AssertionError(center.dtype, angles.dtype)
# create rotation matrix
angle_axis_rad: torch.Tensor = deg2rad(angles)
rmat: torch.Tensor = angle_axis_to_rotation_matrix(angle_axis_rad) # Bx3x3
scaling_matrix: torch.Tensor = eye_like(3, rmat)
scaling_matrix = scaling_matrix * scales.unsqueeze(dim=1)
rmat = rmat @ scaling_matrix.to(rmat)
# define matrix to move forth and back to origin
from_origin_mat = torch.eye(4)[None].repeat(rmat.shape[0], 1, 1).type_as(center) # Bx4x4
from_origin_mat[..., :3, -1] += center
to_origin_mat = from_origin_mat.clone()
to_origin_mat = _torch_inverse_cast(from_origin_mat)
# append translation with zeros
proj_mat = projection_from_Rt(rmat, torch.zeros_like(center)[..., None]) # Bx3x4
# chain 4x4 transforms
proj_mat = convert_affinematrix_to_homography3d(proj_mat) # Bx4x4
proj_mat = from_origin_mat @ proj_mat @ to_origin_mat
return proj_mat[..., :3, :] # Bx3x4
def get_perspective_transform3d(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor:
r"""Calculate a 3d perspective transform from four pairs of the corresponding points.
The function calculates the matrix of a perspective transform so that:
.. math ::
\begin{bmatrix}
t_{i}x_{i}^{'} \\
t_{i}y_{i}^{'} \\
t_{i}z_{i}^{'} \\
t_{i} \\
\end{bmatrix}
=
\textbf{map_matrix} \cdot
\begin{bmatrix}
x_{i} \\
y_{i} \\
z_{i} \\
1 \\
\end{bmatrix}
where
.. math ::
dst(i) = (x_{i}^{'},y_{i}^{'},z_{i}^{'}), src(i) = (x_{i}, y_{i}, z_{i}), i = 0,1,2,5,7
Concrete math is as below:
.. math ::
\[ u_i =\frac{c_{00} * x_i + c_{01} * y_i + c_{02} * z_i + c_{03}}
{c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \]
\[ v_i =\frac{c_{10} * x_i + c_{11} * y_i + c_{12} * z_i + c_{13}}
{c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \]
\[ w_i =\frac{c_{20} * x_i + c_{21} * y_i + c_{22} * z_i + c_{23}}
{c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \]
.. math ::
\begin{pmatrix}
x_0 & y_0 & z_0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_0*u_0 & -y_0*u_0 & -z_0 * u_0 \\
x_1 & y_1 & z_1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_1*u_1 & -y_1*u_1 & -z_1 * u_1 \\
x_2 & y_2 & z_2 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_2*u_2 & -y_2*u_2 & -z_2 * u_2 \\
x_5 & y_5 & z_5 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_5*u_5 & -y_5*u_5 & -z_5 * u_5 \\
x_7 & y_7 & z_7 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_7*u_7 & -y_7*u_7 & -z_7 * u_7 \\
0 & 0 & 0 & 0 & x_0 & y_0 & z_0 & 1 & 0 & 0 & 0 & 0 & -x_0*v_0 & -y_0*v_0 & -z_0 * v_0 \\
0 & 0 & 0 & 0 & x_1 & y_1 & z_1 & 1 & 0 & 0 & 0 & 0 & -x_1*v_1 & -y_1*v_1 & -z_1 * v_1 \\
0 & 0 & 0 & 0 & x_2 & y_2 & z_2 & 1 & 0 & 0 & 0 & 0 & -x_2*v_2 & -y_2*v_2 & -z_2 * v_2 \\
0 & 0 & 0 & 0 & x_5 & y_5 & z_5 & 1 & 0 & 0 & 0 & 0 & -x_5*v_5 & -y_5*v_5 & -z_5 * v_5 \\
0 & 0 & 0 & 0 & x_7 & y_7 & z_7 & 1 & 0 & 0 & 0 & 0 & -x_7*v_7 & -y_7*v_7 & -z_7 * v_7 \\
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_0 & y_0 & z_0 & 1 & -x_0*w_0 & -y_0*w_0 & -z_0 * w_0 \\
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_1 & y_1 & z_1 & 1 & -x_1*w_1 & -y_1*w_1 & -z_1 * w_1 \\
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_2 & y_2 & z_2 & 1 & -x_2*w_2 & -y_2*w_2 & -z_2 * w_2 \\
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_5 & y_5 & z_5 & 1 & -x_5*w_5 & -y_5*w_5 & -z_5 * w_5 \\
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_7 & y_7 & z_7 & 1 & -x_7*w_7 & -y_7*w_7 & -z_7 * w_7 \\
\end{pmatrix}
Args:
src: coordinates of quadrangle vertices in the source image with shape :math:`(B, 8, 3)`.
dst: coordinates of the corresponding quadrangle vertices in
the destination image with shape :math:`(B, 8, 3)`.
Returns:
the perspective transformation with shape :math:`(B, 4, 4)`.
.. note::
This function is often used in conjunction with :func:`warp_perspective3d`.
"""
if not isinstance(src, (torch.Tensor)):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(src)}")
if not isinstance(dst, (torch.Tensor)):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(dst)}")
if not src.shape[-2:] == (8, 3):
raise ValueError(f"Inputs must be a Bx8x3 tensor. Got {src.shape}")
if not src.shape == dst.shape:
raise ValueError(f"Inputs must have the same shape. Got {dst.shape}")
if not (src.shape[0] == dst.shape[0]):
raise ValueError(f"Inputs must have same batch size dimension. Expect {src.shape} but got {dst.shape}")
if not (src.device == dst.device and src.dtype == dst.dtype):
raise AssertionError(
f"Expect `src` and `dst` to be in the same device (Got {src.dtype}, {dst.dtype}) "
f"with the same dtype (Got {src.dtype}, {dst.dtype})."
)
# we build matrix A by using only 4 point correspondence. The linear
# system is solved with the least square method, so here
# we could even pass more correspondence
p = []
# 000, 100, 110, 101, 011
for i in [0, 1, 2, 5, 7]:
p.append(_build_perspective_param3d(src[:, i], dst[:, i], 'x'))
p.append(_build_perspective_param3d(src[:, i], dst[:, i], 'y'))
p.append(_build_perspective_param3d(src[:, i], dst[:, i], 'z'))
# A is Bx15x15
A = torch.stack(p, dim=1)
# b is a Bx15x1
b = torch.stack(
[
dst[:, 0:1, 0],
dst[:, 0:1, 1],
dst[:, 0:1, 2],
dst[:, 1:2, 0],
dst[:, 1:2, 1],
dst[:, 1:2, 2],
dst[:, 2:3, 0],
dst[:, 2:3, 1],
dst[:, 2:3, 2],
# dst[:, 3:4, 0], dst[:, 3:4, 1], dst[:, 3:4, 2],
# dst[:, 4:5, 0], dst[:, 4:5, 1], dst[:, 4:5, 2],
dst[:, 5:6, 0],
dst[:, 5:6, 1],
dst[:, 5:6, 2],
# dst[:, 6:7, 0], dst[:, 6:7, 1], dst[:, 6:7, 2],
dst[:, 7:8, 0],
dst[:, 7:8, 1],
dst[:, 7:8, 2],
],
dim=1,
)
# solve the system Ax = b
X, _ = _torch_solve_cast(b, A)
# create variable to return
batch_size = src.shape[0]
M = torch.ones(batch_size, 16, device=src.device, dtype=src.dtype)
M[..., :15] = torch.squeeze(X, dim=-1)
return M.view(-1, 4, 4) # Bx4x4
def _build_perspective_param3d(p: torch.Tensor, q: torch.Tensor, axis: str) -> torch.Tensor:
ones = torch.ones_like(p)[..., 0:1]
zeros = torch.zeros_like(p)[..., 0:1]
if axis == 'x':
return torch.cat(
[
p[:, 0:1],
p[:, 1:2],
p[:, 2:3],
ones,
zeros,
zeros,
zeros,
zeros,
zeros,
zeros,
zeros,
zeros,
-p[:, 0:1] * q[:, 0:1],
-p[:, 1:2] * q[:, 0:1],
-p[:, 2:3] * q[:, 0:1],
],
dim=1,
)
if axis == 'y':
return torch.cat(
[
zeros,
zeros,
zeros,
zeros,
p[:, 0:1],
p[:, 1:2],
p[:, 2:3],
ones,
zeros,
zeros,
zeros,
zeros,
-p[:, 0:1] * q[:, 1:2],
-p[:, 1:2] * q[:, 1:2],
-p[:, 2:3] * q[:, 1:2],
],
dim=1,
)
if axis == 'z':
return torch.cat(
[
zeros,
zeros,
zeros,
zeros,
zeros,
zeros,
zeros,
zeros,
p[:, 0:1],
p[:, 1:2],
p[:, 2:3],
ones,
-p[:, 0:1] * q[:, 2:3],
-p[:, 1:2] * q[:, 2:3],
-p[:, 2:3] * q[:, 2:3],
],
dim=1,
)
raise NotImplementedError(f"perspective params for axis `{axis}` is not implemented.")
def warp_perspective3d(
src: torch.Tensor,
M: torch.Tensor,
dsize: Tuple[int, int, int],
flags: str = 'bilinear',
border_mode: str = 'zeros',
align_corners: bool = False,
) -> torch.Tensor:
r"""Apply a perspective transformation to an image.
The function warp_perspective transforms the source image using
the specified matrix:
.. math::
\text{dst} (x, y) = \text{src} \left(
\frac{M_{11} x + M_{12} y + M_{13}}{M_{31} x + M_{32} y + M_{33}} ,
\frac{M_{21} x + M_{22} y + M_{23}}{M_{31} x + M_{32} y + M_{33}}
\right )
Args:
src: input image with shape :math:`(B, C, D, H, W)`.
M: transformation matrix with shape :math:`(B, 4, 4)`.
dsize: size of the output image (height, width).
flags: interpolation mode to calculate output values
``'bilinear'`` | ``'nearest'``.
border_mode: padding mode for outside grid values
``'zeros'`` | ``'border'`` | ``'reflection'``.
align_corners: interpolation flag.
Returns:
the warped input image :math:`(B, C, D, H, W)`.
.. note::
This function is often used in conjunction with :func:`get_perspective_transform3d`.
"""
check_is_tensor(src)
check_is_tensor(M)
if not len(src.shape) == 5:
raise ValueError(f"Input src must be a BxCxDxHxW tensor. Got {src.shape}")
if not (len(M.shape) == 3 or M.shape[-2:] == (4, 4)):
raise ValueError(f"Input M must be a Bx4x4 tensor. Got {M.shape}")
# launches the warper
d, h, w = src.shape[-3:]
return transform_warp_impl3d(src, M, (d, h, w), dsize, flags, border_mode, align_corners)
def transform_warp_impl3d(
src: torch.Tensor,
dst_pix_trans_src_pix: torch.Tensor,
dsize_src: Tuple[int, int, int],
dsize_dst: Tuple[int, int, int],
grid_mode: str,
padding_mode: str,
align_corners: bool,
) -> torch.Tensor:
"""Compute the transform in normalized coordinates and perform the warping."""
dst_norm_trans_src_norm: torch.Tensor = normalize_homography3d(dst_pix_trans_src_pix, dsize_src, dsize_dst)
src_norm_trans_dst_norm = torch.inverse(dst_norm_trans_src_norm)
return homography_warp3d(src, src_norm_trans_dst_norm, dsize_dst, grid_mode, padding_mode, align_corners, True)