Spaces:
Running
Running
| from torch import nn | |
| class Generator(nn.Module): | |
| def __init__(self, z_dim=100, input_channels=3, hidden_dim=64): | |
| super(Generator, self).__init__() | |
| self.z_dim = z_dim | |
| self.gen = nn.Sequential( | |
| self.generator_block(z_dim, hidden_dim * 32, stride=1, padding=0), | |
| self.generator_block(hidden_dim * 32, hidden_dim * 16), | |
| self.generator_block(hidden_dim * 16, hidden_dim * 8), | |
| self.generator_block(hidden_dim * 8, hidden_dim * 4), | |
| self.generator_block(hidden_dim * 4, hidden_dim * 2), | |
| self.generator_block(hidden_dim * 2, hidden_dim), | |
| self.generator_block(hidden_dim, input_channels, final_layer=True) | |
| ) | |
| def generator_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=1, final_layer=False): | |
| if not final_layer: | |
| return nn.Sequential( | |
| nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False), | |
| nn.InstanceNorm2d(output_channels, affine=True), | |
| nn.ReLU(inplace=True) | |
| ) | |
| else: | |
| return nn.Sequential( | |
| nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding), | |
| nn.Tanh() | |
| ) | |
| def forward(self, noise): | |
| return self.gen(noise.view(len(noise), self.z_dim, 1, 1)) |