File size: 10,876 Bytes
36c95ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 | from typing import Tuple
import torch
import torch.nn as nn
from kornia.utils import create_meshgrid
from kornia.utils.helpers import _torch_solve_cast
__all__ = ["get_tps_transform", "warp_points_tps", "warp_image_tps"]
# utilities for computing thin plate spline transforms
def _pair_square_euclidean(tensor1: torch.Tensor, tensor2: torch.Tensor) -> torch.Tensor:
r"""Compute the pairwise squared euclidean distance matrices :math:`(B, N, M)` between two tensors
with shapes (B, N, C) and (B, M, C)."""
# ||t1-t2||^2 = (t1-t2)^T(t1-t2) = t1^T*t1 + t2^T*t2 - 2*t1^T*t2
t1_sq: torch.Tensor = tensor1.mul(tensor1).sum(dim=-1, keepdim=True)
t2_sq: torch.Tensor = tensor2.mul(tensor2).sum(dim=-1, keepdim=True).transpose(1, 2)
t1_t2: torch.Tensor = tensor1.matmul(tensor2.transpose(1, 2))
square_dist: torch.Tensor = -2 * t1_t2 + t1_sq + t2_sq
square_dist = square_dist.clamp(min=0) # handle possible numerical errors
return square_dist
def _kernel_distance(squared_distances: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
r"""Compute the TPS kernel distance function: :math:`r^2 log(r)`, where `r` is the euclidean distance.
Since :math:`\log(r) = 1/2 \log(r^2)`, this function takes the squared distance matrix and calculates
:math:`0.5 r^2 log(r^2)`."""
# r^2 * log(r) = 1/2 * r^2 * log(r^2)
return 0.5 * squared_distances * squared_distances.add(eps).log()
def get_tps_transform(points_src: torch.Tensor, points_dst: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Compute the TPS transform parameters that warp source points to target points.
The input to this function is a tensor of :math:`(x, y)` source points :math:`(B, N, 2)` and a corresponding
tensor of target :math:`(x, y)` points :math:`(B, N, 2)`.
Args:
points_src: batch of source points :math:`(B, N, 2)` as :math:`(x, y)` coordinate vectors.
points_dst: batch of target points :math:`(B, N, 2)` as :math:`(x, y)` coordinate vectors.
Returns:
:math:`(B, N, 2)` tensor of kernel weights and :math:`(B, 3, 2)`
tensor of affine weights. The last dimension contains the x-transform and y-transform weights
as separate columns.
Example:
>>> points_src = torch.rand(1, 5, 2)
>>> points_dst = torch.rand(1, 5, 2)
>>> kernel_weights, affine_weights = get_tps_transform(points_src, points_dst)
.. note::
This function is often used in conjunction with :func:`warp_points_tps`, :func:`warp_image_tps`.
"""
if not isinstance(points_src, torch.Tensor):
raise TypeError(f"Input points_src is not torch.Tensor. Got {type(points_src)}")
if not isinstance(points_dst, torch.Tensor):
raise TypeError(f"Input points_dst is not torch.Tensor. Got {type(points_dst)}")
if not len(points_src.shape) == 3:
raise ValueError(f"Invalid shape for points_src, expected BxNx2. Got {points_src.shape}")
if not len(points_dst.shape) == 3:
raise ValueError(f"Invalid shape for points_dst, expected BxNx2. Got {points_dst.shape}")
device, dtype = points_src.device, points_src.dtype
batch_size, num_points = points_src.shape[:2]
# set up and solve linear system
# [K P] [w] = [dst]
# [P^T 0] [a] [ 0 ]
pair_distance: torch.Tensor = _pair_square_euclidean(points_src, points_dst)
k_matrix: torch.Tensor = _kernel_distance(pair_distance)
zero_mat: torch.Tensor = torch.zeros(batch_size, 3, 3, device=device, dtype=dtype)
one_mat: torch.Tensor = torch.ones(batch_size, num_points, 1, device=device, dtype=dtype)
dest_with_zeros: torch.Tensor = torch.cat((points_dst, zero_mat[:, :, :2]), 1)
p_matrix: torch.Tensor = torch.cat((one_mat, points_src), -1)
p_matrix_t: torch.Tensor = torch.cat((p_matrix, zero_mat), 1).transpose(1, 2)
l_matrix: torch.Tensor = torch.cat((k_matrix, p_matrix), -1)
l_matrix = torch.cat((l_matrix, p_matrix_t), 1)
weights, _ = _torch_solve_cast(dest_with_zeros, l_matrix)
kernel_weights: torch.Tensor = weights[:, :-3]
affine_weights: torch.Tensor = weights[:, -3:]
return (kernel_weights, affine_weights)
def warp_points_tps(
points_src: torch.Tensor, kernel_centers: torch.Tensor, kernel_weights: torch.Tensor, affine_weights: torch.Tensor
) -> torch.Tensor:
r"""Warp a tensor of coordinate points using the thin plate spline defined by kernel points, kernel weights,
and affine weights.
The source points should be a :math:`(B, N, 2)` tensor of :math:`(x, y)` coordinates. The kernel centers are
a :math:`(B, K, 2)` tensor of :math:`(x, y)` coordinates. The kernel weights are a :math:`(B, K, 2)` tensor,
and the affine weights are a :math:`(B, 3, 2)` tensor. For the weight tensors, tensor[..., 0] contains the
weights for the x-transform and tensor[..., 1] the weights for the y-transform.
Args:
points_src: tensor of source points :math:`(B, N, 2)`.
kernel_centers: tensor of kernel center points :math:`(B, K, 2)`.
kernel_weights: tensor of kernl weights :math:`(B, K, 2)`.
affine_weights: tensor of affine weights :math:`(B, 3, 2)`.
Returns:
The :math:`(B, N, 2)` tensor of warped source points, from applying the TPS transform.
Example:
>>> points_src = torch.rand(1, 5, 2)
>>> points_dst = torch.rand(1, 5, 2)
>>> kernel_weights, affine_weights = get_tps_transform(points_src, points_dst)
>>> warped = warp_points_tps(points_src, points_dst, kernel_weights, affine_weights)
>>> warped_correct = torch.allclose(warped, points_dst)
.. note::
This function is often used in conjunction with :func:`get_tps_transform`.
"""
if not isinstance(points_src, torch.Tensor):
raise TypeError(f"Input points_src is not torch.Tensor. Got {type(points_src)}")
if not isinstance(kernel_centers, torch.Tensor):
raise TypeError(f"Input kernel_centers is not torch.Tensor. Got {type(kernel_centers)}")
if not isinstance(kernel_weights, torch.Tensor):
raise TypeError(f"Input kernel_weights is not torch.Tensor. Got {type(kernel_weights)}")
if not isinstance(affine_weights, torch.Tensor):
raise TypeError(f"Input affine_weights is not torch.Tensor. Got {type(affine_weights)}")
if not len(points_src.shape) == 3:
raise ValueError(f"Invalid shape for points_src, expected BxNx2. Got {points_src.shape}")
if not len(kernel_centers.shape) == 3:
raise ValueError(f"Invalid shape for kernel_centers, expected BxNx2. Got {kernel_centers.shape}")
if not len(kernel_weights.shape) == 3:
raise ValueError(f"Invalid shape for kernel_weights, expected BxNx2. Got {kernel_weights.shape}")
if not len(affine_weights.shape) == 3:
raise ValueError(f"Invalid shape for affine_weights, expected BxNx2. Got {affine_weights.shape}")
# f_{x|y}(v) = a_0 + [a_x a_y].v + \sum_i w_i * U(||v-u_i||)
pair_distance: torch.Tensor = _pair_square_euclidean(points_src, kernel_centers)
k_matrix: torch.Tensor = _kernel_distance(pair_distance)
# broadcast the kernel distance matrix against the x and y weights to compute the x and y
# transforms simultaneously
k_mul_kernel = k_matrix[..., None].mul(kernel_weights[:, None]).sum(-2)
points_mul_affine = points_src[..., None].mul(affine_weights[:, None, 1:]).sum(-2)
warped: torch.Tensor = k_mul_kernel + points_mul_affine + affine_weights[:, None, 0]
return warped
def warp_image_tps(
image: torch.Tensor,
kernel_centers: torch.Tensor,
kernel_weights: torch.Tensor,
affine_weights: torch.Tensor,
align_corners: bool = False,
) -> torch.Tensor:
r"""Warp an image tensor according to the thin plate spline transform defined by kernel centers,
kernel weights, and affine weights.
.. image:: _static/img/warp_image_tps.png
The transform is applied to each pixel coordinate in the output image to obtain a point in the input
image for interpolation of the output pixel. So the TPS parameters should correspond to a warp from
output space to input space.
The input `image` is a :math:`(B, C, H, W)` tensor. The kernel centers, kernel weight and affine weights
are the same as in `warp_points_tps`.
Args:
image: input image tensor :math:`(B, C, H, W)`.
kernel_centers: kernel center points :math:`(B, K, 2)`.
kernel_weights: tensor of kernl weights :math:`(B, K, 2)`.
affine_weights: tensor of affine weights :math:`(B, 3, 2)`.
align_corners: interpolation flag used by `grid_sample`.
Returns:
warped image tensor :math:`(B, C, H, W)`.
Example:
>>> points_src = torch.rand(1, 5, 2)
>>> points_dst = torch.rand(1, 5, 2)
>>> image = torch.rand(1, 3, 32, 32)
>>> # note that we are getting the reverse transform: dst -> src
>>> kernel_weights, affine_weights = get_tps_transform(points_dst, points_src)
>>> warped_image = warp_image_tps(image, points_src, kernel_weights, affine_weights)
.. note::
This function is often used in conjunction with :func:`get_tps_transform`.
"""
if not isinstance(image, torch.Tensor):
raise TypeError(f"Input image is not torch.Tensor. Got {type(image)}")
if not isinstance(kernel_centers, torch.Tensor):
raise TypeError(f"Input kernel_centers is not torch.Tensor. Got {type(kernel_centers)}")
if not isinstance(kernel_weights, torch.Tensor):
raise TypeError(f"Input kernel_weights is not torch.Tensor. Got {type(kernel_weights)}")
if not isinstance(affine_weights, torch.Tensor):
raise TypeError(f"Input affine_weights is not torch.Tensor. Got {type(affine_weights)}")
if not len(image.shape) == 4:
raise ValueError(f"Invalid shape for image, expected BxCxHxW. Got {image.shape}")
if not len(kernel_centers.shape) == 3:
raise ValueError(f"Invalid shape for kernel_centers, expected BxNx2. Got {kernel_centers.shape}")
if not len(kernel_weights.shape) == 3:
raise ValueError(f"Invalid shape for kernel_weights, expected BxNx2. Got {kernel_weights.shape}")
if not len(affine_weights.shape) == 3:
raise ValueError(f"Invalid shape for affine_weights, expected BxNx2. Got {affine_weights.shape}")
device, dtype = image.device, image.dtype
batch_size, _, h, w = image.shape
coords: torch.Tensor = create_meshgrid(h, w, device=device).to(dtype=dtype)
coords = coords.reshape(-1, 2).expand(batch_size, -1, -1)
warped: torch.Tensor = warp_points_tps(coords, kernel_centers, kernel_weights, affine_weights)
warped = warped.view(-1, h, w, 2)
warped_image: torch.Tensor = nn.functional.grid_sample(image, warped, align_corners=align_corners)
return warped_image
|