CycleGAN / models.py
Yash Nagraj
Add the training scripts for cloud training
275907d
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self,input_channels ) -> None:
super(ResidualBlock,self).__init__()
self.conv1 = nn.Conv2d(input_channels,input_channels,3,1,padding=1,padding_mode='reflect')
self.conv2 = nn.Conv2d(input_channels,input_channels,3,1,padding=1,padding_mode='reflect')
self.instanceNorm = nn.InstanceNorm2d(input_channels)
self.activation = nn.ReLU()
def forward(self,x):
original = x.copy()
x = self.conv1(x)
x = self.instanceNorm(x)
x = self.activation(x)
x = self.conv2(x)
x = self.instanceNorm(x)
return original + x
class ContractingBlock(nn.Module):
def __init__(self, input_channels, use_bn=True,kernel_size=3,activation='relu') -> None:
super(ContractingBlock,self).__init__()
self.conv1 = nn.Conv2d(input_channels, input_channels*2, kernel_size,padding=1,stride=2,padding_mode='reflect')
self.activation = nn.ReLU() if activation == 'relu' else nn.LeakyReLU(0.2)
if use_bn:
self.normalization = nn.InstanceNorm2d(input_channels)
self.use_bn = use_bn
def forward(self,x):
x = self.conv1(x)
if self.use_bn:
self.normalization(x)
x = self.activation(x)
return x
class ExpandingBlock(nn.Module):
def __init__(self,input_channels,use_bn=True) -> None:
super(ExpandingBlock, self).__init__()
self.conv1 = nn.ConvTranspose2d(input_channels, input_channels // 2, kernel_size=3,stride=2,padding=1,output_padding=1)
if use_bn:
self.normalization = nn.InstanceNorm2d(input_channels // 2)
self.use_bn = use_bn
self.activation = nn.ReLU()
def forward(self, x):
x = self.conv1(x)
if self.use_bn:
x = self.normalization(x)
x = self.activation(x)
return x
class FeatureMapBlock(nn.Module):
def __init__(self, input_channels, output_channels) -> None:
super(FeatureMapBlock,self).__init__()
self.conv = nn.Conv2d(input_channels, output_channels,kernel_size=7,padding=3,padding_mode='reflect')
def forward(self,x):
x = self.conv(x)
return x
class Generator(nn.Module):
def __init__(self, input_channels,output_channels, hidden_dim=64) -> None:
super(Generator,self).__init__()
self.upfeature = FeatureMapBlock(input_channels,hidden_dim)
self.contract1 = ContractingBlock(hidden_dim)
self.contract2 = ContractingBlock(hidden_dim * 2)
res_mult = 4
self.res0 = ResidualBlock(hidden_dim * res_mult)
self.res1 = ResidualBlock(hidden_dim * res_mult)
self.res2 = ResidualBlock(hidden_dim * res_mult)
self.res3 = ResidualBlock(hidden_dim * res_mult)
self.res4 = ResidualBlock(hidden_dim * res_mult)
self.res5 = ResidualBlock(hidden_dim * res_mult)
self.res6 = ResidualBlock(hidden_dim * res_mult)
self.res7 = ResidualBlock(hidden_dim * res_mult)
self.res8 = ResidualBlock(hidden_dim * res_mult)
self.expand1 = ExpandingBlock(hidden_dim * res_mult)
self.expand2 = ExpandingBlock(hidden_dim * 2)
self.downfeature = FeatureMapBlock(hidden_dim,output_channels)
self.tanh = nn.Tanh()
def forward(self, x):
x0 = self.upfeature(x)
x1 = self.contract1(x0)
x2 = self.contract2(x1)
x3 = self.res0(x2)
x4 = self.res1(x3)
x5 = self.res2(x4)
x6 = self.res3(x5)
x7 = self.res4(x6)
x8 = self.res5(x7)
x9 = self.res6(x8)
x10 = self.res7(x9)
x11 = self.res8(x10)
x12 = self.expand1(x11)
x13 = self.expand2(x12)
xn = self.downfeature(x13)
return self.tanh(xn)
class Discriminator(nn.Module):
def __init__(self, input_channels, hidden_channels=64) -> None:
super(Discriminator,self).__init__()
self.upfeature = FeatureMapBlock(input_channels,hidden_channels)
self.contract1 = ContractingBlock(hidden_channels, False,kernel_size=4,activation='lrelu')
self.contract2 = ContractingBlock(hidden_channels * 2,kernel_size=4,activation='lrelu')
self.contract3 = ContractingBlock(hidden_channels * 4,kernel_size=4,activation='lrelu')
self.conv = nn.Conv2d(hidden_channels*8,1,kernel_size=1)
def forward(self,x):
x0 = self.upfeature(x)
x1 = self.contract1(x0)
x2 = self.contract2(x1)
x3 = self.contract3(x2)
x4 = self.conv(x3)
return x4