maps / model.py
vishnuraggav's picture
first commit
ce847a6
import torch
import torch.nn as nn
class DownBlock(nn.Module):
def __init__(self, in_filters, out_filters, normal=True):
super().__init__()
layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1, padding_mode='reflect', bias=not normal)]
if normal:
layers.append(nn.InstanceNorm2d(out_filters, affine=True))
layers.append(nn.LeakyReLU(0.2, inplace=True))
self.block = nn.Sequential(*layers)
def forward(self, x):
return self.block(x)
class UpBlock(nn.Module):
def __init__(self, in_filters, out_filters, dropout=0.0):
super().__init__()
layers = [
nn.ConvTranspose2d(in_filters, out_filters, 4, 2, 1, bias=False),
nn.InstanceNorm2d(out_filters, affine=True),
nn.ReLU(inplace=True),
]
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x, skip_input):
x = self.model(x)
x = torch.cat((x, skip_input), 1)
return x
class Generator(nn.Module):
def __init__(self, input_channels, features=[64, 128, 256, 512, 512, 512, 512]):
super().__init__()
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
for idx, feature in enumerate(features):
if idx == 0:
self.encoder.append(DownBlock(input_channels, feature, normal=False))
else:
self.encoder.append(DownBlock(input_channels, feature))
input_channels = feature
self.bottleneck = DownBlock(512, 512, normal=False)
self.final = nn.Sequential(
nn.ConvTranspose2d(128, 3, 4, 2, 1),
nn.Tanh()
)
input_channels = features[-1]
for idx, feature in enumerate(reversed(features)):
if idx == 0:
self.decoder.append(UpBlock(input_channels, feature, dropout=0.5))
elif idx < 3:
self.decoder.append(UpBlock(input_channels*2, feature, dropout=0.5))
else:
self.decoder.append(UpBlock(input_channels*2, feature))
input_channels = feature
def forward(self, x):
skips = []
for layer in self.encoder:
x = layer(x)
skips.append(x)
x = self.bottleneck(x)
skips = skips[::-1]
for idx, layer in enumerate(self.decoder):
x = layer(x, skips[idx])
x = self.final(x)
return x