import torch.nn as nn class ResidualBlock(nn.Module): def __init__(self,input_channels ) -> None: super(ResidualBlock,self).__init__() self.conv1 = nn.Conv2d(input_channels,input_channels,3,1,padding=1,padding_mode='reflect') self.conv2 = nn.Conv2d(input_channels,input_channels,3,1,padding=1,padding_mode='reflect') self.instanceNorm = nn.InstanceNorm2d(input_channels) self.activation = nn.ReLU() def forward(self,x): original = x.copy() x = self.conv1(x) x = self.instanceNorm(x) x = self.activation(x) x = self.conv2(x) x = self.instanceNorm(x) return original + x class ContractingBlock(nn.Module): def __init__(self, input_channels, use_bn=True,kernel_size=3,activation='relu') -> None: super(ContractingBlock,self).__init__() self.conv1 = nn.Conv2d(input_channels, input_channels*2, kernel_size,padding=1,stride=2,padding_mode='reflect') self.activation = nn.ReLU() if activation == 'relu' else nn.LeakyReLU(0.2) if use_bn: self.normalization = nn.InstanceNorm2d(input_channels) self.use_bn = use_bn def forward(self,x): x = self.conv1(x) if self.use_bn: self.normalization(x) x = self.activation(x) return x class ExpandingBlock(nn.Module): def __init__(self,input_channels,use_bn=True) -> None: super(ExpandingBlock, self).__init__() self.conv1 = nn.ConvTranspose2d(input_channels, input_channels // 2, kernel_size=3,stride=2,padding=1,output_padding=1) if use_bn: self.normalization = nn.InstanceNorm2d(input_channels // 2) self.use_bn = use_bn self.activation = nn.ReLU() def forward(self, x): x = self.conv1(x) if self.use_bn: x = self.normalization(x) x = self.activation(x) return x class FeatureMapBlock(nn.Module): def __init__(self, input_channels, output_channels) -> None: super(FeatureMapBlock,self).__init__() self.conv = nn.Conv2d(input_channels, output_channels,kernel_size=7,padding=3,padding_mode='reflect') def forward(self,x): x = self.conv(x) return x class Generator(nn.Module): def __init__(self, input_channels,output_channels, hidden_dim=64) -> None: super(Generator,self).__init__() self.upfeature = FeatureMapBlock(input_channels,hidden_dim) self.contract1 = ContractingBlock(hidden_dim) self.contract2 = ContractingBlock(hidden_dim * 2) res_mult = 4 self.res0 = ResidualBlock(hidden_dim * res_mult) self.res1 = ResidualBlock(hidden_dim * res_mult) self.res2 = ResidualBlock(hidden_dim * res_mult) self.res3 = ResidualBlock(hidden_dim * res_mult) self.res4 = ResidualBlock(hidden_dim * res_mult) self.res5 = ResidualBlock(hidden_dim * res_mult) self.res6 = ResidualBlock(hidden_dim * res_mult) self.res7 = ResidualBlock(hidden_dim * res_mult) self.res8 = ResidualBlock(hidden_dim * res_mult) self.expand1 = ExpandingBlock(hidden_dim * res_mult) self.expand2 = ExpandingBlock(hidden_dim * 2) self.downfeature = FeatureMapBlock(hidden_dim,output_channels) self.tanh = nn.Tanh() def forward(self, x): x0 = self.upfeature(x) x1 = self.contract1(x0) x2 = self.contract2(x1) x3 = self.res0(x2) x4 = self.res1(x3) x5 = self.res2(x4) x6 = self.res3(x5) x7 = self.res4(x6) x8 = self.res5(x7) x9 = self.res6(x8) x10 = self.res7(x9) x11 = self.res8(x10) x12 = self.expand1(x11) x13 = self.expand2(x12) xn = self.downfeature(x13) return self.tanh(xn) class Discriminator(nn.Module): def __init__(self, input_channels, hidden_channels=64) -> None: super(Discriminator,self).__init__() self.upfeature = FeatureMapBlock(input_channels,hidden_channels) self.contract1 = ContractingBlock(hidden_channels, False,kernel_size=4,activation='lrelu') self.contract2 = ContractingBlock(hidden_channels * 2,kernel_size=4,activation='lrelu') self.contract3 = ContractingBlock(hidden_channels * 4,kernel_size=4,activation='lrelu') self.conv = nn.Conv2d(hidden_channels*8,1,kernel_size=1) def forward(self,x): x0 = self.upfeature(x) x1 = self.contract1(x0) x2 = self.contract2(x1) x3 = self.contract3(x2) x4 = self.conv(x3) return x4