|
|
import torch
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
|
class EnhancedBN(nn.Module):
|
|
|
def __init__(self, nc: int, sty_nc: int = 3, sty_nhidden: int = 128):
|
|
|
super(EnhancedBN, self).__init__()
|
|
|
self.bn = nn.BatchNorm2d(nc)
|
|
|
self.mapping = nn.Conv2d(
|
|
|
in_channels=sty_nc,
|
|
|
out_channels=sty_nhidden,
|
|
|
kernel_size=3,
|
|
|
padding=1,
|
|
|
stride=1,
|
|
|
)
|
|
|
self.gamma = nn.Conv2d(
|
|
|
in_channels=sty_nhidden,
|
|
|
out_channels=nc,
|
|
|
kernel_size=3,
|
|
|
padding=1,
|
|
|
stride=1,
|
|
|
)
|
|
|
self.beta = nn.Conv2d(
|
|
|
in_channels=sty_nhidden,
|
|
|
out_channels=nc,
|
|
|
kernel_size=3,
|
|
|
padding=1,
|
|
|
stride=1,
|
|
|
)
|
|
|
self.init_weight()
|
|
|
|
|
|
def init_weight(self):
|
|
|
nn.init.kaiming_normal_(self.mapping.weight)
|
|
|
nn.init.kaiming_normal_(self.gamma.weight)
|
|
|
nn.init.kaiming_normal_(self.beta.weight)
|
|
|
|
|
|
def forward(self, base, sty):
|
|
|
bn = self.bn(base)
|
|
|
sty_resized = torch.nn.functional.interpolate(
|
|
|
sty, size=bn.size()[2:], mode='bilinear'
|
|
|
)
|
|
|
actv = torch.nn.functional.relu(self.mapping(sty_resized))
|
|
|
|
|
|
bn = bn * (1 + self.gamma(actv)) + self.beta(actv)
|
|
|
return bn
|
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
|
def __init__(self, num_filters):
|
|
|
super(ResidualBlock, self).__init__()
|
|
|
self.block1 = nn.Sequential(
|
|
|
nn.ReflectionPad2d(1),
|
|
|
nn.Conv2d(
|
|
|
in_channels=num_filters,
|
|
|
out_channels=num_filters,
|
|
|
kernel_size=3,
|
|
|
stride=1,
|
|
|
padding=0,
|
|
|
bias=False,
|
|
|
),
|
|
|
)
|
|
|
self.bn1 = EnhancedBN(num_filters)
|
|
|
self.block2 = nn.Sequential(
|
|
|
nn.ReLU(True),
|
|
|
nn.Dropout(0.5),
|
|
|
nn.ReflectionPad2d(1),
|
|
|
nn.Conv2d(
|
|
|
in_channels=num_filters,
|
|
|
out_channels=num_filters,
|
|
|
kernel_size=3,
|
|
|
stride=1,
|
|
|
padding=0,
|
|
|
bias=False,
|
|
|
),
|
|
|
)
|
|
|
self.bn2 = EnhancedBN(num_filters)
|
|
|
|
|
|
def forward(self, x, sty):
|
|
|
residual = self.block1(x)
|
|
|
residual = self.bn1(residual, sty)
|
|
|
residual = self.block2(residual)
|
|
|
residual = self.bn2(residual, sty)
|
|
|
return x + residual
|
|
|
|
|
|
|
|
|
ngf = 64
|
|
|
|
|
|
|
|
|
class ResNetGenerator(nn.Module):
|
|
|
def __init__(self):
|
|
|
super(ResNetGenerator, self).__init__()
|
|
|
self.block1 = nn.Sequential(
|
|
|
nn.ReflectionPad2d(3),
|
|
|
nn.Conv2d(3, ngf, kernel_size=7, padding=0, bias=False),
|
|
|
)
|
|
|
self.bn1 = EnhancedBN(ngf)
|
|
|
|
|
|
self.block2 = nn.Sequential(
|
|
|
nn.Conv2d(
|
|
|
ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=False
|
|
|
),
|
|
|
)
|
|
|
self.bn2 = EnhancedBN(ngf * 2)
|
|
|
|
|
|
self.block3 = nn.Sequential(
|
|
|
nn.Conv2d(
|
|
|
ngf * 2,
|
|
|
ngf * 4,
|
|
|
kernel_size=3,
|
|
|
stride=2,
|
|
|
padding=1,
|
|
|
bias=False,
|
|
|
),
|
|
|
)
|
|
|
self.bn3 = EnhancedBN(ngf * 4)
|
|
|
|
|
|
|
|
|
self.resblock1 = ResidualBlock(ngf * 4)
|
|
|
self.resblock2 = ResidualBlock(ngf * 4)
|
|
|
self.resblock3 = ResidualBlock(ngf * 4)
|
|
|
self.resblock4 = ResidualBlock(ngf * 4)
|
|
|
self.resblock5 = ResidualBlock(ngf * 4)
|
|
|
self.resblock6 = ResidualBlock(ngf * 4)
|
|
|
|
|
|
self.upsampl1 = nn.ConvTranspose2d(
|
|
|
ngf * 4,
|
|
|
ngf * 2,
|
|
|
kernel_size=3,
|
|
|
stride=2,
|
|
|
padding=1,
|
|
|
output_padding=1,
|
|
|
bias=False,
|
|
|
)
|
|
|
self.ubn1 = EnhancedBN(ngf * 2)
|
|
|
|
|
|
self.upsampl2 = nn.ConvTranspose2d(
|
|
|
ngf * 2,
|
|
|
ngf,
|
|
|
kernel_size=3,
|
|
|
stride=2,
|
|
|
padding=1,
|
|
|
output_padding=1,
|
|
|
bias=False,
|
|
|
)
|
|
|
self.ubn2 = EnhancedBN(ngf)
|
|
|
|
|
|
self.blockf = nn.Sequential(
|
|
|
nn.ReflectionPad2d(3), nn.Conv2d(ngf, 3, kernel_size=7, padding=0)
|
|
|
)
|
|
|
|
|
|
def forward(self, input, sty):
|
|
|
x = self.block1(input)
|
|
|
x = self.bn1(x, sty)
|
|
|
x = torch.nn.functional.relu(x)
|
|
|
x = self.block2(x)
|
|
|
x = self.bn2(x, sty)
|
|
|
x = torch.nn.functional.relu(x)
|
|
|
x = self.block3(x)
|
|
|
x = self.bn3(x, sty)
|
|
|
x = torch.nn.functional.relu(x)
|
|
|
|
|
|
x = self.resblock1(x, sty)
|
|
|
x = self.resblock2(x, sty)
|
|
|
x = self.resblock3(x, sty)
|
|
|
x = self.resblock4(x, sty)
|
|
|
x = self.resblock5(x, sty)
|
|
|
x = self.resblock6(x, sty)
|
|
|
|
|
|
x = self.upsampl1(x)
|
|
|
x = self.ubn1(x, sty)
|
|
|
x = torch.nn.functional.relu(x)
|
|
|
x = self.upsampl2(x)
|
|
|
x = self.ubn2(x, sty)
|
|
|
x = torch.nn.functional.relu(x)
|
|
|
x = self.blockf(x)
|
|
|
return (torch.tanh(x) + 1) / 2
|
|
|
|
|
|
|
|
|
AIMGenerator = ResNetGenerator
|
|
|
|