from torch import nn class Generator(nn.Module): def __init__(self, z_dim=100, input_channels=3, hidden_dim=64): super(Generator, self).__init__() self.z_dim = z_dim self.gen = nn.Sequential( self.generator_block(z_dim, hidden_dim * 32, stride=1, padding=0), self.generator_block(hidden_dim * 32, hidden_dim * 16), self.generator_block(hidden_dim * 16, hidden_dim * 8), self.generator_block(hidden_dim * 8, hidden_dim * 4), self.generator_block(hidden_dim * 4, hidden_dim * 2), self.generator_block(hidden_dim * 2, hidden_dim), self.generator_block(hidden_dim, input_channels, final_layer=True) ) def generator_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=1, final_layer=False): if not final_layer: return nn.Sequential( nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False), nn.InstanceNorm2d(output_channels, affine=True), nn.ReLU(inplace=True) ) else: return nn.Sequential( nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding), nn.Tanh() ) def forward(self, noise): return self.gen(noise.view(len(noise), self.z_dim, 1, 1))