Season-Style-Transfer / model /TransformerNet.py
russel0719's picture
Upload 6 files
f0e8ef0
# import package : model
import torch
from torch import nn
from torch.nn import functional as F
class ConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(ConvLayer, self).__init__()
reflection_padding = kernel_size // 2
self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x):
out = self.reflection_pad(x)
out = self.conv2d(out)
return out
class ResidualBlock(nn.Module):
"""ResidualBlock
introduced in: https://arxiv.org/abs/1512.03385
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
"""
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in1 = nn.InstanceNorm2d(channels, affine=True)
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in2 = nn.InstanceNorm2d(channels, affine=True)
self.relu = nn.ReLU()
def forward(self, x):
residual = x
out = self.relu(self.in1(self.conv1(x)))
out = self.in2(self.conv2(out))
out = out + residual
return out
class TransformerNet(nn.Module):
def __init__(self):
super(TransformerNet, self).__init__()
# Encoder convolution layers
self.encoder = nn.Sequential()
self.encoder.add_module('conv1', ConvLayer(3, 32, kernel_size=9, stride=1))
self.encoder.add_module('in1', nn.InstanceNorm2d(32, affine=True))
self.encoder.add_module('relu1', nn.ReLU())
# Residual layers
self.residual = nn.Sequential()
for i in range(5):
self.residual.add_module('resblock_%d' %(i+1), ResidualBlock(32))
# Decoder Layers
self.decoder = nn.Sequential()
self.decoder.add_module('deconv3', ConvLayer(32, 3, kernel_size=9, stride=1))
def forward(self, x):
encoder_output = self.encoder(x)
residual_output = self.residual(encoder_output)
decoder_output = self.decoder(residual_output)
return decoder_output