hikmatfarhat commited on
Commit
114475c
·
1 Parent(s): cd4cfe7

Upload dcgan.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dcgan.py +99 -0
dcgan.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # The networks are taken from
5
+ # https://arxiv.org/abs/1511.06434
6
+ class TBlock(nn.Module):
7
+ def __init__(self, in_ch, out_ch, kernel_size,stride,pad,norm_type):
8
+ super().__init__()
9
+ self.net = nn.Sequential(
10
+ nn.ConvTranspose2d(in_ch,out_ch,kernel_size,stride, pad,bias=False),
11
+ norm_layer(out_ch,norm_type),
12
+ nn.ReLU()
13
+ )
14
+ def forward(self,x):
15
+ return self.net(x)
16
+ class CBlock(nn.Module):
17
+ def __init__(self, in_ch, out_ch, kernel_size,stride,pad,norm_type: str = "batch"):
18
+ super().__init__()
19
+ self.net = nn.Sequential(
20
+ nn.Conv2d(in_ch,out_ch,kernel_size,stride, pad, bias=False),
21
+ norm_layer(out_ch,norm_type),
22
+ nn.LeakyReLU(0.2)
23
+ )
24
+ def forward(self,x):
25
+ return self.net(x)
26
+
27
+ class Generator(nn.Module):
28
+ #Outputs 64x64 pixel images
29
+
30
+ def __init__(
31
+ self,img_size=64,
32
+ out_ch=3,zdim=100,norm_type="BatchNorm2d",final_activation=None
33
+ ):
34
+ super().__init__()
35
+ # self.nf_g = nf_g
36
+ # self.z_dim = z_dim
37
+ # self.out_ch = out_ch
38
+ nf_g=2*img_size
39
+ self.final_activation=None if final_activation is None else getattr(torch,final_activation)
40
+
41
+ self.net = nn.Sequential(
42
+ # * Layer 1: 1x1
43
+ TBlock(zdim,8*nf_g, 4,1, 0,norm_type),
44
+ # * Layer 2: 4x4
45
+ TBlock(8*nf_g,4*nf_g,4,2,1,norm_type),
46
+ # * Layer 3: 8x8
47
+ TBlock(4*nf_g,2*nf_g,4,2,1,norm_type),
48
+ # * Layer 4: 16x16
49
+ TBlock(2*nf_g,nf_g,4,2,1,norm_type),
50
+ # * Layer 5: 32x32
51
+ nn.ConvTranspose2d(nf_g, out_ch, 4, 2, 1, bias=False),
52
+ # * Output: 64x64
53
+ )
54
+
55
+ def forward(self, x):
56
+ x = self.net(x)
57
+ return x if self.final_activation is None else self.final_activation(x)
58
+
59
+ #return torch.tanh(x)
60
+
61
+
62
+ class Discriminator(nn.Module):
63
+ def __init__(self, img_size=64,in_ch=3,norm_type="BatchNorm2d",final_activation=None):
64
+ super().__init__()
65
+ nf_d=img_size
66
+ self.final_activation=None if final_activation is None else getattr(torch,final_activation)
67
+ self.net = nn.Sequential(
68
+ # * 64x64
69
+ nn.Conv2d(in_ch, nf_d, 4, 2, 1, bias=False),
70
+ nn.LeakyReLU(0.2),
71
+ # * 32x32
72
+ CBlock(nf_d,2*nf_d,4,2,1,norm_type),
73
+ # * 16x16
74
+ CBlock(2*nf_d,4*nf_d,4,2,1,norm_type),
75
+ # * 8x8
76
+ CBlock(4*nf_d,8*nf_d,4,2,1,norm_type),
77
+ # * 4x4
78
+ nn.Conv2d(8*nf_d, 1, 4, 1, 0, bias=False),
79
+ )
80
+
81
+
82
+ def forward(self, x):
83
+ x = self.net(x)
84
+ return x if self.final_activation is None else self.final_activation(x)
85
+
86
+ class norm_layer(nn.Module):
87
+ def __init__(self, num_channels,norm_type: str = None):
88
+ super().__init__()
89
+ if norm_type == "BatchNorm2d":
90
+ self.norm = nn.BatchNorm2d(num_channels)
91
+ elif norm_type == "GroupNorm":
92
+ self.norm = nn.GroupNorm(num_channels, num_channels)
93
+ elif norm_type is None or norm_type == "None":
94
+ self.norm=None
95
+ else:
96
+ raise ValueError(f"Unknown normalization type: {norm_type}")
97
+
98
+ def forward(self, x):
99
+ return x if self.norm is None else self.norm(x)