| import torch |
|
|
|
|
| |
| def tilt_projection(taux: torch.Tensor, tauy: torch.Tensor, return_inverse: bool = False) -> torch.Tensor: |
| r"""Estimate the tilt projection matrix or the inverse tilt projection matrix. |
| |
| Args: |
| taux: Rotation angle in radians around the :math:`x`-axis with shape :math:`(*, 1)`. |
| tauy: Rotation angle in radians around the :math:`y`-axis with shape :math:`(*, 1)`. |
| return_inverse: False to obtain the the tilt projection matrix. True for the inverse matrix. |
| |
| Returns: |
| torch.Tensor: Inverse tilt projection matrix with shape :math:`(*, 3, 3)`. |
| """ |
| if taux.shape != tauy.shape: |
| raise ValueError(f'Shape of taux {taux.shape} and tauy {tauy.shape} do not match.') |
|
|
| ndim: int = taux.dim() |
| taux = taux.reshape(-1) |
| tauy = tauy.reshape(-1) |
|
|
| cTx = torch.cos(taux) |
| sTx = torch.sin(taux) |
| cTy = torch.cos(tauy) |
| sTy = torch.sin(tauy) |
| zero = torch.zeros_like(cTx) |
| one = torch.ones_like(cTx) |
|
|
| Rx = torch.stack([one, zero, zero, zero, cTx, sTx, zero, -sTx, cTx], -1).reshape(-1, 3, 3) |
| Ry = torch.stack([cTy, zero, -sTy, zero, one, zero, sTy, zero, cTy], -1).reshape(-1, 3, 3) |
| R = Ry @ Rx |
|
|
| if return_inverse: |
| invR22 = 1 / R[..., 2, 2] |
| invPz = torch.stack( |
| [invR22, zero, R[..., 0, 2] * invR22, zero, invR22, R[..., 1, 2] * invR22, zero, zero, one], -1 |
| ).reshape(-1, 3, 3) |
|
|
| inv_tilt = R.transpose(-1, -2) @ invPz |
| if ndim == 0: |
| inv_tilt = torch.squeeze(inv_tilt) |
|
|
| return inv_tilt |
|
|
| Pz = torch.stack( |
| [R[..., 2, 2], zero, -R[..., 0, 2], zero, R[..., 2, 2], -R[..., 1, 2], zero, zero, one], -1 |
| ).reshape(-1, 3, 3) |
|
|
| tilt = Pz @ R.transpose(-1, -2) |
| if ndim == 0: |
| tilt = torch.squeeze(tilt) |
|
|
| return tilt |
|
|
|
|
| def distort_points(points: torch.Tensor, K: torch.Tensor, dist: torch.Tensor) -> torch.Tensor: |
| r"""Distortion of a set of 2D points based on the lens distortion model. |
| |
| Radial :math:`(k_1, k_2, k_3, k_4, k_4, k_6)`, |
| tangential :math:`(p_1, p_2)`, thin prism :math:`(s_1, s_2, s_3, s_4)`, and tilt :math:`(\tau_x, \tau_y)` |
| distortion models are considered in this function. |
| |
| Args: |
| points: Input image points with shape :math:`(*, N, 2)`. |
| K: Intrinsic camera matrix with shape :math:`(*, 3, 3)`. |
| dist: Distortion coefficients |
| :math:`(k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]])`. This is |
| a vector with 4, 5, 8, 12 or 14 elements with shape :math:`(*, n)`. |
| |
| Returns: |
| Undistorted 2D points with shape :math:`(*, N, 2)`. |
| |
| Example: |
| >>> points = torch.rand(1, 1, 2) |
| >>> K = torch.eye(3)[None] |
| >>> dist_coeff = torch.rand(1, 4) |
| >>> points_dist = distort_points(points, K, dist_coeff) |
| |
| """ |
| if points.dim() < 2 and points.shape[-1] != 2: |
| raise ValueError(f'points shape is invalid. Got {points.shape}.') |
|
|
| if K.shape[-2:] != (3, 3): |
| raise ValueError(f'K matrix shape is invalid. Got {K.shape}.') |
|
|
| if dist.shape[-1] not in [4, 5, 8, 12, 14]: |
| raise ValueError(f'Invalid number of distortion coefficients. Got {dist.shape[-1]}') |
|
|
| |
| if dist.shape[-1] < 14: |
| dist = torch.nn.functional.pad(dist, [0, 14 - dist.shape[-1]]) |
|
|
| |
| cx: torch.Tensor = K[..., 0:1, 2] |
| cy: torch.Tensor = K[..., 1:2, 2] |
| fx: torch.Tensor = K[..., 0:1, 0] |
| fy: torch.Tensor = K[..., 1:2, 1] |
| |
| x: torch.Tensor = (points[..., 0] - cx) / fx |
| y: torch.Tensor = (points[..., 1] - cy) / fy |
|
|
| |
| r2 = x * x + y * y |
|
|
| rad_poly = (1 + dist[..., 0:1] * r2 + dist[..., 1:2] * r2 * r2 + dist[..., 4:5] * r2 ** 3) / ( |
| 1 + dist[..., 5:6] * r2 + dist[..., 6:7] * r2 * r2 + dist[..., 7:8] * r2 ** 3 |
| ) |
| xd = ( |
| x * rad_poly |
| + 2 * dist[..., 2:3] * x * y |
| + dist[..., 3:4] * (r2 + 2 * x * x) |
| + dist[..., 8:9] * r2 |
| + dist[..., 9:10] * r2 * r2 |
| ) |
| yd = ( |
| y * rad_poly |
| + dist[..., 2:3] * (r2 + 2 * y * y) |
| + 2 * dist[..., 3:4] * x * y |
| + dist[..., 10:11] * r2 |
| + dist[..., 11:12] * r2 * r2 |
| ) |
|
|
| |
| if torch.any(dist[..., 12] != 0) or torch.any(dist[..., 13] != 0): |
| tilt = tilt_projection(dist[..., 12], dist[..., 13]) |
|
|
| |
| points_untilt = torch.stack([xd, yd, torch.ones_like(xd)], -1) @ tilt.transpose(-2, -1) |
| xd = points_untilt[..., 0] / points_untilt[..., 2] |
| yd = points_untilt[..., 1] / points_untilt[..., 2] |
|
|
| |
| x = fx * xd + cx |
| y = fy * yd + cy |
|
|
| return torch.stack([x, y], -1) |
|
|