| | |
| |
|
| | """ |
| | Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image. |
| | """ |
| |
|
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| | from .util import SPADEResnetBlock |
| |
|
| |
|
| | class SPADEDecoder(nn.Module): |
| | def __init__(self, upscale=1, max_features=256, block_expansion=64, out_channels=64, num_down_blocks=2): |
| | for i in range(num_down_blocks): |
| | input_channels = min(max_features, block_expansion * (2 ** (i + 1))) |
| | self.upscale = upscale |
| | super().__init__() |
| | norm_G = 'spadespectralinstance' |
| | label_num_channels = input_channels |
| |
|
| | self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1) |
| | self.G_middle_0 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) |
| | self.G_middle_1 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) |
| | self.G_middle_2 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) |
| | self.G_middle_3 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) |
| | self.G_middle_4 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) |
| | self.G_middle_5 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) |
| | self.up_0 = SPADEResnetBlock(2 * input_channels, input_channels, norm_G, label_num_channels) |
| | self.up_1 = SPADEResnetBlock(input_channels, out_channels, norm_G, label_num_channels) |
| | self.up = nn.Upsample(scale_factor=2) |
| |
|
| | if self.upscale is None or self.upscale <= 1: |
| | self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1) |
| | else: |
| | self.conv_img = nn.Sequential( |
| | nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1), |
| | nn.PixelShuffle(upscale_factor=2) |
| | ) |
| |
|
| | def forward(self, feature): |
| | seg = feature |
| | x = self.fc(feature) |
| | x = self.G_middle_0(x, seg) |
| | x = self.G_middle_1(x, seg) |
| | x = self.G_middle_2(x, seg) |
| | x = self.G_middle_3(x, seg) |
| | x = self.G_middle_4(x, seg) |
| | x = self.G_middle_5(x, seg) |
| |
|
| | x = self.up(x) |
| | x = self.up_0(x, seg) |
| | x = self.up(x) |
| | x = self.up_1(x, seg) |
| |
|
| | x = self.conv_img(F.leaky_relu(x, 2e-1)) |
| | x = torch.sigmoid(x) |
| |
|
| | return x |