# import package : model import torch from torch import nn from torch.nn import functional as F class ConvLayer(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride): super(ConvLayer, self).__init__() reflection_padding = kernel_size // 2 self.reflection_pad = nn.ReflectionPad2d(reflection_padding) self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) def forward(self, x): out = self.reflection_pad(x) out = self.conv2d(out) return out class ResidualBlock(nn.Module): """ResidualBlock introduced in: https://arxiv.org/abs/1512.03385 recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html """ def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1) self.in1 = nn.InstanceNorm2d(channels, affine=True) self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1) self.in2 = nn.InstanceNorm2d(channels, affine=True) self.relu = nn.ReLU() def forward(self, x): residual = x out = self.relu(self.in1(self.conv1(x))) out = self.in2(self.conv2(out)) out = out + residual return out class TransformerNet(nn.Module): def __init__(self): super(TransformerNet, self).__init__() # Encoder convolution layers self.encoder = nn.Sequential() self.encoder.add_module('conv1', ConvLayer(3, 32, kernel_size=9, stride=1)) self.encoder.add_module('in1', nn.InstanceNorm2d(32, affine=True)) self.encoder.add_module('relu1', nn.ReLU()) # Residual layers self.residual = nn.Sequential() for i in range(5): self.residual.add_module('resblock_%d' %(i+1), ResidualBlock(32)) # Decoder Layers self.decoder = nn.Sequential() self.decoder.add_module('deconv3', ConvLayer(32, 3, kernel_size=9, stride=1)) def forward(self, x): encoder_output = self.encoder(x) residual_output = self.residual(encoder_output) decoder_output = self.decoder(residual_output) return decoder_output