Spaces:
Sleeping
Sleeping
| # import package : model | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| class ConvLayer(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride): | |
| super(ConvLayer, self).__init__() | |
| reflection_padding = kernel_size // 2 | |
| self.reflection_pad = nn.ReflectionPad2d(reflection_padding) | |
| self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) | |
| def forward(self, x): | |
| out = self.reflection_pad(x) | |
| out = self.conv2d(out) | |
| return out | |
| class ResidualBlock(nn.Module): | |
| """ResidualBlock | |
| introduced in: https://arxiv.org/abs/1512.03385 | |
| recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html | |
| """ | |
| def __init__(self, channels): | |
| super(ResidualBlock, self).__init__() | |
| self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1) | |
| self.in1 = nn.InstanceNorm2d(channels, affine=True) | |
| self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1) | |
| self.in2 = nn.InstanceNorm2d(channels, affine=True) | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| residual = x | |
| out = self.relu(self.in1(self.conv1(x))) | |
| out = self.in2(self.conv2(out)) | |
| out = out + residual | |
| return out | |
| class TransformerNet(nn.Module): | |
| def __init__(self): | |
| super(TransformerNet, self).__init__() | |
| # Encoder convolution layers | |
| self.encoder = nn.Sequential() | |
| self.encoder.add_module('conv1', ConvLayer(3, 32, kernel_size=9, stride=1)) | |
| self.encoder.add_module('in1', nn.InstanceNorm2d(32, affine=True)) | |
| self.encoder.add_module('relu1', nn.ReLU()) | |
| # Residual layers | |
| self.residual = nn.Sequential() | |
| for i in range(5): | |
| self.residual.add_module('resblock_%d' %(i+1), ResidualBlock(32)) | |
| # Decoder Layers | |
| self.decoder = nn.Sequential() | |
| self.decoder.add_module('deconv3', ConvLayer(32, 3, kernel_size=9, stride=1)) | |
| def forward(self, x): | |
| encoder_output = self.encoder(x) | |
| residual_output = self.residual(encoder_output) | |
| decoder_output = self.decoder(residual_output) | |
| return decoder_output |