| import functools |
| import torch |
| import torch.nn as nn |
| from .base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder |
|
|
| def convert_flow_to_deformation(flow): |
| r"""convert flow fields to deformations. |
| |
| Args: |
| flow (tensor): Flow field obtained by the model |
| Returns: |
| deformation (tensor): The deformation used for warpping |
| """ |
| b,c,h,w = flow.shape |
| flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1) |
| grid = make_coordinate_grid(flow) |
| deformation = grid + flow_norm.permute(0,2,3,1) |
| return deformation |
|
|
| def make_coordinate_grid(flow): |
| r"""obtain coordinate grid with the same size as the flow filed. |
| |
| Args: |
| flow (tensor): Flow field obtained by the model |
| Returns: |
| grid (tensor): The grid with the same size as the input flow |
| """ |
| b,c,h,w = flow.shape |
|
|
| x = torch.arange(w).to(flow) |
| y = torch.arange(h).to(flow) |
|
|
| x = (2 * (x / (w - 1)) - 1) |
| y = (2 * (y / (h - 1)) - 1) |
|
|
| yy = y.view(-1, 1).repeat(1, w) |
| xx = x.view(1, -1).repeat(h, 1) |
|
|
| meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) |
| meshed = meshed.expand(b, -1, -1, -1) |
| return meshed |
|
|
| |
| def warp_image(source_image, deformation): |
| r"""warp the input image according to the deformation |
| |
| Args: |
| source_image (tensor): source images to be warpped |
| deformation (tensor): deformations used to warp the images; value in range (-1, 1) |
| Returns: |
| output (tensor): the warpped images |
| """ |
| _, h_old, w_old, _ = deformation.shape |
| _, _, h, w = source_image.shape |
| if h_old != h or w_old != w: |
| deformation = deformation.permute(0, 3, 1, 2) |
| deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear') |
| deformation = deformation.permute(0, 2, 3, 1) |
| return torch.nn.functional.grid_sample(source_image, deformation) |
|
|
|
|
| class FaceGenerator(nn.Module): |
| def __init__( |
| self, |
| mapping_net, |
| warpping_net, |
| editing_net, |
| common |
| ): |
| super(FaceGenerator, self).__init__() |
| self.mapping_net = MappingNet(**mapping_net) |
| self.warpping_net = WarpingNet(**warpping_net, **common) |
| self.editing_net = EditingNet(**editing_net, **common) |
| |
| def forward( |
| self, |
| input_image, |
| driving_source, |
| stage=None |
| ): |
| if stage == 'warp': |
| descriptor = self.mapping_net(driving_source) |
| output = self.warpping_net(input_image, descriptor) |
| else: |
| descriptor = self.mapping_net(driving_source) |
| output = self.warpping_net(input_image, descriptor) |
| output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor) |
| return output |
|
|
| class MappingNet(nn.Module): |
| def __init__(self, coeff_nc, descriptor_nc, layer): |
| super( MappingNet, self).__init__() |
|
|
| self.layer = layer |
| nonlinearity = nn.LeakyReLU(0.1) |
|
|
| self.first = nn.Sequential( |
| torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) |
|
|
| for i in range(layer): |
| net = nn.Sequential(nonlinearity, |
| torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) |
| setattr(self, 'encoder' + str(i), net) |
|
|
| self.pooling = nn.AdaptiveAvgPool1d(1) |
| self.output_nc = descriptor_nc |
|
|
| def forward(self, input_3dmm): |
| out = self.first(input_3dmm) |
| for i in range(self.layer): |
| model = getattr(self, 'encoder' + str(i)) |
| out = model(out) + out[:,:,3:-3] |
| out = self.pooling(out) |
| return out |
|
|
| class WarpingNet(nn.Module): |
| def __init__( |
| self, |
| image_nc, |
| descriptor_nc, |
| base_nc, |
| max_nc, |
| encoder_layer, |
| decoder_layer, |
| use_spect |
| ): |
| super( WarpingNet, self).__init__() |
|
|
| nonlinearity = nn.LeakyReLU(0.1) |
| norm_layer = functools.partial(LayerNorm2d, affine=True) |
| kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect} |
|
|
| self.descriptor_nc = descriptor_nc |
| self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc, |
| max_nc, encoder_layer, decoder_layer, **kwargs) |
|
|
| self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc), |
| nonlinearity, |
| nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3)) |
|
|
| self.pool = nn.AdaptiveAvgPool2d(1) |
|
|
| def forward(self, input_image, descriptor): |
| final_output={} |
| output = self.hourglass(input_image, descriptor) |
| final_output['flow_field'] = self.flow_out(output) |
|
|
| deformation = convert_flow_to_deformation(final_output['flow_field']) |
| final_output['warp_image'] = warp_image(input_image, deformation) |
| return final_output |
|
|
|
|
| class EditingNet(nn.Module): |
| def __init__( |
| self, |
| image_nc, |
| descriptor_nc, |
| layer, |
| base_nc, |
| max_nc, |
| num_res_blocks, |
| use_spect): |
| super(EditingNet, self).__init__() |
|
|
| nonlinearity = nn.LeakyReLU(0.1) |
| norm_layer = functools.partial(LayerNorm2d, affine=True) |
| kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect} |
| self.descriptor_nc = descriptor_nc |
|
|
| |
| self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs) |
| self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs) |
|
|
| def forward(self, input_image, warp_image, descriptor): |
| x = torch.cat([input_image, warp_image], 1) |
| x = self.encoder(x) |
| gen_image = self.decoder(x, descriptor) |
| return gen_image |
|
|