caikybaldo999 commited on
Commit
10827e4
·
verified ·
1 Parent(s): 06bcef3

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +29 -0
model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Generator(nn.Module):
5
+ def __init__(self, latent_dim=100):
6
+ super(Generator, self).__init__()
7
+ self.main = nn.Sequential(
8
+ nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
9
+ nn.BatchNorm2d(512),
10
+ nn.ReLU(True),
11
+
12
+ nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
13
+ nn.BatchNorm2d(256),
14
+ nn.ReLU(True),
15
+
16
+ nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
17
+ nn.BatchNorm2d(128),
18
+ nn.ReLU(True),
19
+
20
+ nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
21
+ nn.BatchNorm2d(64),
22
+ nn.ReLU(True),
23
+
24
+ nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
25
+ nn.Tanh()
26
+ )
27
+
28
+ def forward(self, input):
29
+ return self.main(input)