PolarisFTL commited on
Commit
5dbd914
·
verified ·
1 Parent(s): 3fc7ec1

Delete nets/FAENet.py

Browse files
Files changed (1) hide show
  1. nets/FAENet.py +0 -248
nets/FAENet.py DELETED
@@ -1,248 +0,0 @@
1
- import math
2
- import cv2
3
- import torch
4
- import torch.nn as nn
5
- import torchvision
6
- from torch import nn
7
- import torch
8
-
9
- class eca_block(nn.Module):
10
- def __init__(self, channel, b=1, gamma=2):
11
- super(eca_block, self).__init__()
12
- kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
13
- kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
14
-
15
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
16
- self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
17
- self.sigmoid = nn.Sigmoid()
18
-
19
- def forward(self, x):
20
-
21
- y = self.avg_pool(x)
22
- y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
23
- y = self.sigmoid(y)
24
- return x * y.expand_as(x)
25
-
26
- class DilatedConvNet(nn.Module):
27
- def __init__(self, in_channels, out_channels, dilation, padding, kernel_size):
28
- super(DilatedConvNet, self).__init__()
29
- self.dilated_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
30
- self.relu = nn.ReLU(inplace=False)
31
-
32
- def forward(self, x):
33
-
34
- x = self.dilated_conv(x)
35
- x = self.relu(x)
36
-
37
- return x
38
-
39
- class LAM(nn.Module):
40
- def __init__(self, ch=16):
41
- super().__init__()
42
- self.eca = eca_block(ch)
43
- self.conv1 = nn.Conv2d(6, 3, 3, padding=1)
44
-
45
- def forward(self, x):
46
- x = self.eca(x)
47
- x = self.conv1(x)
48
- return x
49
-
50
- class RFEM(nn.Module):
51
- def __init__(
52
- self,
53
- ch_blocks=64,
54
- ch_mask=16,
55
- ):
56
- super().__init__()
57
-
58
- self.encoder = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1),
59
- nn.LeakyReLU(True),
60
- nn.Conv2d(16, ch_blocks, 3, padding=1),
61
- nn.LeakyReLU(True))
62
-
63
- self.dconv1 = DilatedConvNet(ch_blocks,
64
- ch_blocks // 4,
65
- kernel_size=3,
66
- padding=1, dilation=1)
67
- self.dconv2 = DilatedConvNet(ch_blocks,
68
- ch_blocks // 4,
69
- kernel_size=3,
70
- padding=2, dilation=2)
71
- self.dconv3 = DilatedConvNet(ch_blocks,
72
- ch_blocks // 4,
73
- kernel_size=3,
74
- padding=3, dilation=3)
75
- self.dconv4 = nn.Conv2d(ch_blocks,
76
- ch_blocks // 4,
77
- kernel_size=7,
78
- padding=3)
79
-
80
- self.decoder = nn.Sequential(nn.Conv2d(ch_blocks, 16, 3, padding=1),
81
- nn.LeakyReLU(True),
82
- nn.Conv2d(16, 3, 3, padding=1),
83
- nn.LeakyReLU(True),
84
- )
85
-
86
- self.lam = LAM(ch_mask)
87
-
88
- def forward(self, x):
89
- x1 = self.encoder(x)
90
- x1_1 = self.dconv1(x1)
91
- x1_2 = self.dconv2(x1)
92
- x1_3 = self.dconv3(x1)
93
- x1_4 = self.dconv4(x1)
94
- x1 = torch.cat([x1_1, x1_2, x1_3, x1_4], dim=1)
95
- x1 = self.decoder(x1)
96
- out = x + x1
97
- out = torch.relu(out)
98
- mask = self.lam(torch.cat([x, out], dim=1))
99
- return out, mask
100
-
101
- class ATEM(nn.Module):
102
- def __init__(self, in_ch=3, inter_ch=32, out_ch=3, kernel_size=3):
103
- super().__init__()
104
- self.encoder = nn.Sequential(
105
- nn.Conv2d(in_ch, inter_ch, kernel_size, padding=kernel_size // 2),
106
- nn.LeakyReLU(True),
107
- )
108
- self.shift_conv = nn.Sequential(
109
- nn.Conv2d(in_ch, inter_ch, kernel_size, padding=kernel_size // 2))
110
- self.scale_conv = nn.Sequential(
111
- nn.Conv2d(in_ch, inter_ch, kernel_size, padding=kernel_size // 2))
112
-
113
-
114
- self.decoder = nn.Sequential(
115
- nn.Conv2d(inter_ch, out_ch, kernel_size, padding=kernel_size // 2))
116
-
117
- def forward(self, x, tag):
118
- x = self.encoder(x)
119
- scale = self.scale_conv(tag)
120
- shift = self.shift_conv(tag)
121
- x = x +(x * scale + shift)
122
- x = self.decoder(x)
123
- return x
124
-
125
- class Trans_high(nn.Module):
126
- def __init__(self, in_ch=3, inter_ch=16, out_ch=3, kernel_size=3):
127
- super().__init__()
128
- self.atem = ATEM(in_ch, inter_ch, out_ch, kernel_size)
129
- def forward(self, x, tag):
130
- x = x + self.atem(x, tag)
131
- return x
132
-
133
-
134
- class Up_tag(nn.Module):
135
- def __init__(self, kernel_size=1, ch=3):
136
- super().__init__()
137
- self.up = nn.Sequential(
138
- nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
139
- nn.Conv2d(ch,
140
- ch,
141
- kernel_size,
142
- stride=1,
143
- padding=kernel_size // 2,
144
- bias=False))
145
-
146
- def forward(self, x):
147
- x = self.up(x)
148
- return x
149
-
150
-
151
- class Lap_Pyramid_Conv(nn.Module):
152
- def __init__(self, num_high=3, kernel_size=5, channels=3):
153
- super().__init__()
154
-
155
- self.num_high = num_high
156
- self.kernel = self.gauss_kernel(kernel_size, channels)
157
-
158
- def gauss_kernel(self, kernel_size, channels):
159
- kernel = cv2.getGaussianKernel(kernel_size, 0).dot(
160
- cv2.getGaussianKernel(kernel_size, 0).T)
161
- kernel = torch.FloatTensor(kernel).unsqueeze(0).repeat(
162
- channels, 1, 1, 1)
163
- kernel = torch.nn.Parameter(data=kernel, requires_grad=False)
164
- return kernel
165
-
166
- def conv_gauss(self, x, kernel):
167
- n_channels, _, kw, kh = kernel.shape
168
- x = torch.nn.functional.pad(x, (kw // 2, kh // 2, kw // 2, kh // 2),
169
- mode='reflect')
170
- x = torch.nn.functional.conv2d(x, kernel, groups=n_channels)
171
- return x
172
- def downsample(self, x):
173
- return x[:, :, ::2, ::2]
174
- def pyramid_down(self, x):
175
- return self.downsample(self.conv_gauss(x, self.kernel))
176
- def upsample(self, x):
177
- up = torch.zeros((x.size(0), x.size(1), x.size(2) * 2, x.size(3) * 2),
178
- device=x.device)
179
- up[:, :, ::2, ::2] = x * 4
180
-
181
- return self.conv_gauss(up, self.kernel)
182
-
183
- def pyramid_decom(self, img):
184
- self.kernel = self.kernel.to(img.device)
185
- current = img
186
- pyr = []
187
- for _ in range(self.num_high):
188
- down = self.pyramid_down(current)
189
- up = self.upsample(down)
190
- diff = current - up
191
- pyr.append(diff)
192
- current = down
193
- pyr.append(current)
194
- return pyr
195
-
196
- def pyramid_recons(self, pyr):
197
- image = pyr[0]
198
- for level in pyr[1:]:
199
- up = self.upsample(image)
200
- image = up + level
201
- return image
202
-
203
- class FAENet(nn.Module):
204
- def __init__(self,
205
- num_high=1,
206
- ch_blocks=32,
207
- up_ksize=1,
208
- high_ch=32,
209
- high_ksize=3,
210
- ch_mask=32,
211
- gauss_kernel=7):
212
- super().__init__()
213
- self.num_high = num_high
214
- self.lap_pyramid = Lap_Pyramid_Conv(num_high, gauss_kernel)
215
- self.rfem = RFEM(ch_blocks, ch_mask)
216
-
217
- for i in range(0, self.num_high):
218
- self.__setattr__('up_tag_layer_{}'.format(i),
219
- Up_tag(up_ksize, ch=3))
220
- self.__setattr__('trans_high_layer_{}'.format(i),
221
- Trans_high(3, high_ch, 3, high_ksize))
222
-
223
- def forward(self, x):
224
- pyrs = self.lap_pyramid.pyramid_decom(img=x)
225
-
226
- trans_pyrs = []
227
- trans_pyr, tag = self.rfem(pyrs[-1])
228
- trans_pyrs.append(trans_pyr)
229
-
230
- commom_tag = []
231
- for i in range(self.num_high):
232
- tag = self.__getattr__('up_tag_layer_{}'.format(i))(tag)
233
- commom_tag.append(tag)
234
-
235
- for i in range(self.num_high):
236
- trans_pyr = self.__getattr__('trans_high_layer_{}'.format(i))(
237
- pyrs[-2 - i], commom_tag[i])
238
- trans_pyrs.append(trans_pyr)
239
-
240
- out = self.lap_pyramid.pyramid_recons(trans_pyrs)
241
-
242
- return out
243
-
244
-
245
- faenet = FAENet()
246
- params = faenet.parameters()
247
- num_params = sum(p.numel() for p in params)
248
- print("FAENet parameters: {:.2f}K ".format(num_params/ 1024) + "{:.2f} MB".format(num_params/ (1024 * 1024)))