| | |
| |
|
| | """ |
| | Warping field estimator(W) defined in the paper, which generates a warping field using the implicit |
| | keypoint representations x_s and x_d, and employs this flow field to warp the source feature volume f_s. |
| | """ |
| |
|
| | from torch import nn |
| | import torch.nn.functional as F |
| | from .util import SameBlock2d |
| | from .dense_motion import DenseMotionNetwork |
| |
|
| |
|
| | class WarpingNetwork(nn.Module): |
| | def __init__( |
| | self, |
| | num_kp, |
| | block_expansion, |
| | max_features, |
| | num_down_blocks, |
| | reshape_channel, |
| | estimate_occlusion_map=False, |
| | dense_motion_params=None, |
| | **kwargs |
| | ): |
| | super(WarpingNetwork, self).__init__() |
| |
|
| | self.upscale = kwargs.get('upscale', 1) |
| | self.flag_use_occlusion_map = kwargs.get('flag_use_occlusion_map', True) |
| |
|
| | if dense_motion_params is not None: |
| | self.dense_motion_network = DenseMotionNetwork( |
| | num_kp=num_kp, |
| | feature_channel=reshape_channel, |
| | estimate_occlusion_map=estimate_occlusion_map, |
| | **dense_motion_params |
| | ) |
| | else: |
| | self.dense_motion_network = None |
| |
|
| | self.third = SameBlock2d(max_features, block_expansion * (2 ** num_down_blocks), kernel_size=(3, 3), padding=(1, 1), lrelu=True) |
| | self.fourth = nn.Conv2d(in_channels=block_expansion * (2 ** num_down_blocks), out_channels=block_expansion * (2 ** num_down_blocks), kernel_size=1, stride=1) |
| |
|
| | self.estimate_occlusion_map = estimate_occlusion_map |
| |
|
| | def deform_input(self, inp, deformation): |
| | return F.grid_sample(inp, deformation, align_corners=False) |
| |
|
| | def forward(self, feature_3d, kp_driving, kp_source): |
| | if self.dense_motion_network is not None: |
| | |
| | dense_motion = self.dense_motion_network( |
| | feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source |
| | ) |
| | if 'occlusion_map' in dense_motion: |
| | occlusion_map = dense_motion['occlusion_map'] |
| | else: |
| | occlusion_map = None |
| |
|
| | deformation = dense_motion['deformation'] |
| | out = self.deform_input(feature_3d, deformation) |
| |
|
| | bs, c, d, h, w = out.shape |
| | out = out.view(bs, c * d, h, w) |
| | out = self.third(out) |
| | out = self.fourth(out) |
| |
|
| | if self.flag_use_occlusion_map and (occlusion_map is not None): |
| | out = out * occlusion_map |
| |
|
| | ret_dct = { |
| | 'occlusion_map': occlusion_map, |
| | 'deformation': deformation, |
| | 'out': out, |
| | } |
| |
|
| | return ret_dct |
| |
|