keysun89 commited on
Commit
dc39dcf
·
verified ·
1 Parent(s): b858c17

Create genrator_2.py

Browse files
Files changed (1) hide show
  1. genrator_2.py +167 -0
genrator_2.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class _conv(nn.Conv2d):
7
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias):
8
+ super(_conv, self).__init__(in_channels = in_channels, out_channels = out_channels,
9
+ kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True)
10
+
11
+ self.weight.data = torch.normal(torch.zeros((out_channels, in_channels, kernel_size, kernel_size)), 0.02)
12
+ self.bias.data = torch.zeros((out_channels))
13
+
14
+ for p in self.parameters():
15
+ p.requires_grad = True
16
+
17
+
18
+ class conv(nn.Module):
19
+ def __init__(self, in_channel, out_channel, kernel_size, BN = False, act = None, stride = 1, bias = True):
20
+ super(conv, self).__init__()
21
+ m = []
22
+ m.append(_conv(in_channels = in_channel, out_channels = out_channel,
23
+ kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True))
24
+
25
+ if BN:
26
+ m.append(nn.BatchNorm2d(num_features = out_channel))
27
+
28
+ if act is not None:
29
+ m.append(act)
30
+
31
+ self.body = nn.Sequential(*m)
32
+
33
+ def forward(self, x):
34
+ out = self.body(x)
35
+ return out
36
+
37
+ class ResBlock(nn.Module):
38
+ def __init__(self, channels, kernel_size, act = nn.ReLU(inplace = True), bias = True):
39
+ super(ResBlock, self).__init__()
40
+ m = []
41
+ m.append(conv(channels, channels, kernel_size, BN = True, act = act))
42
+ m.append(conv(channels, channels, kernel_size, BN = True, act = None))
43
+ self.body = nn.Sequential(*m)
44
+
45
+ def forward(self, x):
46
+ res = self.body(x)
47
+ res += x
48
+ return res
49
+
50
+ class BasicBlock(nn.Module):
51
+ def __init__(self, in_channels, out_channels, kernel_size, num_res_block, act = nn.ReLU(inplace = True)):
52
+ super(BasicBlock, self).__init__()
53
+ m = []
54
+
55
+ self.conv = conv(in_channels, out_channels, kernel_size, BN = False, act = act)
56
+ for i in range(num_res_block):
57
+ m.append(ResBlock(out_channels, kernel_size, act))
58
+
59
+ m.append(conv(out_channels, out_channels, kernel_size, BN = True, act = None))
60
+
61
+ self.body = nn.Sequential(*m)
62
+
63
+ def forward(self, x):
64
+ res = self.conv(x)
65
+ out = self.body(res)
66
+ out += res
67
+
68
+ return out
69
+
70
+ class Upsampler(nn.Module):
71
+ def __init__(self, channel, kernel_size, scale, act = nn.ReLU(inplace = True)):
72
+ super(Upsampler, self).__init__()
73
+ m = []
74
+ m.append(conv(channel, channel * scale * scale, kernel_size))
75
+ m.append(nn.PixelShuffle(scale))
76
+
77
+ if act is not None:
78
+ m.append(act)
79
+
80
+ self.body = nn.Sequential(*m)
81
+
82
+ def forward(self, x):
83
+ out = self.body(x)
84
+ return out
85
+
86
+ class discrim_block(nn.Module):
87
+ def __init__(self, in_feats, out_feats, kernel_size, act = nn.LeakyReLU(inplace = True)):
88
+ super(discrim_block, self).__init__()
89
+ m = []
90
+ m.append(conv(in_feats, out_feats, kernel_size, BN = True, act = act))
91
+ m.append(conv(out_feats, out_feats, kernel_size, BN = True, act = act, stride = 2))
92
+ self.body = nn.Sequential(*m)
93
+
94
+ def forward(self, x):
95
+ out = self.body(x)
96
+ return out
97
+
98
+
99
+ class Generator(nn.Module):
100
+
101
+ def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, num_block = 16, act = nn.PReLU(), scale=4):
102
+ super(Generator, self).__init__()
103
+
104
+ self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 9, BN = False, act = act)
105
+
106
+ resblocks = [ResBlock(channels = n_feats, kernel_size = 3, act = act) for _ in range(num_block)]
107
+ self.body = nn.Sequential(*resblocks)
108
+
109
+ self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = True, act = None)
110
+
111
+ if(scale == 4):
112
+ upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = 2, act = act) for _ in range(2)]
113
+ else:
114
+ upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = scale, act = act)]
115
+
116
+ self.tail = nn.Sequential(*upsample_blocks)
117
+
118
+ self.last_conv = conv(in_channel = n_feats, out_channel = img_feat, kernel_size = 3, BN = False, act = nn.Tanh())
119
+
120
+ def forward(self, x):
121
+
122
+ x = self.conv01(x)
123
+ _skip_connection = x
124
+
125
+ x = self.body(x)
126
+ x = self.conv02(x)
127
+ feat = x + _skip_connection
128
+
129
+ x = self.tail(feat)
130
+ x = self.last_conv(x)
131
+
132
+ return x, feat
133
+
134
+ class Discriminator(nn.Module):
135
+
136
+ def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, act = nn.LeakyReLU(inplace = True), num_of_block = 3, patch_size = 96):
137
+ super(Discriminator, self).__init__()
138
+ self.act = act
139
+
140
+ self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act)
141
+ self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act, stride = 2)
142
+
143
+ body = [discrim_block(in_feats = n_feats * (2 ** i), out_feats = n_feats * (2 ** (i + 1)), kernel_size = 3, act = self.act) for i in range(num_of_block)]
144
+ self.body = nn.Sequential(*body)
145
+
146
+ self.linear_size = ((patch_size // (2 ** (num_of_block + 1))) ** 2) * (n_feats * (2 ** num_of_block))
147
+
148
+ tail = []
149
+
150
+ tail.append(nn.Linear(self.linear_size, 1024))
151
+ tail.append(self.act)
152
+ tail.append(nn.Linear(1024, 1))
153
+ tail.append(nn.Sigmoid())
154
+
155
+ self.tail = nn.Sequential(*tail)
156
+
157
+
158
+ def forward(self, x):
159
+
160
+ x = self.conv01(x)
161
+ x = self.conv02(x)
162
+ x = self.body(x)
163
+ x = x.view(-1, self.linear_size)
164
+ x = self.tail(x)
165
+
166
+ return x
167
+