| import torch |
| import torch.nn as nn |
|
|
| class ResidualBlock(nn.Module): |
| def __init__(self, in_features): |
| super(ResidualBlock, self).__init__() |
| self.block = nn.Sequential( |
| nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), |
| nn.InstanceNorm2d(in_features), nn.ReLU(inplace=True), |
| nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), |
| nn.InstanceNorm2d(in_features) |
| ) |
| def forward(self, x): |
| return x + self.block(x) |
|
|
| class ResNetGenerator(nn.Module): |
| def __init__(self, input_channels=3, output_channels=3, num_residual_blocks=9): |
| super(ResNetGenerator, self).__init__() |
| out_features = 64 |
| model =[nn.ReflectionPad2d(3), nn.Conv2d(input_channels, out_features, 7), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True)] |
| in_features = out_features |
| for _ in range(2): |
| out_features *= 2 |
| model +=[nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True)] |
| in_features = out_features |
| for _ in range(num_residual_blocks): |
| model += [ResidualBlock(in_features)] |
| for _ in range(2): |
| out_features //= 2 |
| model +=[nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True)] |
| in_features = out_features |
| model +=[nn.ReflectionPad2d(3), nn.Conv2d(out_features, output_channels, 7), nn.Tanh()] |
| self.model = nn.Sequential(*model) |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
| class PatchGANDiscriminator(nn.Module): |
| def __init__(self, input_channels=3): |
| super(PatchGANDiscriminator, self).__init__() |
| def discriminator_block(in_filters, out_filters, stride=2, normalize=True): |
| layers =[nn.Conv2d(in_filters, out_filters, 4, stride=stride, padding=1)] |
| if normalize: layers.append(nn.InstanceNorm2d(out_filters)) |
| layers.append(nn.LeakyReLU(0.2, inplace=True)) |
| return layers |
| self.model = nn.Sequential( |
| *discriminator_block(input_channels, 64, normalize=False), |
| *discriminator_block(64, 128), |
| *discriminator_block(128, 256), |
| *discriminator_block(256, 512, stride=1), |
| nn.Conv2d(512, 1, 4, padding=1) |
| ) |
| def forward(self, x): |
| return self.model(x) |
|
|
| def weights_init_normal(m): |
| classname = m.__class__.__name__ |
| if classname.find('Conv') != -1: |
| nn.init.normal_(m.weight.data, 0.0, 0.02) |
| if hasattr(m, 'bias') and m.bias is not None: |
| nn.init.constant_(m.bias.data, 0.0) |
|
|
| class CycleGAN(nn.Module): |
| def __init__(self): |
| super(CycleGAN, self).__init__() |
| self.G_A2B = ResNetGenerator(num_residual_blocks=9) |
| self.G_B2A = ResNetGenerator(num_residual_blocks=9) |
| self.D_A = PatchGANDiscriminator() |
| self.D_B = PatchGANDiscriminator() |
| self.apply(weights_init_normal) |
|
|