Spaces:
Running
Running
File size: 1,489 Bytes
e1aa346 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | from torch import nn
class Generator(nn.Module):
def __init__(self, z_dim=100, input_channels=3, hidden_dim=64):
super(Generator, self).__init__()
self.z_dim = z_dim
self.gen = nn.Sequential(
self.generator_block(z_dim, hidden_dim * 32, stride=1, padding=0),
self.generator_block(hidden_dim * 32, hidden_dim * 16),
self.generator_block(hidden_dim * 16, hidden_dim * 8),
self.generator_block(hidden_dim * 8, hidden_dim * 4),
self.generator_block(hidden_dim * 4, hidden_dim * 2),
self.generator_block(hidden_dim * 2, hidden_dim),
self.generator_block(hidden_dim, input_channels, final_layer=True)
)
def generator_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=1, final_layer=False):
if not final_layer:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
nn.InstanceNorm2d(output_channels, affine=True),
nn.ReLU(inplace=True)
)
else:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding),
nn.Tanh()
)
def forward(self, noise):
return self.gen(noise.view(len(noise), self.z_dim, 1, 1)) |