| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| | from face_vid2vid.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock |
| | from face_vid2vid.modules.dense_motion import DenseMotionNetwork |
| |
|
| |
|
| | class OcclusionAwareGenerator(nn.Module): |
| | """ |
| | Generator follows NVIDIA architecture. |
| | """ |
| |
|
| | def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, |
| | num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): |
| | super(OcclusionAwareGenerator, self).__init__() |
| |
|
| | if dense_motion_params is not None: |
| | self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, |
| | estimate_occlusion_map=estimate_occlusion_map, |
| | **dense_motion_params) |
| | else: |
| | self.dense_motion_network = None |
| |
|
| | self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3)) |
| |
|
| | down_blocks = [] |
| | for i in range(num_down_blocks): |
| | in_features = min(max_features, block_expansion * (2 ** i)) |
| | out_features = min(max_features, block_expansion * (2 ** (i + 1))) |
| | down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) |
| | self.down_blocks = nn.ModuleList(down_blocks) |
| |
|
| | self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) |
| |
|
| | self.reshape_channel = reshape_channel |
| | self.reshape_depth = reshape_depth |
| |
|
| | self.resblocks_3d = torch.nn.Sequential() |
| | for i in range(num_resblocks): |
| | self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) |
| |
|
| | out_features = block_expansion * (2 ** (num_down_blocks)) |
| | self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) |
| | self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) |
| |
|
| | self.resblocks_2d = torch.nn.Sequential() |
| | for i in range(num_resblocks): |
| | self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1)) |
| |
|
| | up_blocks = [] |
| | for i in range(num_down_blocks): |
| | in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i))) |
| | out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1))) |
| | up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) |
| | self.up_blocks = nn.ModuleList(up_blocks) |
| |
|
| | self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3)) |
| | self.estimate_occlusion_map = estimate_occlusion_map |
| | self.image_channel = image_channel |
| |
|
| | def deform_input(self, inp, deformation): |
| | _, d_old, h_old, w_old, _ = deformation.shape |
| | _, _, d, h, w = inp.shape |
| | if d_old != d or h_old != h or w_old != w: |
| | deformation = deformation.permute(0, 4, 1, 2, 3) |
| | deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') |
| | deformation = deformation.permute(0, 2, 3, 4, 1) |
| | return F.grid_sample(inp, deformation) |
| |
|
| | def forward(self, source_image, kp_driving, kp_source): |
| | |
| | out = self.first(source_image) |
| | for i in range(len(self.down_blocks)): |
| | out = self.down_blocks[i](out) |
| | out = self.second(out) |
| | bs, c, h, w = out.shape |
| | |
| | feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) |
| | feature_3d = self.resblocks_3d(feature_3d) |
| |
|
| | |
| | output_dict = {} |
| | if self.dense_motion_network is not None: |
| | dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, |
| | kp_source=kp_source) |
| | output_dict['mask'] = dense_motion['mask'] |
| |
|
| | if 'occlusion_map' in dense_motion: |
| | occlusion_map = dense_motion['occlusion_map'] |
| | output_dict['occlusion_map'] = occlusion_map |
| | else: |
| | occlusion_map = None |
| | deformation = dense_motion['deformation'] |
| | out = self.deform_input(feature_3d, deformation) |
| |
|
| | bs, c, d, h, w = out.shape |
| | out = out.view(bs, c*d, h, w) |
| | out = self.third(out) |
| | out = self.fourth(out) |
| |
|
| | if occlusion_map is not None: |
| | if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: |
| | occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') |
| | out = out * occlusion_map |
| |
|
| | |
| |
|
| | |
| | out = self.resblocks_2d(out) |
| | for i in range(len(self.up_blocks)): |
| | out = self.up_blocks[i](out) |
| | out = self.final(out) |
| | out = F.sigmoid(out) |
| |
|
| | output_dict["prediction"] = out |
| |
|
| | return output_dict |
| |
|
| |
|
| | class SPADEDecoder(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | ic = 256 |
| | oc = 64 |
| | norm_G = 'spadespectralinstance' |
| | label_nc = 256 |
| | |
| | self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1) |
| | self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) |
| | self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) |
| | self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) |
| | self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) |
| | self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) |
| | self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) |
| | self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc) |
| | self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc) |
| | self.conv_img = nn.Conv2d(oc, 3, 3, padding=1) |
| | self.up = nn.Upsample(scale_factor=2) |
| | |
| | def forward(self, feature): |
| | seg = feature |
| | x = self.fc(feature) |
| | x = self.G_middle_0(x, seg) |
| | x = self.G_middle_1(x, seg) |
| | x = self.G_middle_2(x, seg) |
| | x = self.G_middle_3(x, seg) |
| | x = self.G_middle_4(x, seg) |
| | x = self.G_middle_5(x, seg) |
| | x = self.up(x) |
| | x = self.up_0(x, seg) |
| | x = self.up(x) |
| | x = self.up_1(x, seg) |
| |
|
| | x = self.conv_img(F.leaky_relu(x, 2e-1)) |
| | |
| | x = F.sigmoid(x) |
| | |
| | return x |
| |
|
| |
|
| | class OcclusionAwareSPADEGenerator(nn.Module): |
| |
|
| | def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, |
| | num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): |
| | super(OcclusionAwareSPADEGenerator, self).__init__() |
| |
|
| | if dense_motion_params is not None: |
| | self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, |
| | estimate_occlusion_map=estimate_occlusion_map, |
| | **dense_motion_params) |
| | else: |
| | self.dense_motion_network = None |
| |
|
| | self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) |
| |
|
| | down_blocks = [] |
| | for i in range(num_down_blocks): |
| | in_features = min(max_features, block_expansion * (2 ** i)) |
| | out_features = min(max_features, block_expansion * (2 ** (i + 1))) |
| | down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) |
| | self.down_blocks = nn.ModuleList(down_blocks) |
| |
|
| | self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) |
| |
|
| | self.reshape_channel = reshape_channel |
| | self.reshape_depth = reshape_depth |
| |
|
| | self.resblocks_3d = torch.nn.Sequential() |
| | for i in range(num_resblocks): |
| | self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) |
| |
|
| | out_features = block_expansion * (2 ** (num_down_blocks)) |
| | self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) |
| | self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) |
| |
|
| | self.estimate_occlusion_map = estimate_occlusion_map |
| | self.image_channel = image_channel |
| |
|
| | self.decoder = SPADEDecoder() |
| |
|
| | def deform_input(self, inp, deformation): |
| | _, d_old, h_old, w_old, _ = deformation.shape |
| | _, _, d, h, w = inp.shape |
| | if d_old != d or h_old != h or w_old != w: |
| | deformation = deformation.permute(0, 4, 1, 2, 3) |
| | deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') |
| | deformation = deformation.permute(0, 2, 3, 4, 1) |
| | return F.grid_sample(inp, deformation) |
| |
|
| | def forward(self, source_image, kp_driving, kp_source, fp16=False): |
| | if fp16: |
| | source_image = source_image.half() |
| | kp_driving['value'] = kp_driving['value'].half() |
| | kp_source['value'] = kp_source['value'].half() |
| | |
| | out = self.first(source_image) |
| | for i in range(len(self.down_blocks)): |
| | out = self.down_blocks[i](out) |
| | out = self.second(out) |
| | bs, c, h, w = out.shape |
| | |
| | feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) |
| | feature_3d = self.resblocks_3d(feature_3d) |
| |
|
| | |
| | output_dict = {} |
| | if self.dense_motion_network is not None: |
| | dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, |
| | kp_source=kp_source) |
| | output_dict['mask'] = dense_motion['mask'] |
| |
|
| | if 'occlusion_map' in dense_motion: |
| | occlusion_map = dense_motion['occlusion_map'] |
| | output_dict['occlusion_map'] = occlusion_map |
| | else: |
| | occlusion_map = None |
| | deformation = dense_motion['deformation'] |
| | out = self.deform_input(feature_3d, deformation) |
| |
|
| | bs, c, d, h, w = out.shape |
| | out = out.view(bs, c*d, h, w) |
| | out = self.third(out) |
| | out = self.fourth(out) |
| |
|
| | if occlusion_map is not None: |
| | if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: |
| | occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') |
| | out = out * occlusion_map |
| |
|
| | |
| | out = self.decoder(out) |
| |
|
| | output_dict["prediction"] = out |
| |
|
| | return output_dict |