GeoGen-API / models.py
saniaE
created fastapi
e1aa346
from torch import nn
class Generator(nn.Module):
def __init__(self, z_dim=100, input_channels=3, hidden_dim=64):
super(Generator, self).__init__()
self.z_dim = z_dim
self.gen = nn.Sequential(
self.generator_block(z_dim, hidden_dim * 32, stride=1, padding=0),
self.generator_block(hidden_dim * 32, hidden_dim * 16),
self.generator_block(hidden_dim * 16, hidden_dim * 8),
self.generator_block(hidden_dim * 8, hidden_dim * 4),
self.generator_block(hidden_dim * 4, hidden_dim * 2),
self.generator_block(hidden_dim * 2, hidden_dim),
self.generator_block(hidden_dim, input_channels, final_layer=True)
)
def generator_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=1, final_layer=False):
if not final_layer:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
nn.InstanceNorm2d(output_channels, affine=True),
nn.ReLU(inplace=True)
)
else:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding),
nn.Tanh()
)
def forward(self, noise):
return self.gen(noise.view(len(noise), self.z_dim, 1, 1))