|
|
|
|
| """
|
| Warping field estimator(W) defined in the paper, which generates a warping field using the implicit
|
| keypoint representations x_s and x_d, and employs this flow field to warp the source feature volume f_s.
|
| """
|
|
|
| from torch import nn
|
| import torch.nn.functional as F
|
| from .util import SameBlock2d
|
| from .dense_motion import DenseMotionNetwork
|
|
|
|
|
| class WarpingNetwork(nn.Module):
|
| def __init__(
|
| self,
|
| num_kp,
|
| block_expansion,
|
| max_features,
|
| num_down_blocks,
|
| reshape_channel,
|
| estimate_occlusion_map=False,
|
| dense_motion_params=None,
|
| **kwargs
|
| ):
|
| super(WarpingNetwork, self).__init__()
|
|
|
| self.upscale = kwargs.get('upscale', 1)
|
| self.flag_use_occlusion_map = kwargs.get('flag_use_occlusion_map', True)
|
|
|
| if dense_motion_params is not None:
|
| self.dense_motion_network = DenseMotionNetwork(
|
| num_kp=num_kp,
|
| feature_channel=reshape_channel,
|
| estimate_occlusion_map=estimate_occlusion_map,
|
| **dense_motion_params
|
| )
|
| else:
|
| self.dense_motion_network = None
|
|
|
| self.third = SameBlock2d(max_features, block_expansion * (2 ** num_down_blocks), kernel_size=(3, 3), padding=(1, 1), lrelu=True)
|
| self.fourth = nn.Conv2d(in_channels=block_expansion * (2 ** num_down_blocks), out_channels=block_expansion * (2 ** num_down_blocks), kernel_size=1, stride=1)
|
|
|
| self.estimate_occlusion_map = estimate_occlusion_map
|
|
|
| def deform_input(self, inp, deformation):
|
| return F.grid_sample(inp, deformation, align_corners=False)
|
|
|
| def forward(self, feature_3d, kp_driving, kp_source):
|
| if self.dense_motion_network is not None:
|
|
|
| dense_motion = self.dense_motion_network(
|
| feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source
|
| )
|
| if 'occlusion_map' in dense_motion:
|
| occlusion_map = dense_motion['occlusion_map']
|
| else:
|
| occlusion_map = None
|
|
|
| deformation = dense_motion['deformation']
|
| out = self.deform_input(feature_3d, deformation)
|
|
|
| bs, c, d, h, w = out.shape
|
| out = out.view(bs, c * d, h, w)
|
| out = self.third(out)
|
| out = self.fourth(out)
|
|
|
| if self.flag_use_occlusion_map and (occlusion_map is not None):
|
| out = out * occlusion_map
|
|
|
| ret_dct = {
|
| 'occlusion_map': occlusion_map,
|
| 'deformation': deformation,
|
| 'out': out,
|
| }
|
|
|
| return ret_dct
|
|
|