| import torch | |
| def convert_flow_to_deformation(flow): | |
| r"""convert flow fields to deformations. | |
| Args: | |
| flow (tensor): Flow field obtained by the model | |
| Returns: | |
| deformation (tensor): The deformation used for warping | |
| """ | |
| b,c,h,w = flow.shape | |
| flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1) | |
| grid = make_coordinate_grid(flow) | |
| deformation = grid + flow_norm.permute(0,2,3,1) | |
| return deformation | |
| def make_coordinate_grid(flow): | |
| r"""obtain coordinate grid with the same size as the flow filed. | |
| Args: | |
| flow (tensor): Flow field obtained by the model | |
| Returns: | |
| grid (tensor): The grid with the same size as the input flow | |
| """ | |
| b,c,h,w = flow.shape | |
| x = torch.arange(w).to(flow) | |
| y = torch.arange(h).to(flow) | |
| x = (2 * (x / (w - 1)) - 1) | |
| y = (2 * (y / (h - 1)) - 1) | |
| yy = y.view(-1, 1).repeat(1, w) | |
| xx = x.view(1, -1).repeat(h, 1) | |
| meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) | |
| meshed = meshed.expand(b, -1, -1, -1) | |
| return meshed | |
| def warp_image(source_image, deformation): | |
| r"""warp the input image according to the deformation | |
| Args: | |
| source_image (tensor): source images to be warped | |
| deformation (tensor): deformations used to warp the images; value in range (-1, 1) | |
| Returns: | |
| output (tensor): the warped images | |
| """ | |
| _, h_old, w_old, _ = deformation.shape | |
| _, _, h, w = source_image.shape | |
| if h_old != h or w_old != w: | |
| deformation = deformation.permute(0, 3, 1, 2) | |
| deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear') | |
| deformation = deformation.permute(0, 2, 3, 1) | |
| return torch.nn.functional.grid_sample(source_image, deformation) | |