devoppro commited on
Commit
cf3c3e8
·
verified ·
1 Parent(s): 6ee8442

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +37 -0
model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ NOISE_DIM = 256
5
+
6
+ class Generator(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ self.fc = nn.Linear(NOISE_DIM, 4*4*512)
11
+
12
+ self.net = nn.Sequential(
13
+ nn.BatchNorm2d(512),
14
+ nn.Upsample(scale_factor=2),
15
+ nn.Conv2d(512, 256, 3, padding=1),
16
+ nn.BatchNorm2d(256),
17
+ nn.ReLU(True),
18
+
19
+ nn.Upsample(scale_factor=2),
20
+ nn.Conv2d(256, 128, 3, padding=1),
21
+ nn.BatchNorm2d(128),
22
+ nn.ReLU(True),
23
+
24
+ nn.Upsample(scale_factor=2),
25
+ nn.Conv2d(128, 64, 3, padding=1),
26
+ nn.BatchNorm2d(64),
27
+ nn.ReLU(True),
28
+
29
+ nn.Upsample(scale_factor=2),
30
+ nn.Conv2d(64, 3, 3, padding=1),
31
+ nn.Tanh()
32
+ )
33
+
34
+ def forward(self, noise):
35
+ x = self.fc(noise)
36
+ x = x.view(-1, 512, 4, 4)
37
+ return self.net(x)