| | """Module to perform projective transformations to tensors.""" |
| | from typing import List, Tuple |
| |
|
| | import torch |
| |
|
| | from kornia.geometry.conversions import angle_axis_to_rotation_matrix, convert_affinematrix_to_homography3d, deg2rad |
| | from kornia.testing import check_is_tensor |
| | from kornia.utils import eye_like |
| | from kornia.utils.helpers import _torch_inverse_cast, _torch_solve_cast |
| |
|
| | from .homography_warper import homography_warp3d, normalize_homography3d |
| |
|
| | __all__ = [ |
| | "warp_affine3d", |
| | "get_projective_transform", |
| | "projection_from_Rt", |
| | "get_perspective_transform3d", |
| | "warp_perspective3d", |
| | ] |
| |
|
| |
|
| | def warp_affine3d( |
| | src: torch.Tensor, |
| | M: torch.Tensor, |
| | dsize: Tuple[int, int, int], |
| | flags: str = 'bilinear', |
| | padding_mode: str = 'zeros', |
| | align_corners: bool = True, |
| | ) -> torch.Tensor: |
| | r"""Apply a projective transformation a to 3d tensor. |
| | |
| | .. warning:: |
| | This API signature it is experimental and might suffer some changes in the future. |
| | |
| | Args: |
| | src : input tensor of shape :math:`(B, C, D, H, W)`. |
| | M: projective transformation matrix of shape :math:`(B, 3, 4)`. |
| | dsize: size of the output image (depth, height, width). |
| | mode: interpolation mode to calculate output values |
| | ``'bilinear'`` | ``'nearest'``. |
| | padding_mode: padding mode for outside grid values |
| | ``'zeros'`` | ``'border'`` | ``'reflection'``. |
| | align_corners : mode for grid_generation. |
| | |
| | Returns: |
| | torch.Tensor: the warped 3d tensor with shape :math:`(B, C, D, H, W)`. |
| | |
| | .. note:: |
| | This function is often used in conjunction with :func:`get_perspective_transform3d`. |
| | """ |
| | if len(src.shape) != 5: |
| | raise AssertionError(src.shape) |
| | if not (len(M.shape) == 3 and M.shape[-2:] == (3, 4)): |
| | raise AssertionError(M.shape) |
| | if len(dsize) != 3: |
| | raise AssertionError(dsize) |
| | B, C, D, H, W = src.size() |
| |
|
| | size_src: Tuple[int, int, int] = (D, H, W) |
| | size_out: Tuple[int, int, int] = dsize |
| |
|
| | M_4x4 = convert_affinematrix_to_homography3d(M) |
| |
|
| | |
| | dst_norm_trans_src_norm: torch.Tensor = normalize_homography3d(M_4x4, size_src, size_out) |
| |
|
| | src_norm_trans_dst_norm = _torch_inverse_cast(dst_norm_trans_src_norm) |
| | P_norm: torch.Tensor = src_norm_trans_dst_norm[:, :3] |
| |
|
| | |
| | dsize_out: List[int] = [B, C] + list(size_out) |
| | grid = torch.nn.functional.affine_grid(P_norm, dsize_out, align_corners=align_corners) |
| | return torch.nn.functional.grid_sample( |
| | src, grid, align_corners=align_corners, mode=flags, padding_mode=padding_mode |
| | ) |
| |
|
| |
|
| | def projection_from_Rt(rmat: torch.Tensor, tvec: torch.Tensor) -> torch.Tensor: |
| | r"""Compute the projection matrix from Rotation and translation. |
| | |
| | .. warning:: |
| | This API signature it is experimental and might suffer some changes in the future. |
| | |
| | Concatenates the batch of rotations and translations such that :math:`P = [R | t]`. |
| | |
| | Args: |
| | rmat: the rotation matrix with shape :math:`(*, 3, 3)`. |
| | tvec: the translation vector with shape :math:`(*, 3, 1)`. |
| | |
| | Returns: |
| | the projection matrix with shape :math:`(*, 3, 4)`. |
| | |
| | """ |
| | if not (len(rmat.shape) >= 2 and rmat.shape[-2:] == (3, 3)): |
| | raise AssertionError(rmat.shape) |
| | if not (len(tvec.shape) >= 2 and tvec.shape[-2:] == (3, 1)): |
| | raise AssertionError(tvec.shape) |
| |
|
| | return torch.cat([rmat, tvec], dim=-1) |
| |
|
| |
|
| | def get_projective_transform(center: torch.Tensor, angles: torch.Tensor, scales: torch.Tensor) -> torch.Tensor: |
| | r"""Calculate the projection matrix for a 3D rotation. |
| | |
| | .. warning:: |
| | This API signature it is experimental and might suffer some changes in the future. |
| | |
| | The function computes the projection matrix given the center and angles per axis. |
| | |
| | Args: |
| | center: center of the rotation (x,y,z) in the source with shape :math:`(B, 3)`. |
| | angles: angle axis vector containing the rotation angles in degrees in the form |
| | of (rx, ry, rz) with shape :math:`(B, 3)`. Internally it calls Rodrigues to compute |
| | the rotation matrix from axis-angle. |
| | scales: scale factor for x-y-z-directions with shape :math:`(B, 3)`. |
| | |
| | Returns: |
| | the projection matrix of 3D rotation with shape :math:`(B, 3, 4)`. |
| | |
| | .. note:: |
| | This function is often used in conjunction with :func:`warp_affine3d`. |
| | """ |
| | if not (len(center.shape) == 2 and center.shape[-1] == 3): |
| | raise AssertionError(center.shape) |
| | if not (len(angles.shape) == 2 and angles.shape[-1] == 3): |
| | raise AssertionError(angles.shape) |
| | if center.device != angles.device: |
| | raise AssertionError(center.device, angles.device) |
| | if center.dtype != angles.dtype: |
| | raise AssertionError(center.dtype, angles.dtype) |
| |
|
| | |
| | angle_axis_rad: torch.Tensor = deg2rad(angles) |
| | rmat: torch.Tensor = angle_axis_to_rotation_matrix(angle_axis_rad) |
| | scaling_matrix: torch.Tensor = eye_like(3, rmat) |
| | scaling_matrix = scaling_matrix * scales.unsqueeze(dim=1) |
| | rmat = rmat @ scaling_matrix.to(rmat) |
| |
|
| | |
| | from_origin_mat = torch.eye(4)[None].repeat(rmat.shape[0], 1, 1).type_as(center) |
| | from_origin_mat[..., :3, -1] += center |
| |
|
| | to_origin_mat = from_origin_mat.clone() |
| | to_origin_mat = _torch_inverse_cast(from_origin_mat) |
| |
|
| | |
| | proj_mat = projection_from_Rt(rmat, torch.zeros_like(center)[..., None]) |
| |
|
| | |
| | proj_mat = convert_affinematrix_to_homography3d(proj_mat) |
| | proj_mat = from_origin_mat @ proj_mat @ to_origin_mat |
| |
|
| | return proj_mat[..., :3, :] |
| |
|
| |
|
| | def get_perspective_transform3d(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor: |
| | r"""Calculate a 3d perspective transform from four pairs of the corresponding points. |
| | |
| | The function calculates the matrix of a perspective transform so that: |
| | |
| | .. math :: |
| | |
| | \begin{bmatrix} |
| | t_{i}x_{i}^{'} \\ |
| | t_{i}y_{i}^{'} \\ |
| | t_{i}z_{i}^{'} \\ |
| | t_{i} \\ |
| | \end{bmatrix} |
| | = |
| | \textbf{map_matrix} \cdot |
| | \begin{bmatrix} |
| | x_{i} \\ |
| | y_{i} \\ |
| | z_{i} \\ |
| | 1 \\ |
| | \end{bmatrix} |
| | |
| | where |
| | |
| | .. math :: |
| | dst(i) = (x_{i}^{'},y_{i}^{'},z_{i}^{'}), src(i) = (x_{i}, y_{i}, z_{i}), i = 0,1,2,5,7 |
| | |
| | Concrete math is as below: |
| | |
| | .. math :: |
| | |
| | \[ u_i =\frac{c_{00} * x_i + c_{01} * y_i + c_{02} * z_i + c_{03}} |
| | {c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \] |
| | \[ v_i =\frac{c_{10} * x_i + c_{11} * y_i + c_{12} * z_i + c_{13}} |
| | {c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \] |
| | \[ w_i =\frac{c_{20} * x_i + c_{21} * y_i + c_{22} * z_i + c_{23}} |
| | {c_{30} * x_i + c_{31} * y_i + c_{32} * z_i + c_{33}} \] |
| | |
| | .. math :: |
| | |
| | \begin{pmatrix} |
| | x_0 & y_0 & z_0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_0*u_0 & -y_0*u_0 & -z_0 * u_0 \\ |
| | x_1 & y_1 & z_1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_1*u_1 & -y_1*u_1 & -z_1 * u_1 \\ |
| | x_2 & y_2 & z_2 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_2*u_2 & -y_2*u_2 & -z_2 * u_2 \\ |
| | x_5 & y_5 & z_5 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_5*u_5 & -y_5*u_5 & -z_5 * u_5 \\ |
| | x_7 & y_7 & z_7 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & -x_7*u_7 & -y_7*u_7 & -z_7 * u_7 \\ |
| | 0 & 0 & 0 & 0 & x_0 & y_0 & z_0 & 1 & 0 & 0 & 0 & 0 & -x_0*v_0 & -y_0*v_0 & -z_0 * v_0 \\ |
| | 0 & 0 & 0 & 0 & x_1 & y_1 & z_1 & 1 & 0 & 0 & 0 & 0 & -x_1*v_1 & -y_1*v_1 & -z_1 * v_1 \\ |
| | 0 & 0 & 0 & 0 & x_2 & y_2 & z_2 & 1 & 0 & 0 & 0 & 0 & -x_2*v_2 & -y_2*v_2 & -z_2 * v_2 \\ |
| | 0 & 0 & 0 & 0 & x_5 & y_5 & z_5 & 1 & 0 & 0 & 0 & 0 & -x_5*v_5 & -y_5*v_5 & -z_5 * v_5 \\ |
| | 0 & 0 & 0 & 0 & x_7 & y_7 & z_7 & 1 & 0 & 0 & 0 & 0 & -x_7*v_7 & -y_7*v_7 & -z_7 * v_7 \\ |
| | 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_0 & y_0 & z_0 & 1 & -x_0*w_0 & -y_0*w_0 & -z_0 * w_0 \\ |
| | 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_1 & y_1 & z_1 & 1 & -x_1*w_1 & -y_1*w_1 & -z_1 * w_1 \\ |
| | 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_2 & y_2 & z_2 & 1 & -x_2*w_2 & -y_2*w_2 & -z_2 * w_2 \\ |
| | 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_5 & y_5 & z_5 & 1 & -x_5*w_5 & -y_5*w_5 & -z_5 * w_5 \\ |
| | 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & x_7 & y_7 & z_7 & 1 & -x_7*w_7 & -y_7*w_7 & -z_7 * w_7 \\ |
| | \end{pmatrix} |
| | |
| | Args: |
| | src: coordinates of quadrangle vertices in the source image with shape :math:`(B, 8, 3)`. |
| | dst: coordinates of the corresponding quadrangle vertices in |
| | the destination image with shape :math:`(B, 8, 3)`. |
| | |
| | Returns: |
| | the perspective transformation with shape :math:`(B, 4, 4)`. |
| | |
| | .. note:: |
| | This function is often used in conjunction with :func:`warp_perspective3d`. |
| | """ |
| | if not isinstance(src, (torch.Tensor)): |
| | raise TypeError(f"Input type is not a torch.Tensor. Got {type(src)}") |
| |
|
| | if not isinstance(dst, (torch.Tensor)): |
| | raise TypeError(f"Input type is not a torch.Tensor. Got {type(dst)}") |
| |
|
| | if not src.shape[-2:] == (8, 3): |
| | raise ValueError(f"Inputs must be a Bx8x3 tensor. Got {src.shape}") |
| |
|
| | if not src.shape == dst.shape: |
| | raise ValueError(f"Inputs must have the same shape. Got {dst.shape}") |
| |
|
| | if not (src.shape[0] == dst.shape[0]): |
| | raise ValueError(f"Inputs must have same batch size dimension. Expect {src.shape} but got {dst.shape}") |
| |
|
| | if not (src.device == dst.device and src.dtype == dst.dtype): |
| | raise AssertionError( |
| | f"Expect `src` and `dst` to be in the same device (Got {src.dtype}, {dst.dtype}) " |
| | f"with the same dtype (Got {src.dtype}, {dst.dtype})." |
| | ) |
| |
|
| | |
| | |
| | |
| | p = [] |
| |
|
| | |
| | for i in [0, 1, 2, 5, 7]: |
| | p.append(_build_perspective_param3d(src[:, i], dst[:, i], 'x')) |
| | p.append(_build_perspective_param3d(src[:, i], dst[:, i], 'y')) |
| | p.append(_build_perspective_param3d(src[:, i], dst[:, i], 'z')) |
| |
|
| | |
| | A = torch.stack(p, dim=1) |
| |
|
| | |
| | b = torch.stack( |
| | [ |
| | dst[:, 0:1, 0], |
| | dst[:, 0:1, 1], |
| | dst[:, 0:1, 2], |
| | dst[:, 1:2, 0], |
| | dst[:, 1:2, 1], |
| | dst[:, 1:2, 2], |
| | dst[:, 2:3, 0], |
| | dst[:, 2:3, 1], |
| | dst[:, 2:3, 2], |
| | |
| | |
| | dst[:, 5:6, 0], |
| | dst[:, 5:6, 1], |
| | dst[:, 5:6, 2], |
| | |
| | dst[:, 7:8, 0], |
| | dst[:, 7:8, 1], |
| | dst[:, 7:8, 2], |
| | ], |
| | dim=1, |
| | ) |
| |
|
| | |
| | X, _ = _torch_solve_cast(b, A) |
| |
|
| | |
| | batch_size = src.shape[0] |
| | M = torch.ones(batch_size, 16, device=src.device, dtype=src.dtype) |
| | M[..., :15] = torch.squeeze(X, dim=-1) |
| | return M.view(-1, 4, 4) |
| |
|
| |
|
| | def _build_perspective_param3d(p: torch.Tensor, q: torch.Tensor, axis: str) -> torch.Tensor: |
| | ones = torch.ones_like(p)[..., 0:1] |
| | zeros = torch.zeros_like(p)[..., 0:1] |
| |
|
| | if axis == 'x': |
| | return torch.cat( |
| | [ |
| | p[:, 0:1], |
| | p[:, 1:2], |
| | p[:, 2:3], |
| | ones, |
| | zeros, |
| | zeros, |
| | zeros, |
| | zeros, |
| | zeros, |
| | zeros, |
| | zeros, |
| | zeros, |
| | -p[:, 0:1] * q[:, 0:1], |
| | -p[:, 1:2] * q[:, 0:1], |
| | -p[:, 2:3] * q[:, 0:1], |
| | ], |
| | dim=1, |
| | ) |
| |
|
| | if axis == 'y': |
| | return torch.cat( |
| | [ |
| | zeros, |
| | zeros, |
| | zeros, |
| | zeros, |
| | p[:, 0:1], |
| | p[:, 1:2], |
| | p[:, 2:3], |
| | ones, |
| | zeros, |
| | zeros, |
| | zeros, |
| | zeros, |
| | -p[:, 0:1] * q[:, 1:2], |
| | -p[:, 1:2] * q[:, 1:2], |
| | -p[:, 2:3] * q[:, 1:2], |
| | ], |
| | dim=1, |
| | ) |
| |
|
| | if axis == 'z': |
| | return torch.cat( |
| | [ |
| | zeros, |
| | zeros, |
| | zeros, |
| | zeros, |
| | zeros, |
| | zeros, |
| | zeros, |
| | zeros, |
| | p[:, 0:1], |
| | p[:, 1:2], |
| | p[:, 2:3], |
| | ones, |
| | -p[:, 0:1] * q[:, 2:3], |
| | -p[:, 1:2] * q[:, 2:3], |
| | -p[:, 2:3] * q[:, 2:3], |
| | ], |
| | dim=1, |
| | ) |
| |
|
| | raise NotImplementedError(f"perspective params for axis `{axis}` is not implemented.") |
| |
|
| |
|
| | def warp_perspective3d( |
| | src: torch.Tensor, |
| | M: torch.Tensor, |
| | dsize: Tuple[int, int, int], |
| | flags: str = 'bilinear', |
| | border_mode: str = 'zeros', |
| | align_corners: bool = False, |
| | ) -> torch.Tensor: |
| | r"""Apply a perspective transformation to an image. |
| | |
| | The function warp_perspective transforms the source image using |
| | the specified matrix: |
| | |
| | .. math:: |
| | \text{dst} (x, y) = \text{src} \left( |
| | \frac{M_{11} x + M_{12} y + M_{13}}{M_{31} x + M_{32} y + M_{33}} , |
| | \frac{M_{21} x + M_{22} y + M_{23}}{M_{31} x + M_{32} y + M_{33}} |
| | \right ) |
| | |
| | Args: |
| | src: input image with shape :math:`(B, C, D, H, W)`. |
| | M: transformation matrix with shape :math:`(B, 4, 4)`. |
| | dsize: size of the output image (height, width). |
| | flags: interpolation mode to calculate output values |
| | ``'bilinear'`` | ``'nearest'``. |
| | border_mode: padding mode for outside grid values |
| | ``'zeros'`` | ``'border'`` | ``'reflection'``. |
| | align_corners: interpolation flag. |
| | |
| | Returns: |
| | the warped input image :math:`(B, C, D, H, W)`. |
| | |
| | .. note:: |
| | This function is often used in conjunction with :func:`get_perspective_transform3d`. |
| | """ |
| | check_is_tensor(src) |
| | check_is_tensor(M) |
| |
|
| | if not len(src.shape) == 5: |
| | raise ValueError(f"Input src must be a BxCxDxHxW tensor. Got {src.shape}") |
| |
|
| | if not (len(M.shape) == 3 or M.shape[-2:] == (4, 4)): |
| | raise ValueError(f"Input M must be a Bx4x4 tensor. Got {M.shape}") |
| |
|
| | |
| | d, h, w = src.shape[-3:] |
| | return transform_warp_impl3d(src, M, (d, h, w), dsize, flags, border_mode, align_corners) |
| |
|
| |
|
| | def transform_warp_impl3d( |
| | src: torch.Tensor, |
| | dst_pix_trans_src_pix: torch.Tensor, |
| | dsize_src: Tuple[int, int, int], |
| | dsize_dst: Tuple[int, int, int], |
| | grid_mode: str, |
| | padding_mode: str, |
| | align_corners: bool, |
| | ) -> torch.Tensor: |
| | """Compute the transform in normalized coordinates and perform the warping.""" |
| | dst_norm_trans_src_norm: torch.Tensor = normalize_homography3d(dst_pix_trans_src_pix, dsize_src, dsize_dst) |
| |
|
| | src_norm_trans_dst_norm = torch.inverse(dst_norm_trans_src_norm) |
| | return homography_warp3d(src, src_norm_trans_dst_norm, dsize_dst, grid_mode, padding_mode, align_corners, True) |
| |
|