CBN_cGAN / Generator.py
KhadgaA's picture
Add generator_model_40.pth and segmenatation_model.py
fd31624
import torch.nn as nn
import torch
class SandwichBatchNorm2d(nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.num_features = num_features
self.bn = nn.BatchNorm2d(num_features, affine=True)
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].normal_(
1, 0.02
) # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out = self.bn(x)
gamma, beta = self.embed(y).chunk(2, 1)
out = gamma.reshape(gamma.size(0), self.num_features, 1, 1) *out + beta.reshape(beta.size(0), self.num_features, 1, 1)
return out
class CategoricalConditionalBatchNorm2d(nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.num_features = num_features
self.bn = nn.BatchNorm2d(num_features, affine=False)
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].normal_(
1, 0.02
) # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
def forward(self, x, y):
out = self.bn(x)
gamma, beta = self.embed(y).chunk(2, 1)
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(
-1, self.num_features, 1, 1
)
return out
# """Generator Network"""
class double_convolution(nn.Module):
def __init__(self, in_channels, out_channels, num_classes):
super(double_convolution, self).__init__()
self.c1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.swn1 = SandwichBatchNorm2d(out_channels, num_classes)
self.act1 = nn.ReLU(inplace=True)
self.c2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.swn2 = SandwichBatchNorm2d(out_channels, num_classes)
self.act2 = nn.ReLU(inplace=True)
def forward(self, x,y):
x = self.c1(x)
x = self.swn1(x,y)
x = self.act1(x)
x = self.c2(x)
x = self.swn2(x,y)
x = self.act2(x)
return x
class Generator(nn.Module):
def __init__(self, ngpu, num_classes):
super(Generator, self).__init__()
self.ngpu = ngpu
self.swn1 = SandwichBatchNorm2d(512, num_classes)
self.swn2 = SandwichBatchNorm2d(256, num_classes)
self.swn3 = SandwichBatchNorm2d(128, num_classes)
self.swn4 = SandwichBatchNorm2d(64, num_classes)
# self.swn5 = SandwichBatchNorm2d(512, num_classes)
# self.swn6 = SandwichBatchNorm2d(1024, num_classes)
# self.swn7 = SandwichBatchNorm2d(1024, num_classes)
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
# Contracting path.
# Each convolution is applied twice.
self.down_convolution__2 = double_convolution(8, 4, num_classes)
self.down_convolution__1 = double_convolution(4, 8, num_classes)
self.down_convolution_0 = double_convolution(8, 1, num_classes)
self.down_convolution_1 = double_convolution(1, 64, num_classes)
self.down_convolution_2 = double_convolution(64, 128, num_classes)
self.down_convolution_3 = double_convolution(128, 256, num_classes)
self.down_convolution_4 = double_convolution(256, 512, num_classes)
self.down_convolution_5 = double_convolution(512, 1024, num_classes)
# Expanding path.
self.up_transpose_1 = nn.ConvTranspose2d(
in_channels=1024, out_channels=512, kernel_size=2, stride=2
)
# Below, `in_channels` again becomes 1024 as we are concatinating.
self.up_convolution_1 = double_convolution(1024, 512,num_classes)
self.up_transpose_2 = nn.ConvTranspose2d(
in_channels=512, out_channels=256, kernel_size=2, stride=2
)
self.up_convolution_2 = double_convolution(512, 256,num_classes)
self.up_transpose_3 = nn.ConvTranspose2d(
in_channels=256, out_channels=128, kernel_size=2, stride=2
)
self.up_convolution_3 = double_convolution(256, 128,num_classes)
self.up_transpose_4 = nn.ConvTranspose2d(
in_channels=128, out_channels=64, kernel_size=2, stride=2
)
self.up_convolution_4 = double_convolution(128, 64,num_classes)
# output => `out_channels` as per the number of classes.
self.out = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1)
def forward(self, x, y):
down__2 = self.down_convolution__2(x, y)
down__1 = self.down_convolution__1(down__2, y)
down_0 = self.down_convolution_0(down__1, y)
down_1 = self.down_convolution_1(down_0, y)
down_2 = self.max_pool2d(down_1)
down_3 = self.down_convolution_2(down_2, y)
down_4 = self.max_pool2d(down_3)
down_5 = self.down_convolution_3(down_4, y)
down_6 = self.max_pool2d(down_5)
down_7 = self.down_convolution_4(down_6, y)
down_8 = self.max_pool2d(down_7)
down_9 = self.down_convolution_5(down_8, y)
up_1 = self.up_transpose_1(down_9)
x = self.up_convolution_1(torch.cat([down_7, up_1], 1),y)
self.swn1(x, y)
up_2 = self.up_transpose_2(x)
x = self.up_convolution_2(torch.cat([down_5, up_2], 1),y)
self.swn2(x, y)
up_3 = self.up_transpose_3(x)
x = self.up_convolution_3(torch.cat([down_3, up_3], 1),y)
self.swn3(x, y)
up_4 = self.up_transpose_4(x)
x = self.up_convolution_4(torch.cat([down_1, up_4], 1),y)
self.swn4(x, y)
out = self.out(x)
return out