import torch import torch.nn as nn class ResidualBlock(nn.Module): def __init__(self, in_features): super(ResidualBlock, self).__init__() self.block = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features), nn.ReLU(inplace=True), nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features) ) def forward(self, x): return x + self.block(x) class ResNetGenerator(nn.Module): def __init__(self, input_channels=3, output_channels=3, num_residual_blocks=9): super(ResNetGenerator, self).__init__() out_features = 64 model =[nn.ReflectionPad2d(3), nn.Conv2d(input_channels, out_features, 7), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True)] in_features = out_features for _ in range(2): out_features *= 2 model +=[nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True)] in_features = out_features for _ in range(num_residual_blocks): model += [ResidualBlock(in_features)] for _ in range(2): out_features //= 2 model +=[nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True)] in_features = out_features model +=[nn.ReflectionPad2d(3), nn.Conv2d(out_features, output_channels, 7), nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, x): return self.model(x) class PatchGANDiscriminator(nn.Module): def __init__(self, input_channels=3): super(PatchGANDiscriminator, self).__init__() def discriminator_block(in_filters, out_filters, stride=2, normalize=True): layers =[nn.Conv2d(in_filters, out_filters, 4, stride=stride, padding=1)] if normalize: layers.append(nn.InstanceNorm2d(out_filters)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers self.model = nn.Sequential( *discriminator_block(input_channels, 64, normalize=False), *discriminator_block(64, 128), *discriminator_block(128, 256), *discriminator_block(256, 512, stride=1), nn.Conv2d(512, 1, 4, padding=1) ) def forward(self, x): return self.model(x) def weights_init_normal(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias.data, 0.0) class CycleGAN(nn.Module): def __init__(self): super(CycleGAN, self).__init__() self.G_A2B = ResNetGenerator(num_residual_blocks=9) self.G_B2A = ResNetGenerator(num_residual_blocks=9) self.D_A = PatchGANDiscriminator() self.D_B = PatchGANDiscriminator() self.apply(weights_init_normal)