File size: 1,489 Bytes
e1aa346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from torch import nn 

class Generator(nn.Module):
    def __init__(self, z_dim=100, input_channels=3, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            self.generator_block(z_dim, hidden_dim * 32, stride=1, padding=0),  
            self.generator_block(hidden_dim * 32, hidden_dim * 16), 
            self.generator_block(hidden_dim * 16, hidden_dim * 8),  
            self.generator_block(hidden_dim * 8, hidden_dim * 4),  
            self.generator_block(hidden_dim * 4, hidden_dim * 2),  
            self.generator_block(hidden_dim * 2, hidden_dim), 
            self.generator_block(hidden_dim, input_channels, final_layer=True) 
        )
        
    def generator_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=1, final_layer=False): 
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
                nn.InstanceNorm2d(output_channels, affine=True),  
                nn.ReLU(inplace=True)
            )
        else: 
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding),
                nn.Tanh()
            )

    def forward(self, noise):
        return self.gen(noise.view(len(noise), self.z_dim, 1, 1))