Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import functools | |
| from typing import Tuple, Optional | |
| ########################## | |
| #### from pytorch3d #### | |
| ########################## | |
| def _axis_angle_rotation(axis: str, angle): | |
| """ | |
| Return the rotation matrices for one of the rotations about an axis | |
| of which Euler angles describe, for each value of the angle given. | |
| Args: | |
| axis: Axis label "X" or "Y or "Z". | |
| angle: any shape tensor of Euler angles in radians | |
| Returns: | |
| Rotation matrices as tensor of shape (..., 3, 3). | |
| """ | |
| cos = torch.cos(angle) | |
| sin = torch.sin(angle) | |
| one = torch.ones_like(angle) | |
| zero = torch.zeros_like(angle) | |
| if axis == "X": | |
| R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) | |
| if axis == "Y": | |
| R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) | |
| if axis == "Z": | |
| R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) | |
| return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) | |
| def euler_angles_to_matrix(euler_angles, convention: str): | |
| """ | |
| Convert rotations given as Euler angles in radians to rotation matrices. | |
| Args: | |
| euler_angles: Euler angles in radians as tensor of shape (..., 3). | |
| convention: Convention string of three uppercase letters from | |
| {"X", "Y", and "Z"}. | |
| Returns: | |
| Rotation matrices as tensor of shape (..., 3, 3). | |
| """ | |
| if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: | |
| raise ValueError("Invalid input euler angles.") | |
| if len(convention) != 3: | |
| raise ValueError("Convention must have 3 letters.") | |
| if convention[1] in (convention[0], convention[2]): | |
| raise ValueError(f"Invalid convention {convention}.") | |
| for letter in convention: | |
| if letter not in ("X", "Y", "Z"): | |
| raise ValueError(f"Invalid letter {letter} in convention string.") | |
| matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) | |
| return functools.reduce(torch.matmul, matrices) | |
| ########################### | |
| #### from pytorchgemotry #### | |
| ########################### | |
| def get_perspective_transform(src, dst): | |
| r"""Calculates a perspective transform from four pairs of the corresponding | |
| points. | |
| The function calculates the matrix of a perspective transform so that: | |
| .. math :: | |
| \begin{bmatrix} | |
| t_{i}x_{i}^{'} \\ | |
| t_{i}y_{i}^{'} \\ | |
| t_{i} \\ | |
| \end{bmatrix} | |
| = | |
| \textbf{map_matrix} \cdot | |
| \begin{bmatrix} | |
| x_{i} \\ | |
| y_{i} \\ | |
| 1 \\ | |
| \end{bmatrix} | |
| where | |
| .. math :: | |
| dst(i) = (x_{i}^{'},y_{i}^{'}), src(i) = (x_{i}, y_{i}), i = 0,1,2,3 | |
| Args: | |
| src (Tensor): coordinates of quadrangle vertices in the source image. | |
| dst (Tensor): coordinates of the corresponding quadrangle vertices in | |
| the destination image. | |
| Returns: | |
| Tensor: the perspective transformation. | |
| Shape: | |
| - Input: :math:`(B, 4, 2)` and :math:`(B, 4, 2)` | |
| - Output: :math:`(B, 3, 3)` | |
| """ | |
| if not torch.is_tensor(src): | |
| raise TypeError("Input type is not a torch.Tensor. Got {}" | |
| .format(type(src))) | |
| if not torch.is_tensor(dst): | |
| raise TypeError("Input type is not a torch.Tensor. Got {}" | |
| .format(type(dst))) | |
| if not src.shape[-2:] == (4, 2): | |
| raise ValueError("Inputs must be a Bx4x2 tensor. Got {}" | |
| .format(src.shape)) | |
| if not src.shape == dst.shape: | |
| raise ValueError("Inputs must have the same shape. Got {}" | |
| .format(dst.shape)) | |
| if not (src.shape[0] == dst.shape[0]): | |
| raise ValueError("Inputs must have same batch size dimension. Got {}" | |
| .format(src.shape, dst.shape)) | |
| def ax(p, q): | |
| ones = torch.ones_like(p)[..., 0:1] | |
| zeros = torch.zeros_like(p)[..., 0:1] | |
| return torch.cat( | |
| [p[:, 0:1], p[:, 1:2], ones, zeros, zeros, zeros, | |
| -p[:, 0:1] * q[:, 0:1], -p[:, 1:2] * q[:, 0:1] | |
| ], dim=1) | |
| def ay(p, q): | |
| ones = torch.ones_like(p)[..., 0:1] | |
| zeros = torch.zeros_like(p)[..., 0:1] | |
| return torch.cat( | |
| [zeros, zeros, zeros, p[:, 0:1], p[:, 1:2], ones, | |
| -p[:, 0:1] * q[:, 1:2], -p[:, 1:2] * q[:, 1:2]], dim=1) | |
| # we build matrix A by using only 4 point correspondence. The linear | |
| # system is solved with the least square method, so here | |
| # we could even pass more correspondence | |
| p = [] | |
| p.append(ax(src[:, 0], dst[:, 0])) | |
| p.append(ay(src[:, 0], dst[:, 0])) | |
| p.append(ax(src[:, 1], dst[:, 1])) | |
| p.append(ay(src[:, 1], dst[:, 1])) | |
| p.append(ax(src[:, 2], dst[:, 2])) | |
| p.append(ay(src[:, 2], dst[:, 2])) | |
| p.append(ax(src[:, 3], dst[:, 3])) | |
| p.append(ay(src[:, 3], dst[:, 3])) | |
| # A is Bx8x8 | |
| A = torch.stack(p, dim=1) | |
| # b is a Bx8x1 | |
| b = torch.stack([ | |
| dst[:, 0:1, 0], dst[:, 0:1, 1], | |
| dst[:, 1:2, 0], dst[:, 1:2, 1], | |
| dst[:, 2:3, 0], dst[:, 2:3, 1], | |
| dst[:, 3:4, 0], dst[:, 3:4, 1], | |
| ], dim=1) | |
| # solve the system Ax = b | |
| # X, LU = torch.gesv(b, A) | |
| X = torch.linalg.solve(A, b) | |
| # create variable to return | |
| batch_size = src.shape[0] | |
| M = torch.ones(batch_size, 9, device=src.device, dtype=src.dtype) | |
| M[..., :8] = torch.squeeze(X, dim=-1) | |
| return M.view(-1, 3, 3) # Bx3x3 | |
| def warp_perspective(src, M, dsize, flags='bilinear', border_mode=None, | |
| border_value=0): | |
| r"""Applies a perspective transformation to an image. | |
| The function warp_perspective transforms the source image using | |
| the specified matrix: | |
| .. math:: | |
| \text{dst} (x, y) = \text{src} \left( | |
| \frac{M_{11} x + M_{12} y + M_{13}}{M_{31} x + M_{32} y + M_{33}} , | |
| \frac{M_{21} x + M_{22} y + M_{23}}{M_{31} x + M_{32} y + M_{33}} | |
| \right ) | |
| Args: | |
| src (torch.Tensor): input image. | |
| M (Tensor): transformation matrix. | |
| dsize (tuple): size of the output image (height, width). | |
| Returns: | |
| Tensor: the warped input image. | |
| Shape: | |
| - Input: :math:`(B, C, H, W)` and :math:`(B, 3, 3)` | |
| - Output: :math:`(B, C, H, W)` | |
| .. note:: | |
| See a working example `here <https://github.com/arraiy/torchgeometry/ | |
| blob/master/examples/warp_perspective.ipynb>`_. | |
| """ | |
| if not torch.is_tensor(src): | |
| raise TypeError("Input src type is not a torch.Tensor. Got {}" | |
| .format(type(src))) | |
| if not torch.is_tensor(M): | |
| raise TypeError("Input M type is not a torch.Tensor. Got {}" | |
| .format(type(M))) | |
| if not len(src.shape) == 4: | |
| raise ValueError("Input src must be a BxCxHxW tensor. Got {}" | |
| .format(src.shape)) | |
| if not (len(M.shape) == 3 or M.shape[-2:] == (3, 3)): | |
| raise ValueError("Input M must be a Bx3x3 tensor. Got {}" | |
| .format(src.shape)) | |
| # launches the warper | |
| return transform_warp_impl(src, M, (src.shape[-2:]), dsize) | |
| def transform_warp_impl(src, dst_pix_trans_src_pix, dsize_src, dsize_dst): | |
| """Compute the transform in normalized cooridnates and perform the warping. | |
| """ | |
| dst_norm_trans_dst_norm = dst_norm_to_dst_norm( | |
| dst_pix_trans_src_pix, dsize_src, dsize_dst) | |
| return homography_warp(src, torch.inverse( | |
| dst_norm_trans_dst_norm), dsize_dst) | |
| def dst_norm_to_dst_norm(dst_pix_trans_src_pix, dsize_src, dsize_dst): | |
| # source and destination sizes | |
| src_h, src_w = dsize_src | |
| dst_h, dst_w = dsize_dst | |
| # the devices and types | |
| device = dst_pix_trans_src_pix.device | |
| dtype = dst_pix_trans_src_pix.dtype | |
| # compute the transformation pixel/norm for src/dst | |
| src_norm_trans_src_pix = normal_transform_pixel( | |
| src_h, src_w).to(device).to(dtype) | |
| src_pix_trans_src_norm = torch.inverse(src_norm_trans_src_pix) | |
| dst_norm_trans_dst_pix = normal_transform_pixel( | |
| dst_h, dst_w).to(device).to(dtype) | |
| # compute chain transformations | |
| dst_norm_trans_src_norm = torch.matmul( | |
| dst_norm_trans_dst_pix, torch.matmul( | |
| dst_pix_trans_src_pix, src_pix_trans_src_norm)) | |
| return dst_norm_trans_src_norm | |
| def normal_transform_pixel(height, width): | |
| tr_mat = torch.Tensor([[1.0, 0.0, -1.0], | |
| [0.0, 1.0, -1.0], | |
| [0.0, 0.0, 1.0]]) # 1x3x3 | |
| tr_mat[0, 0] = tr_mat[0, 0] * 2.0 / (width - 1.0) | |
| tr_mat[1, 1] = tr_mat[1, 1] * 2.0 / (height - 1.0) | |
| tr_mat = tr_mat.unsqueeze(0) | |
| return tr_mat | |
| def homography_warp(patch_src: torch.Tensor, | |
| dst_homo_src: torch.Tensor, | |
| dsize: Tuple[int, int], | |
| mode: Optional[str] = 'bilinear', | |
| padding_mode: Optional[str] = 'zeros') -> torch.Tensor: | |
| r"""Function that warps image patchs or tensors by homographies. | |
| See :class:`~torchgeometry.HomographyWarper` for details. | |
| Args: | |
| patch_src (torch.Tensor): The image or tensor to warp. Should be from | |
| source of shape :math:`(N, C, H, W)`. | |
| dst_homo_src (torch.Tensor): The homography or stack of homographies | |
| from source to destination of shape | |
| :math:`(N, 3, 3)`. | |
| dsize (Tuple[int, int]): The height and width of the image to warp. | |
| mode (Optional[str]): interpolation mode to calculate output values | |
| 'bilinear' | 'nearest'. Default: 'bilinear'. | |
| padding_mode (Optional[str]): padding mode for outside grid values | |
| 'zeros' | 'border' | 'reflection'. Default: 'zeros'. | |
| Return: | |
| torch.Tensor: Patch sampled at locations from source to destination. | |
| Example: | |
| >>> input = torch.rand(1, 3, 32, 32) | |
| >>> homography = torch.eye(3).view(1, 3, 3) | |
| >>> output = tgm.homography_warp(input, homography, (32, 32)) # NxCxHxW | |
| """ | |
| height, width = dsize | |
| warper = HomographyWarper(height, width, mode, padding_mode) | |
| return warper(patch_src, dst_homo_src) | |
| class HomographyWarper(nn.Module): | |
| r"""Warps image patches or tensors by homographies. | |
| .. math:: | |
| X_{dst} = H_{src}^{\{dst\}} * X_{src} | |
| Args: | |
| height (int): The height of the image to warp. | |
| width (int): The width of the image to warp. | |
| mode (Optional[str]): interpolation mode to calculate output values | |
| 'bilinear' | 'nearest'. Default: 'bilinear'. | |
| padding_mode (Optional[str]): padding mode for outside grid values | |
| 'zeros' | 'border' | 'reflection'. Default: 'zeros'. | |
| normalized_coordinates (Optional[bool]): wether to use a grid with | |
| normalized coordinates. | |
| """ | |
| def __init__( | |
| self, | |
| height: int, | |
| width: int, | |
| mode: Optional[str] = 'bilinear', | |
| padding_mode: Optional[str] = 'zeros', | |
| normalized_coordinates: Optional[bool] = True) -> None: | |
| super(HomographyWarper, self).__init__() | |
| self.width: int = width | |
| self.height: int = height | |
| self.mode: Optional[str] = mode | |
| self.padding_mode: Optional[str] = padding_mode | |
| self.normalized_coordinates: Optional[bool] = normalized_coordinates | |
| # create base grid to compute the flow | |
| self.grid: torch.Tensor = create_meshgrid( | |
| height, width, normalized_coordinates=normalized_coordinates) | |
| def warp_grid(self, dst_homo_src: torch.Tensor) -> torch.Tensor: | |
| r"""Computes the grid to warp the coordinates grid by an homography. | |
| Args: | |
| dst_homo_src (torch.Tensor): Homography or homographies (stacked) to | |
| transform all points in the grid. Shape of the | |
| homography has to be :math:`(N, 3, 3)`. | |
| Returns: | |
| torch.Tensor: the transformed grid of shape :math:`(N, H, W, 2)`. | |
| """ | |
| batch_size: int = dst_homo_src.shape[0] | |
| device: torch.device = dst_homo_src.device | |
| dtype: torch.dtype = dst_homo_src.dtype | |
| # expand grid to match the input batch size | |
| grid: torch.Tensor = self.grid.expand(batch_size, -1, -1, -1) # NxHxWx2 | |
| if len(dst_homo_src.shape) == 3: # local homography case | |
| dst_homo_src = dst_homo_src.view(batch_size, 1, 3, 3) # NxHxWx3x3 | |
| # perform the actual grid transformation, | |
| # the grid is copied to input device and casted to the same type | |
| flow: torch.Tensor = transform_points( | |
| dst_homo_src, grid.to(device).to(dtype)) # NxHxWx2 | |
| return flow.view(batch_size, self.height, self.width, 2) # NxHxWx2 | |
| def forward( | |
| self, | |
| patch_src: torch.Tensor, | |
| dst_homo_src: torch.Tensor) -> torch.Tensor: | |
| r"""Warps an image or tensor from source into reference frame. | |
| Args: | |
| patch_src (torch.Tensor): The image or tensor to warp. | |
| Should be from source. | |
| dst_homo_src (torch.Tensor): The homography or stack of homographies | |
| from source to destination. The homography assumes normalized | |
| coordinates [-1, 1]. | |
| Return: | |
| torch.Tensor: Patch sampled at locations from source to destination. | |
| Shape: | |
| - Input: :math:`(N, C, H, W)` and :math:`(N, 3, 3)` | |
| - Output: :math:`(N, C, H, W)` | |
| Example: | |
| >>> input = torch.rand(1, 3, 32, 32) | |
| >>> homography = torch.eye(3).view(1, 3, 3) | |
| >>> warper = tgm.HomographyWarper(32, 32) | |
| >>> output = warper(input, homography) # NxCxHxW | |
| """ | |
| if not dst_homo_src.device == patch_src.device: | |
| raise TypeError("Patch and homography must be on the same device. \ | |
| Got patch.device: {} dst_H_src.device: {}." | |
| .format(patch_src.device, dst_homo_src.device)) | |
| return F.grid_sample(patch_src, self.warp_grid(dst_homo_src), | |
| mode=self.mode, padding_mode=self.padding_mode) | |
| def create_meshgrid( | |
| height: int, | |
| width: int, | |
| normalized_coordinates: Optional[bool] = True): | |
| """Generates a coordinate grid for an image. | |
| When the flag `normalized_coordinates` is set to True, the grid is | |
| normalized to be in the range [-1,1] to be consistent with the pytorch | |
| function grid_sample. | |
| http://pytorch.org/docs/master/nn.html#torch.nn.functional.grid_sample | |
| Args: | |
| height (int): the image height (rows). | |
| width (int): the image width (cols). | |
| normalized_coordinates (Optional[bool]): wether to normalize | |
| coordinates in the range [-1, 1] in order to be consistent with the | |
| PyTorch function grid_sample. | |
| Return: | |
| torch.Tensor: returns a grid tensor with shape :math:`(1, H, W, 2)`. | |
| """ | |
| # generate coordinates | |
| xs: Optional[torch.Tensor] = None | |
| ys: Optional[torch.Tensor] = None | |
| if normalized_coordinates: | |
| xs = torch.linspace(-1, 1, width) | |
| ys = torch.linspace(-1, 1, height) | |
| else: | |
| xs = torch.linspace(0, width - 1, width) | |
| ys = torch.linspace(0, height - 1, height) | |
| # generate grid by stacking coordinates | |
| base_grid: torch.Tensor = torch.stack( | |
| torch.meshgrid([xs, ys])).transpose(1, 2) # 2xHxW | |
| return torch.unsqueeze(base_grid, dim=0).permute(0, 2, 3, 1) # 1xHxWx2 | |
| def transform_points(trans_01: torch.Tensor, | |
| points_1: torch.Tensor) -> torch.Tensor: | |
| r"""Function that applies transformations to a set of points. | |
| Args: | |
| trans_01 (torch.Tensor): tensor for transformations of shape | |
| :math:`(B, D+1, D+1)`. | |
| points_1 (torch.Tensor): tensor of points of shape :math:`(B, N, D)`. | |
| Returns: | |
| torch.Tensor: tensor of N-dimensional points. | |
| Shape: | |
| - Output: :math:`(B, N, D)` | |
| Examples: | |
| >>> points_1 = torch.rand(2, 4, 3) # BxNx3 | |
| >>> trans_01 = torch.eye(4).view(1, 4, 4) # Bx4x4 | |
| >>> points_0 = tgm.transform_points(trans_01, points_1) # BxNx3 | |
| """ | |
| if not torch.is_tensor(trans_01) or not torch.is_tensor(points_1): | |
| raise TypeError("Input type is not a torch.Tensor") | |
| if not trans_01.device == points_1.device: | |
| raise TypeError("Tensor must be in the same device") | |
| if not trans_01.shape[0] == points_1.shape[0]: | |
| raise ValueError("Input batch size must be the same for both tensors") | |
| if not trans_01.shape[-1] == (points_1.shape[-1] + 1): | |
| raise ValueError("Last input dimensions must differe by one unit") | |
| # to homogeneous | |
| points_1_h = convert_points_to_homogeneous(points_1) # BxNxD+1 | |
| # transform coordinates | |
| points_0_h = torch.matmul( | |
| trans_01.unsqueeze(1), points_1_h.unsqueeze(-1)) | |
| points_0_h = torch.squeeze(points_0_h, dim=-1) | |
| # to euclidean | |
| points_0 = convert_points_from_homogeneous(points_0_h) # BxNxD | |
| return points_0 | |
| def convert_points_to_homogeneous(points): | |
| r"""Function that converts points from Euclidean to homogeneous space. | |
| See :class:`~torchgeometry.ConvertPointsToHomogeneous` for details. | |
| Examples:: | |
| >>> input = torch.rand(2, 4, 3) # BxNx3 | |
| >>> output = tgm.convert_points_to_homogeneous(input) # BxNx4 | |
| """ | |
| if not torch.is_tensor(points): | |
| raise TypeError("Input type is not a torch.Tensor. Got {}".format( | |
| type(points))) | |
| if len(points.shape) < 2: | |
| raise ValueError("Input must be at least a 2D tensor. Got {}".format( | |
| points.shape)) | |
| return nn.functional.pad(points, (0, 1), "constant", 1.0) | |
| def convert_points_from_homogeneous(points): | |
| r"""Function that converts points from homogeneous to Euclidean space. | |
| See :class:`~torchgeometry.ConvertPointsFromHomogeneous` for details. | |
| Examples:: | |
| >>> input = torch.rand(2, 4, 3) # BxNx3 | |
| >>> output = tgm.convert_points_from_homogeneous(input) # BxNx2 | |
| """ | |
| if not torch.is_tensor(points): | |
| raise TypeError("Input type is not a torch.Tensor. Got {}".format( | |
| type(points))) | |
| if len(points.shape) < 2: | |
| raise ValueError("Input must be at least a 2D tensor. Got {}".format( | |
| points.shape)) | |
| return points[..., :-1] / points[..., -1:] |