File size: 10,876 Bytes
36c95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
from typing import Tuple

import torch
import torch.nn as nn

from kornia.utils import create_meshgrid
from kornia.utils.helpers import _torch_solve_cast

__all__ = ["get_tps_transform", "warp_points_tps", "warp_image_tps"]

# utilities for computing thin plate spline transforms


def _pair_square_euclidean(tensor1: torch.Tensor, tensor2: torch.Tensor) -> torch.Tensor:
    r"""Compute the pairwise squared euclidean distance matrices :math:`(B, N, M)` between two tensors
    with shapes (B, N, C) and (B, M, C)."""
    # ||t1-t2||^2 = (t1-t2)^T(t1-t2) = t1^T*t1 + t2^T*t2 - 2*t1^T*t2
    t1_sq: torch.Tensor = tensor1.mul(tensor1).sum(dim=-1, keepdim=True)
    t2_sq: torch.Tensor = tensor2.mul(tensor2).sum(dim=-1, keepdim=True).transpose(1, 2)
    t1_t2: torch.Tensor = tensor1.matmul(tensor2.transpose(1, 2))
    square_dist: torch.Tensor = -2 * t1_t2 + t1_sq + t2_sq
    square_dist = square_dist.clamp(min=0)  # handle possible numerical errors
    return square_dist


