Nikhil Mudhalwadkar
added other files
c6d5483
raw
history blame
4.13 kB
import torch.nn as nn
import torch
import albumentations as A
# CNN block will be used repeatly later
class CNNBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=2):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode='reflect'),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
def forward(self, x):
return self.conv(x)
class PatchGan(torch.nn.Module):
""" Patch GAN Architecture """
@staticmethod
def create_contracting_block(in_channels: int, out_channels: int):
"""
Create encoding layer
:param in_channels:
:param out_channels:
:return:
"""
conv_layer = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
),
torch.nn.ReLU(),
torch.nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
),
torch.nn.ReLU(),
)
max_pool = torch.nn.Sequential(
torch.nn.MaxPool2d(
stride=2,
kernel_size=2,
),
)
layer = torch.nn.Sequential(
conv_layer,
max_pool,
)
return layer
def __init__(self, input_channels: int, hidden_channels: int) -> None:
super().__init__()
self.resize_channels = torch.nn.Conv2d(
in_channels=input_channels,
out_channels=hidden_channels,
kernel_size=1,
)
self.enc1 = self.create_contracting_block(
in_channels=hidden_channels,
out_channels=hidden_channels * 2
)
self.enc2 = self.create_contracting_block(
in_channels=hidden_channels * 2,
out_channels=hidden_channels * 4
)
self.enc3 = self.create_contracting_block(
in_channels=hidden_channels * 4,
out_channels=hidden_channels * 8
)
self.enc4 = self.create_contracting_block(
in_channels=hidden_channels * 8,
out_channels=hidden_channels * 16
)
self.final_layer = torch.nn.Conv2d(
in_channels=hidden_channels * 16,
out_channels=1,
kernel_size=1,
)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
""" Forward patch gan layer """
inpt = torch.cat([x, y], axis=1)
resize_img = self.resize_channels(inpt)
enc1 = self.enc1(resize_img)
enc2 = self.enc2(enc1)
enc3 = self.enc3(enc2)
enc4 = self.enc4(enc3)
final_layer = self.final_layer(enc4)
return final_layer
# x, y <- concatenate the gen image and the input image to determin the gen image is real or not
class Discriminator(nn.Module):
def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
super().__init__()
self.initial = nn.Sequential(
nn.Conv2d(in_channels * 2, features[0], kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
nn.LeakyReLU(.2)
)
# save layers into a list
layers = []
in_channels = features[0]
for feature in features[1:]:
layers.append(
CNNBlock(
in_channels,
feature,
stride=1 if feature == features[-1] else 2
),
)
in_channels = feature
# append last conv layer
layers.append(
nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect')
)
# create a model using the list of layers
self.model = nn.Sequential(*layers)
def forward(self, x, y):
x = torch.cat([x, y], dim=1)
x = self.initial(x)
return self.model(x)