|
|
|
|
|
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
|
def __init__(self, num_filters):
|
|
|
super(ResidualBlock, self).__init__()
|
|
|
self.block = nn.Sequential(
|
|
|
nn.ReflectionPad2d(1),
|
|
|
nn.Conv2d(
|
|
|
in_channels=num_filters,
|
|
|
out_channels=num_filters,
|
|
|
kernel_size=3,
|
|
|
stride=1,
|
|
|
padding=0,
|
|
|
bias=False,
|
|
|
),
|
|
|
nn.BatchNorm2d(num_filters),
|
|
|
nn.ReLU(True),
|
|
|
nn.Dropout(0.5),
|
|
|
nn.ReflectionPad2d(1),
|
|
|
nn.Conv2d(
|
|
|
in_channels=num_filters,
|
|
|
out_channels=num_filters,
|
|
|
kernel_size=3,
|
|
|
stride=1,
|
|
|
padding=0,
|
|
|
bias=False,
|
|
|
),
|
|
|
nn.BatchNorm2d(num_filters),
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
residual = self.block(x)
|
|
|
return x + residual
|
|
|
|
|
|
|
|
|
ngf = 64
|
|
|
|
|
|
|
|
|
class ResNetGenerator(nn.Module):
|
|
|
"""
|
|
|
https://github.com/Alibaba-AAIG/Beyond-ImageNet-Attack/blob/863b758ee4f4a6d3d4e7777c5f94f457fa449f73/generator.py#L14
|
|
|
|
|
|
Test Case:
|
|
|
>>> netG = ResNetGenerator()
|
|
|
>>> test_sample = torch.rand(1, 3, 224, 224)
|
|
|
>>> print("Generator output:", netG(test_sample).size())
|
|
|
>>> print(
|
|
|
>>> "Generator parameters:",
|
|
|
>>> sum(p.numel() for p in netG.parameters() if p.requires_grad),
|
|
|
>>> )
|
|
|
"""
|
|
|
|
|
|
def __init__(self, inception=False):
|
|
|
super(ResNetGenerator, self).__init__()
|
|
|
self.inception = inception
|
|
|
self.block1 = nn.Sequential(
|
|
|
nn.ReflectionPad2d(3),
|
|
|
nn.Conv2d(3, ngf, kernel_size=7, padding=0, bias=False),
|
|
|
nn.BatchNorm2d(ngf),
|
|
|
nn.ReLU(True),
|
|
|
)
|
|
|
|
|
|
self.block2 = nn.Sequential(
|
|
|
nn.Conv2d(
|
|
|
ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=False
|
|
|
),
|
|
|
nn.BatchNorm2d(ngf * 2),
|
|
|
nn.ReLU(True),
|
|
|
)
|
|
|
|
|
|
self.block3 = nn.Sequential(
|
|
|
nn.Conv2d(
|
|
|
ngf * 2,
|
|
|
ngf * 4,
|
|
|
kernel_size=3,
|
|
|
stride=2,
|
|
|
padding=1,
|
|
|
bias=False,
|
|
|
),
|
|
|
nn.BatchNorm2d(ngf * 4),
|
|
|
nn.ReLU(True),
|
|
|
)
|
|
|
|
|
|
self.resblock1 = ResidualBlock(ngf * 4)
|
|
|
self.resblock2 = ResidualBlock(ngf * 4)
|
|
|
self.resblock3 = ResidualBlock(ngf * 4)
|
|
|
self.resblock4 = ResidualBlock(ngf * 4)
|
|
|
self.resblock5 = ResidualBlock(ngf * 4)
|
|
|
self.resblock6 = ResidualBlock(ngf * 4)
|
|
|
|
|
|
self.upsampl1 = nn.Sequential(
|
|
|
nn.ConvTranspose2d(
|
|
|
ngf * 4,
|
|
|
ngf * 2,
|
|
|
kernel_size=3,
|
|
|
stride=2,
|
|
|
padding=1,
|
|
|
output_padding=1,
|
|
|
bias=False,
|
|
|
),
|
|
|
nn.BatchNorm2d(ngf * 2),
|
|
|
nn.ReLU(True),
|
|
|
)
|
|
|
|
|
|
self.upsampl2 = nn.Sequential(
|
|
|
nn.ConvTranspose2d(
|
|
|
ngf * 2,
|
|
|
ngf,
|
|
|
kernel_size=3,
|
|
|
stride=2,
|
|
|
padding=1,
|
|
|
output_padding=1,
|
|
|
bias=False,
|
|
|
),
|
|
|
nn.BatchNorm2d(ngf),
|
|
|
nn.ReLU(True),
|
|
|
)
|
|
|
|
|
|
self.blockf = nn.Sequential(
|
|
|
nn.ReflectionPad2d(3), nn.Conv2d(ngf, 3, kernel_size=7, padding=0)
|
|
|
)
|
|
|
self.crop = nn.ConstantPad2d((0, -1, -1, 0), 0)
|
|
|
|
|
|
def forward(self, input):
|
|
|
x = self.block1(input)
|
|
|
x = self.block2(x)
|
|
|
x = self.block3(x)
|
|
|
x = self.resblock1(x)
|
|
|
x = self.resblock2(x)
|
|
|
x = self.resblock3(x)
|
|
|
x = self.resblock4(x)
|
|
|
x = self.resblock5(x)
|
|
|
x = self.resblock6(x)
|
|
|
x = self.upsampl1(x)
|
|
|
x = self.upsampl2(x)
|
|
|
x = self.blockf(x)
|
|
|
if self.inception:
|
|
|
x = self.crop(x)
|
|
|
return (torch.tanh(x) + 1) / 2
|
|
|
|
|
|
|
|
|
CDAGenerator = ResNetGenerator
|
|
|
|