def _kernel_distance(squared_distances: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    r"""Compute the TPS kernel distance function: :math:`r^2 log(r)`, where `r` is the euclidean distance.
    Since :math:`\log(r) = 1/2 \log(r^2)`, this function takes the squared distance matrix and calculates
    :math:`0.5 r^2 log(r^2)`."""
    # r^2 * log(r) = 1/2 * r^2 * log(r^2)
    return 0.5 * squared_distances * squared_distances.add(eps).log()


def get_tps_transform(points_src: torch.Tensor, points_dst: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Compute the TPS transform parameters that warp source points to target points.

    The input to this function is a tensor of :math:`(x, y)` source points :math:`(B, N, 2)` and a corresponding
    tensor of target :math:`(x, y)` points :math:`(B, N, 2)`.

    Args:
        points_src: batch of source points :math:`(B, N, 2)` as :math:`(x, y)` coordinate vectors.
        points_dst: batch of target points :math:`(B, N, 2)` as :math:`(x, y)` coordinate vectors.

    Returns:
        :math:`(B, N, 2)` tensor of kernel weights and :math:`(B, 3, 2)`
            tensor of affine weights. The last dimension contains the x-transform and y-transform weights
            as separate columns.

    Example:
        >>> points_src = torch.rand(1, 5, 2)
        >>> points_dst = torch.rand(1, 5, 2)
        >>> kernel_weights, affine_weights = get_tps_transform(points_src, points_dst)

    .. note::
        This function is often used in conjunction with :func:`warp_points_tps`, :func:`warp_image_tps`.
    """
    if not isinstance(points_src, torch.Tensor):
        raise TypeError(f"Input points_src is not torch.Tensor. Got {type(points_src)}")

    if not isinstance(points_dst, torch.Tensor):
        raise TypeError(f"Input points_dst is not torch.Tensor. Got {type(points_dst)}")

    if not len(points_src.shape) == 3:
        raise ValueError(f"Invalid shape for points_src, expected BxNx2. Got {points_src.shape}")

    if not len(points_dst.shape) == 3:
        raise ValueError(f"Invalid shape for points_dst, expected BxNx2. Got {points_dst.shape}")

    device, dtype = points_src.device, points_src.dtype
    batch_size, num_points = points_src.shape[:2]

    # set up and solve linear system
    # [K   P] [w] = [dst]
    # [P^T 0] [a]   [ 0 ]
    pair_distance: torch.Tensor = _pair_square_euclidean(points_src, points_dst)
    k_matrix: torch.Tensor = _kernel_distance(pair_distance)

    zero_mat: torch.Tensor = torch.zeros(batch_size, 3, 3, device=device, dtype=dtype)
    one_mat: torch.Tensor = torch.ones(batch_size, num_points, 1, device=device, dtype=dtype)
    dest_with_zeros: torch.Tensor = torch.cat((points_dst, zero_mat[:, :, :2]), 1)
    p_matrix: torch.Tensor = torch.cat((one_mat, points_src), -1)
    p_matrix_t: torch.Tensor = torch.cat((p_matrix, zero_mat), 1).transpose(1, 2)
    l_matrix: torch.Tensor = torch.cat((k_matrix, p_matrix), -1)
    l_matrix = torch.cat((l_matrix, p_matrix_t), 1)

    weights, _ = _torch_solve_cast(dest_with_zeros, l_matrix)
    kernel_weights: torch.Tensor = weights[:, :-3]
    affine_weights: torch.Tensor = weights[:, -3:]

    return (kernel_weights, affine_weights)


def warp_points_tps(
    points_src: torch.Tensor, kernel_centers: torch.Tensor, kernel_weights: torch.Tensor, affine_weights: torch.Tensor
) -> torch.Tensor:
    r"""Warp a tensor of coordinate points using the thin plate spline defined by kernel points, kernel weights,
    and affine weights.

    The source points should be a :math:`(B, N, 2)` tensor of :math:`(x, y)` coordinates. The kernel centers are
    a :math:`(B, K, 2)` tensor of :math:`(x, y)` coordinates. The kernel weights are a :math:`(B, K, 2)` tensor,
    and the affine weights are a :math:`(B, 3, 2)` tensor.  For the weight tensors, tensor[..., 0] contains the
    weights for the x-transform and tensor[..., 1] the weights for the y-transform.

    Args:
        points_src: tensor of source points :math:`(B, N, 2)`.
        kernel_centers: tensor of kernel center points :math:`(B, K, 2)`.
        kernel_weights: tensor of kernl weights :math:`(B, K, 2)`.
        affine_weights: tensor of affine weights :math:`(B, 3, 2)`.

    Returns:
        The :math:`(B, N, 2)` tensor of warped source points, from applying the TPS transform.

    Example:
        >>> points_src = torch.rand(1, 5, 2)
        >>> points_dst = torch.rand(1, 5, 2)
        >>> kernel_weights, affine_weights = get_tps_transform(points_src, points_dst)
        >>> warped = warp_points_tps(points_src, points_dst, kernel_weights, affine_weights)
        >>> warped_correct = torch.allclose(warped, points_dst)

    .. note::
        This function is often used in conjunction with :func:`get_tps_transform`.
    """
    if not isinstance(points_src, torch.Tensor):
        raise TypeError(f"Input points_src is not torch.Tensor. Got {type(points_src)}")

    if not isinstance(kernel_centers, torch.Tensor):
        raise TypeError(f"Input kernel_centers is not torch.Tensor. Got {type(kernel_centers)}")

    if not isinstance(kernel_weights, torch.Tensor):
        raise TypeError(f"Input kernel_weights is not torch.Tensor. Got {type(kernel_weights)}")

    if not isinstance(affine_weights, torch.Tensor):
        raise TypeError(f"Input affine_weights is not torch.Tensor. Got {type(affine_weights)}")

    if not len(points_src.shape) == 3:
        raise ValueError(f"Invalid shape for points_src, expected BxNx2. Got {points_src.shape}")

    if not len(kernel_centers.shape) == 3:
        raise ValueError(f"Invalid shape for kernel_centers, expected BxNx2. Got {kernel_centers.shape}")

    if not len(kernel_weights.shape) == 3:
        raise ValueError(f"Invalid shape for kernel_weights, expected BxNx2. Got {kernel_weights.shape}")

    if not len(affine_weights.shape) == 3:
        raise ValueError(f"Invalid shape for affine_weights, expected BxNx2. Got {affine_weights.shape}")

    # f_{x|y}(v) = a_0 + [a_x a_y].v + \sum_i w_i * U(||v-u_i||)
    pair_distance: torch.Tensor = _pair_square_euclidean(points_src, kernel_centers)
    k_matrix: torch.Tensor = _kernel_distance(pair_distance)

    # broadcast the kernel distance matrix against the x and y weights to compute the x and y
    # transforms simultaneously
    k_mul_kernel = k_matrix[..., None].mul(kernel_weights[:, None]).sum(-2)
    points_mul_affine = points_src[..., None].mul(affine_weights[:, None, 1:]).sum(-2)
    warped: torch.Tensor = k_mul_kernel + points_mul_affine + affine_weights[:, None, 0]

    return warped


def warp_image_tps(
    image: torch.Tensor,
    kernel_centers: torch.Tensor,
    kernel_weights: torch.Tensor,
    affine_weights: torch.Tensor,
    align_corners: bool = False,
) -> torch.Tensor:
    r"""Warp an image tensor according to the thin plate spline transform defined by kernel centers,
    kernel weights, and affine weights.

    .. image:: _static/img/warp_image_tps.png

    The transform is applied to each pixel coordinate in the output image to obtain a point in the input
    image for interpolation of the output pixel. So the TPS parameters should correspond to a warp from
    output space to input space.

    The input `image` is a :math:`(B, C, H, W)` tensor. The kernel centers, kernel weight and affine weights
    are the same as in `warp_points_tps`.

    Args:
        image: input image tensor :math:`(B, C, H, W)`.
        kernel_centers: kernel center points :math:`(B, K, 2)`.
        kernel_weights: tensor of kernl weights :math:`(B, K, 2)`.
        affine_weights: tensor of affine weights :math:`(B, 3, 2)`.
        align_corners: interpolation flag used by `grid_sample`.

    Returns:
        warped image tensor :math:`(B, C, H, W)`.

    Example:
        >>> points_src = torch.rand(1, 5, 2)
        >>> points_dst = torch.rand(1, 5, 2)
        >>> image = torch.rand(1, 3, 32, 32)
        >>> # note that we are getting the reverse transform: dst -> src
        >>> kernel_weights, affine_weights = get_tps_transform(points_dst, points_src)
        >>> warped_image = warp_image_tps(image, points_src, kernel_weights, affine_weights)

    .. note::
        This function is often used in conjunction with :func:`get_tps_transform`.
    """
    if not isinstance(image, torch.Tensor):
        raise TypeError(f"Input image is not torch.Tensor. Got {type(image)}")

    if not isinstance(kernel_centers, torch.Tensor):
        raise TypeError(f"Input kernel_centers is not torch.Tensor. Got {type(kernel_centers)}")

    if not isinstance(kernel_weights, torch.Tensor):
        raise TypeError(f"Input kernel_weights is not torch.Tensor. Got {type(kernel_weights)}")

    if not isinstance(affine_weights, torch.Tensor):
        raise TypeError(f"Input affine_weights is not torch.Tensor. Got {type(affine_weights)}")

    if not len(image.shape) == 4:
        raise ValueError(f"Invalid shape for image, expected BxCxHxW. Got {image.shape}")

    if not len(kernel_centers.shape) == 3:
        raise ValueError(f"Invalid shape for kernel_centers, expected BxNx2. Got {kernel_centers.shape}")

    if not len(kernel_weights.shape) == 3:
        raise ValueError(f"Invalid shape for kernel_weights, expected BxNx2. Got {kernel_weights.shape}")

    if not len(affine_weights.shape) == 3:
        raise ValueError(f"Invalid shape for affine_weights, expected BxNx2. Got {affine_weights.shape}")

    device, dtype = image.device, image.dtype
    batch_size, _, h, w = image.shape
    coords: torch.Tensor = create_meshgrid(h, w, device=device).to(dtype=dtype)
    coords = coords.reshape(-1, 2).expand(batch_size, -1, -1)
    warped: torch.Tensor = warp_points_tps(coords, kernel_centers, kernel_weights, affine_weights)
    warped = warped.view(-1, h, w, 2)
    warped_image: torch.Tensor = nn.functional.grid_sample(image, warped, align_corners=align_corners)

    return warped_image