|
|
| import torch.nn as nn |
| class ResidualBlock(nn.Module): |
| def __init__(self,input_channels ) -> None: |
| super(ResidualBlock,self).__init__() |
| self.conv1 = nn.Conv2d(input_channels,input_channels,3,1,padding=1,padding_mode='reflect') |
| self.conv2 = nn.Conv2d(input_channels,input_channels,3,1,padding=1,padding_mode='reflect') |
| self.instanceNorm = nn.InstanceNorm2d(input_channels) |
| self.activation = nn.ReLU() |
|
|
| def forward(self,x): |
| original = x.copy() |
| x = self.conv1(x) |
| x = self.instanceNorm(x) |
| x = self.activation(x) |
| x = self.conv2(x) |
| x = self.instanceNorm(x) |
| return original + x |
| |
|
|
|
|
| class ContractingBlock(nn.Module): |
| def __init__(self, input_channels, use_bn=True,kernel_size=3,activation='relu') -> None: |
| super(ContractingBlock,self).__init__() |
| self.conv1 = nn.Conv2d(input_channels, input_channels*2, kernel_size,padding=1,stride=2,padding_mode='reflect') |
| self.activation = nn.ReLU() if activation == 'relu' else nn.LeakyReLU(0.2) |
| if use_bn: |
| self.normalization = nn.InstanceNorm2d(input_channels) |
| self.use_bn = use_bn |
|
|
| def forward(self,x): |
| x = self.conv1(x) |
| if self.use_bn: |
| self.normalization(x) |
| x = self.activation(x) |
| return x |
| |
|
|
| class ExpandingBlock(nn.Module): |
| def __init__(self,input_channels,use_bn=True) -> None: |
| super(ExpandingBlock, self).__init__() |
| self.conv1 = nn.ConvTranspose2d(input_channels, input_channels // 2, kernel_size=3,stride=2,padding=1,output_padding=1) |
| if use_bn: |
| self.normalization = nn.InstanceNorm2d(input_channels // 2) |
| self.use_bn = use_bn |
| self.activation = nn.ReLU() |
|
|
| def forward(self, x): |
| x = self.conv1(x) |
| if self.use_bn: |
| x = self.normalization(x) |
| x = self.activation(x) |
| return x |
| |
|
|
|
|
| class FeatureMapBlock(nn.Module): |
| def __init__(self, input_channels, output_channels) -> None: |
| super(FeatureMapBlock,self).__init__() |
| self.conv = nn.Conv2d(input_channels, output_channels,kernel_size=7,padding=3,padding_mode='reflect') |
|
|
| def forward(self,x): |
| x = self.conv(x) |
| return x |
|
|
| class Generator(nn.Module): |
| def __init__(self, input_channels,output_channels, hidden_dim=64) -> None: |
| super(Generator,self).__init__() |
| self.upfeature = FeatureMapBlock(input_channels,hidden_dim) |
| self.contract1 = ContractingBlock(hidden_dim) |
| self.contract2 = ContractingBlock(hidden_dim * 2) |
| res_mult = 4 |
| self.res0 = ResidualBlock(hidden_dim * res_mult) |
| self.res1 = ResidualBlock(hidden_dim * res_mult) |
| self.res2 = ResidualBlock(hidden_dim * res_mult) |
| self.res3 = ResidualBlock(hidden_dim * res_mult) |
| self.res4 = ResidualBlock(hidden_dim * res_mult) |
| self.res5 = ResidualBlock(hidden_dim * res_mult) |
| self.res6 = ResidualBlock(hidden_dim * res_mult) |
| self.res7 = ResidualBlock(hidden_dim * res_mult) |
| self.res8 = ResidualBlock(hidden_dim * res_mult) |
| self.expand1 = ExpandingBlock(hidden_dim * res_mult) |
| self.expand2 = ExpandingBlock(hidden_dim * 2) |
| self.downfeature = FeatureMapBlock(hidden_dim,output_channels) |
| self.tanh = nn.Tanh() |
|
|
| def forward(self, x): |
| x0 = self.upfeature(x) |
| x1 = self.contract1(x0) |
| x2 = self.contract2(x1) |
| x3 = self.res0(x2) |
| x4 = self.res1(x3) |
| x5 = self.res2(x4) |
| x6 = self.res3(x5) |
| x7 = self.res4(x6) |
| x8 = self.res5(x7) |
| x9 = self.res6(x8) |
| x10 = self.res7(x9) |
| x11 = self.res8(x10) |
| x12 = self.expand1(x11) |
| x13 = self.expand2(x12) |
| xn = self.downfeature(x13) |
| return self.tanh(xn) |
| |
|
|
| class Discriminator(nn.Module): |
| def __init__(self, input_channels, hidden_channels=64) -> None: |
| super(Discriminator,self).__init__() |
| self.upfeature = FeatureMapBlock(input_channels,hidden_channels) |
| self.contract1 = ContractingBlock(hidden_channels, False,kernel_size=4,activation='lrelu') |
| self.contract2 = ContractingBlock(hidden_channels * 2,kernel_size=4,activation='lrelu') |
| self.contract3 = ContractingBlock(hidden_channels * 4,kernel_size=4,activation='lrelu') |
| self.conv = nn.Conv2d(hidden_channels*8,1,kernel_size=1) |
|
|
| def forward(self,x): |
| x0 = self.upfeature(x) |
| x1 = self.contract1(x0) |
| x2 = self.contract2(x1) |
| x3 = self.contract3(x2) |
| x4 = self.conv(x3) |
| return x4 |