PolarisFTL commited on
Commit
320cf10
·
verified ·
1 Parent(s): e0e2576

Upload FAENet.py

Browse files
Files changed (1) hide show
  1. nets/FAENet.py +248 -0
nets/FAENet.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import cv2
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision
6
+ from torch import nn
7
+ import torch
8
+
9
+ class eca_block(nn.Module):
10
+ def __init__(self, channel, b=1, gamma=2):
11
+ super(eca_block, self).__init__()
12
+ kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
13
+ kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
14
+
15
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
16
+ self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
17
+ self.sigmoid = nn.Sigmoid()
18
+
19
+ def forward(self, x):
20
+
21
+ y = self.avg_pool(x)
22
+ y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
23
+ y = self.sigmoid(y)
24
+ return x * y.expand_as(x)
25
+
26
+ class DilatedConvNet(nn.Module):
27
+ def __init__(self, in_channels, out_channels, dilation, padding, kernel_size):
28
+ super(DilatedConvNet, self).__init__()
29
+ self.dilated_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
30
+ self.relu = nn.ReLU(inplace=False)
31
+
32
+ def forward(self, x):
33
+
34
+ x = self.dilated_conv(x)
35
+ x = self.relu(x)
36
+
37
+ return x
38
+
39
+ class LAM(nn.Module):
40
+ def __init__(self, ch=16):
41
+ super().__init__()
42
+ self.eca = eca_block(ch)
43
+ self.conv1 = nn.Conv2d(6, 3, 3, padding=1)
44
+
45
+ def forward(self, x):
46
+ x = self.eca(x)
47
+ x = self.conv1(x)
48
+ return x
49
+
50
+ class RFEM(nn.Module):
51
+ def __init__(
52
+ self,
53
+ ch_blocks=64,
54
+ ch_mask=16,
55
+ ):
56
+ super().__init__()
57
+
58
+ self.encoder = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1),
59
+ nn.LeakyReLU(True),
60
+ nn.Conv2d(16, ch_blocks, 3, padding=1),
61
+ nn.LeakyReLU(True))
62
+
63
+ self.dconv1 = DilatedConvNet(ch_blocks,
64
+ ch_blocks // 4,
65
+ kernel_size=3,
66
+ padding=1, dilation=1)
67
+ self.dconv2 = DilatedConvNet(ch_blocks,
68
+ ch_blocks // 4,
69
+ kernel_size=3,
70
+ padding=2, dilation=2)
71
+ self.dconv3 = DilatedConvNet(ch_blocks,
72
+ ch_blocks // 4,
73
+ kernel_size=3,
74
+ padding=3, dilation=3)
75
+ self.dconv4 = nn.Conv2d(ch_blocks,
76
+ ch_blocks // 4,
77
+ kernel_size=7,
78
+ padding=3)
79
+
80
+ self.decoder = nn.Sequential(nn.Conv2d(ch_blocks, 16, 3, padding=1),
81
+ nn.LeakyReLU(True),
82
+ nn.Conv2d(16, 3, 3, padding=1),
83
+ nn.LeakyReLU(True),
84
+ )
85
+
86
+ self.lam = LAM(ch_mask)
87
+
88
+ def forward(self, x):
89
+ x1 = self.encoder(x)
90
+ x1_1 = self.dconv1(x1)
91
+ x1_2 = self.dconv2(x1)
92
+ x1_3 = self.dconv3(x1)
93
+ x1_4 = self.dconv4(x1)
94
+ x1 = torch.cat([x1_1, x1_2, x1_3, x1_4], dim=1)
95
+ x1 = self.decoder(x1)
96
+ out = x + x1
97
+ out = torch.relu(out)
98
+ mask = self.lam(torch.cat([x, out], dim=1))
99
+ return out, mask
100
+
101
+ class ATEM(nn.Module):
102
+ def __init__(self, in_ch=3, inter_ch=32, out_ch=3, kernel_size=3):
103
+ super().__init__()
104
+ self.encoder = nn.Sequential(
105
+ nn.Conv2d(in_ch, inter_ch, kernel_size, padding=kernel_size // 2),
106
+ nn.LeakyReLU(True),
107
+ )
108
+ self.shift_conv = nn.Sequential(
109
+ nn.Conv2d(in_ch, inter_ch, kernel_size, padding=kernel_size // 2))
110
+ self.scale_conv = nn.Sequential(
111
+ nn.Conv2d(in_ch, inter_ch, kernel_size, padding=kernel_size // 2))
112
+
113
+
114
+ self.decoder = nn.Sequential(
115
+ nn.Conv2d(inter_ch, out_ch, kernel_size, padding=kernel_size // 2))
116
+
117
+ def forward(self, x, tag):
118
+ x = self.encoder(x)
119
+ scale = self.scale_conv(tag)
120
+ shift = self.shift_conv(tag)
121
+ x = x +(x * scale + shift)
122
+ x = self.decoder(x)
123
+ return x
124
+
125
+ class Trans_high(nn.Module):
126
+ def __init__(self, in_ch=3, inter_ch=16, out_ch=3, kernel_size=3):
127
+ super().__init__()
128
+ self.atem = ATEM(in_ch, inter_ch, out_ch, kernel_size)
129
+ def forward(self, x, tag):
130
+ x = x + self.atem(x, tag)
131
+ return x
132
+
133
+
134
+ class Up_tag(nn.Module):
135
+ def __init__(self, kernel_size=1, ch=3):
136
+ super().__init__()
137
+ self.up = nn.Sequential(
138
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
139
+ nn.Conv2d(ch,
140
+ ch,
141
+ kernel_size,
142
+ stride=1,
143
+ padding=kernel_size // 2,
144
+ bias=False))
145
+
146
+ def forward(self, x):
147
+ x = self.up(x)
148
+ return x
149
+
150
+
151
+ class Lap_Pyramid_Conv(nn.Module):
152
+ def __init__(self, num_high=3, kernel_size=5, channels=3):
153
+ super().__init__()
154
+
155
+ self.num_high = num_high
156
+ self.kernel = self.gauss_kernel(kernel_size, channels)
157
+
158
+ def gauss_kernel(self, kernel_size, channels):
159
+ kernel = cv2.getGaussianKernel(kernel_size, 0).dot(
160
+ cv2.getGaussianKernel(kernel_size, 0).T)
161
+ kernel = torch.FloatTensor(kernel).unsqueeze(0).repeat(
162
+ channels, 1, 1, 1)
163
+ kernel = torch.nn.Parameter(data=kernel, requires_grad=False)
164
+ return kernel
165
+
166
+ def conv_gauss(self, x, kernel):
167
+ n_channels, _, kw, kh = kernel.shape
168
+ x = torch.nn.functional.pad(x, (kw // 2, kh // 2, kw // 2, kh // 2),
169
+ mode='reflect')
170
+ x = torch.nn.functional.conv2d(x, kernel, groups=n_channels)
171
+ return x
172
+ def downsample(self, x):
173
+ return x[:, :, ::2, ::2]
174
+ def pyramid_down(self, x):
175
+ return self.downsample(self.conv_gauss(x, self.kernel))
176
+ def upsample(self, x):
177
+ up = torch.zeros((x.size(0), x.size(1), x.size(2) * 2, x.size(3) * 2),
178
+ device=x.device)
179
+ up[:, :, ::2, ::2] = x * 4
180
+
181
+ return self.conv_gauss(up, self.kernel)
182
+
183
+ def pyramid_decom(self, img):
184
+ self.kernel = self.kernel.to(img.device)
185
+ current = img
186
+ pyr = []
187
+ for _ in range(self.num_high):
188
+ down = self.pyramid_down(current)
189
+ up = self.upsample(down)
190
+ diff = current - up
191
+ pyr.append(diff)
192
+ current = down
193
+ pyr.append(current)
194
+ return pyr
195
+
196
+ def pyramid_recons(self, pyr):
197
+ image = pyr[0]
198
+ for level in pyr[1:]:
199
+ up = self.upsample(image)
200
+ image = up + level
201
+ return image
202
+
203
+ class FAENet(nn.Module):
204
+ def __init__(self,
205
+ num_high=1,
206
+ ch_blocks=32,
207
+ up_ksize=1,
208
+ high_ch=32,
209
+ high_ksize=3,
210
+ ch_mask=32,
211
+ gauss_kernel=7):
212
+ super().__init__()
213
+ self.num_high = num_high
214
+ self.lap_pyramid = Lap_Pyramid_Conv(num_high, gauss_kernel)
215
+ self.rfem = RFEM(ch_blocks, ch_mask)
216
+
217
+ for i in range(0, self.num_high):
218
+ self.__setattr__('up_tag_layer_{}'.format(i),
219
+ Up_tag(up_ksize, ch=3))
220
+ self.__setattr__('trans_high_layer_{}'.format(i),
221
+ Trans_high(3, high_ch, 3, high_ksize))
222
+
223
+ def forward(self, x):
224
+ pyrs = self.lap_pyramid.pyramid_decom(img=x)
225
+
226
+ trans_pyrs = []
227
+ trans_pyr, tag = self.rfem(pyrs[-1])
228
+ trans_pyrs.append(trans_pyr)
229
+
230
+ commom_tag = []
231
+ for i in range(self.num_high):
232
+ tag = self.__getattr__('up_tag_layer_{}'.format(i))(tag)
233
+ commom_tag.append(tag)
234
+
235
+ for i in range(self.num_high):
236
+ trans_pyr = self.__getattr__('trans_high_layer_{}'.format(i))(
237
+ pyrs[-2 - i], commom_tag[i])
238
+ trans_pyrs.append(trans_pyr)
239
+
240
+ out = self.lap_pyramid.pyramid_recons(trans_pyrs)
241
+
242
+ return out
243
+
244
+
245
+ faenet = FAENet()
246
+ params = faenet.parameters()
247
+ num_params = sum(p.numel() for p in params)
248
+ print("FAENet parameters: {:.2f}K ".format(num_params/ 1024) + "{:.2f} MB".format(num_params/ (1024 * 1024)))