x444 commited on
Commit
75cbf50
·
1 Parent(s): edece61
Files changed (1) hide show
  1. bg_removal.py +663 -0
bg_removal.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from torchvision.transforms.functional import normalize
9
+ import warnings
10
+ import os
11
+ warnings.filterwarnings("ignore")
12
+
13
+
14
+ bce_loss = nn.BCELoss(size_average=True)
15
+ def muti_loss_fusion(preds, target):
16
+ loss0 = 0.0
17
+ loss = 0.0
18
+
19
+ for i in range(0,len(preds)):
20
+ # print("i: ", i, preds[i].shape)
21
+ if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
22
+ # tmp_target = _upsample_like(target,preds[i])
23
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
24
+ loss = loss + bce_loss(preds[i],tmp_target)
25
+ else:
26
+ loss = loss + bce_loss(preds[i],target)
27
+ if(i==0):
28
+ loss0 = loss
29
+ return loss0, loss
30
+
31
+
32
+ fea_loss = nn.MSELoss(size_average=True)
33
+ kl_loss = nn.KLDivLoss(size_average=True)
34
+ l1_loss = nn.L1Loss(size_average=True)
35
+ smooth_l1_loss = nn.SmoothL1Loss(size_average=True)
36
+ def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'):
37
+ loss0 = 0.0
38
+ loss = 0.0
39
+
40
+ for i in range(0,len(preds)):
41
+ # print("i: ", i, preds[i].shape)
42
+ if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
43
+ # tmp_target = _upsample_like(target,preds[i])
44
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
45
+ loss = loss + bce_loss(preds[i],tmp_target)
46
+ else:
47
+ loss = loss + bce_loss(preds[i],target)
48
+ if(i==0):
49
+ loss0 = loss
50
+
51
+ for i in range(0,len(dfs)):
52
+ if(mode=='MSE'):
53
+ loss = loss + fea_loss(dfs[i],fs[i]) ### add the mse loss of features as additional constraints
54
+ # print("fea_loss: ", fea_loss(dfs[i],fs[i]).item())
55
+ elif(mode=='KL'):
56
+ loss = loss + kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1))
57
+ # print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item())
58
+ elif(mode=='MAE'):
59
+ loss = loss + l1_loss(dfs[i],fs[i])
60
+ # print("ls_loss: ", l1_loss(dfs[i],fs[i]))
61
+ elif(mode=='SmoothL1'):
62
+ loss = loss + smooth_l1_loss(dfs[i],fs[i])
63
+ # print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item())
64
+
65
+ return loss0, loss
66
+
67
+ class REBNCONV(nn.Module):
68
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
69
+ super(REBNCONV,self).__init__()
70
+
71
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
72
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
73
+ self.relu_s1 = nn.ReLU(inplace=True)
74
+
75
+ def forward(self,x):
76
+
77
+ hx = x
78
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
79
+
80
+ return xout
81
+
82
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
83
+ def _upsample_like(src,tar):
84
+
85
+ src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
86
+
87
+ return src
88
+
89
+
90
+ ### RSU-7 ###
91
+ class RSU7(nn.Module):
92
+
93
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
94
+ super(RSU7,self).__init__()
95
+
96
+ self.in_ch = in_ch
97
+ self.mid_ch = mid_ch
98
+ self.out_ch = out_ch
99
+
100
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
101
+
102
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
103
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
104
+
105
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
106
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
107
+
108
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
109
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
110
+
111
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
112
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
113
+
114
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
115
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
116
+
117
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
118
+
119
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
120
+
121
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
122
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
123
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
124
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
125
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
126
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
127
+
128
+ def forward(self,x):
129
+ b, c, h, w = x.shape
130
+
131
+ hx = x
132
+ hxin = self.rebnconvin(hx)
133
+
134
+ hx1 = self.rebnconv1(hxin)
135
+ hx = self.pool1(hx1)
136
+
137
+ hx2 = self.rebnconv2(hx)
138
+ hx = self.pool2(hx2)
139
+
140
+ hx3 = self.rebnconv3(hx)
141
+ hx = self.pool3(hx3)
142
+
143
+ hx4 = self.rebnconv4(hx)
144
+ hx = self.pool4(hx4)
145
+
146
+ hx5 = self.rebnconv5(hx)
147
+ hx = self.pool5(hx5)
148
+
149
+ hx6 = self.rebnconv6(hx)
150
+
151
+ hx7 = self.rebnconv7(hx6)
152
+
153
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
154
+ hx6dup = _upsample_like(hx6d,hx5)
155
+
156
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
157
+ hx5dup = _upsample_like(hx5d,hx4)
158
+
159
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
160
+ hx4dup = _upsample_like(hx4d,hx3)
161
+
162
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
163
+ hx3dup = _upsample_like(hx3d,hx2)
164
+
165
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
166
+ hx2dup = _upsample_like(hx2d,hx1)
167
+
168
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
169
+
170
+ return hx1d + hxin
171
+
172
+
173
+ ### RSU-6 ###
174
+ class RSU6(nn.Module):
175
+
176
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
177
+ super(RSU6,self).__init__()
178
+
179
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
180
+
181
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
182
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
183
+
184
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
185
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
186
+
187
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
188
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
189
+
190
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
191
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
192
+
193
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
194
+
195
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
196
+
197
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
198
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
199
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
200
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
201
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
202
+
203
+ def forward(self,x):
204
+
205
+ hx = x
206
+
207
+ hxin = self.rebnconvin(hx)
208
+
209
+ hx1 = self.rebnconv1(hxin)
210
+ hx = self.pool1(hx1)
211
+
212
+ hx2 = self.rebnconv2(hx)
213
+ hx = self.pool2(hx2)
214
+
215
+ hx3 = self.rebnconv3(hx)
216
+ hx = self.pool3(hx3)
217
+
218
+ hx4 = self.rebnconv4(hx)
219
+ hx = self.pool4(hx4)
220
+
221
+ hx5 = self.rebnconv5(hx)
222
+
223
+ hx6 = self.rebnconv6(hx5)
224
+
225
+
226
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
227
+ hx5dup = _upsample_like(hx5d,hx4)
228
+
229
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
230
+ hx4dup = _upsample_like(hx4d,hx3)
231
+
232
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
233
+ hx3dup = _upsample_like(hx3d,hx2)
234
+
235
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
236
+ hx2dup = _upsample_like(hx2d,hx1)
237
+
238
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
239
+
240
+ return hx1d + hxin
241
+
242
+ ### RSU-5 ###
243
+ class RSU5(nn.Module):
244
+
245
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
246
+ super(RSU5,self).__init__()
247
+
248
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
249
+
250
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
251
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
252
+
253
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
254
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
255
+
256
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
257
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
258
+
259
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
260
+
261
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
262
+
263
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
264
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
265
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
266
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
267
+
268
+ def forward(self,x):
269
+
270
+ hx = x
271
+
272
+ hxin = self.rebnconvin(hx)
273
+
274
+ hx1 = self.rebnconv1(hxin)
275
+ hx = self.pool1(hx1)
276
+
277
+ hx2 = self.rebnconv2(hx)
278
+ hx = self.pool2(hx2)
279
+
280
+ hx3 = self.rebnconv3(hx)
281
+ hx = self.pool3(hx3)
282
+
283
+ hx4 = self.rebnconv4(hx)
284
+
285
+ hx5 = self.rebnconv5(hx4)
286
+
287
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
288
+ hx4dup = _upsample_like(hx4d,hx3)
289
+
290
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
291
+ hx3dup = _upsample_like(hx3d,hx2)
292
+
293
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
294
+ hx2dup = _upsample_like(hx2d,hx1)
295
+
296
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
297
+
298
+ return hx1d + hxin
299
+
300
+ ### RSU-4 ###
301
+ class RSU4(nn.Module):
302
+
303
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
304
+ super(RSU4,self).__init__()
305
+
306
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
307
+
308
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
309
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
310
+
311
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
312
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
313
+
314
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
315
+
316
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
317
+
318
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
319
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
320
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
321
+
322
+ def forward(self,x):
323
+
324
+ hx = x
325
+
326
+ hxin = self.rebnconvin(hx)
327
+
328
+ hx1 = self.rebnconv1(hxin)
329
+ hx = self.pool1(hx1)
330
+
331
+ hx2 = self.rebnconv2(hx)
332
+ hx = self.pool2(hx2)
333
+
334
+ hx3 = self.rebnconv3(hx)
335
+
336
+ hx4 = self.rebnconv4(hx3)
337
+
338
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
339
+ hx3dup = _upsample_like(hx3d,hx2)
340
+
341
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
342
+ hx2dup = _upsample_like(hx2d,hx1)
343
+
344
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
345
+
346
+ return hx1d + hxin
347
+
348
+ ### RSU-4F ###
349
+ class RSU4F(nn.Module):
350
+
351
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
352
+ super(RSU4F,self).__init__()
353
+
354
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
355
+
356
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
357
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
358
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
359
+
360
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
361
+
362
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
363
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
364
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
365
+
366
+ def forward(self,x):
367
+
368
+ hx = x
369
+
370
+ hxin = self.rebnconvin(hx)
371
+
372
+ hx1 = self.rebnconv1(hxin)
373
+ hx2 = self.rebnconv2(hx1)
374
+ hx3 = self.rebnconv3(hx2)
375
+
376
+ hx4 = self.rebnconv4(hx3)
377
+
378
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
379
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
380
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
381
+
382
+ return hx1d + hxin
383
+
384
+
385
+ class myrebnconv(nn.Module):
386
+ def __init__(self, in_ch=3,
387
+ out_ch=1,
388
+ kernel_size=3,
389
+ stride=1,
390
+ padding=1,
391
+ dilation=1,
392
+ groups=1):
393
+ super(myrebnconv,self).__init__()
394
+
395
+ self.conv = nn.Conv2d(in_ch,
396
+ out_ch,
397
+ kernel_size=kernel_size,
398
+ stride=stride,
399
+ padding=padding,
400
+ dilation=dilation,
401
+ groups=groups)
402
+ self.bn = nn.BatchNorm2d(out_ch)
403
+ self.rl = nn.ReLU(inplace=True)
404
+
405
+ def forward(self,x):
406
+ return self.rl(self.bn(self.conv(x)))
407
+
408
+
409
+ class ISNetGTEncoder(nn.Module):
410
+
411
+ def __init__(self,in_ch=1,out_ch=1):
412
+ super(ISNetGTEncoder,self).__init__()
413
+
414
+ self.conv_in = myrebnconv(in_ch,16,3,stride=2,padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
415
+
416
+ self.stage1 = RSU7(16,16,64)
417
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
418
+
419
+ self.stage2 = RSU6(64,16,64)
420
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
421
+
422
+ self.stage3 = RSU5(64,32,128)
423
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
424
+
425
+ self.stage4 = RSU4(128,32,256)
426
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
427
+
428
+ self.stage5 = RSU4F(256,64,512)
429
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
430
+
431
+ self.stage6 = RSU4F(512,64,512)
432
+
433
+
434
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
435
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
436
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
437
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
438
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
439
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
440
+
441
+ def compute_loss(self, preds, targets):
442
+
443
+ return muti_loss_fusion(preds,targets)
444
+
445
+ def forward(self,x):
446
+
447
+ hx = x
448
+
449
+ hxin = self.conv_in(hx)
450
+ # hx = self.pool_in(hxin)
451
+
452
+ #stage 1
453
+ hx1 = self.stage1(hxin)
454
+ hx = self.pool12(hx1)
455
+
456
+ #stage 2
457
+ hx2 = self.stage2(hx)
458
+ hx = self.pool23(hx2)
459
+
460
+ #stage 3
461
+ hx3 = self.stage3(hx)
462
+ hx = self.pool34(hx3)
463
+
464
+ #stage 4
465
+ hx4 = self.stage4(hx)
466
+ hx = self.pool45(hx4)
467
+
468
+ #stage 5
469
+ hx5 = self.stage5(hx)
470
+ hx = self.pool56(hx5)
471
+
472
+ #stage 6
473
+ hx6 = self.stage6(hx)
474
+
475
+
476
+ #side output
477
+ d1 = self.side1(hx1)
478
+ d1 = _upsample_like(d1,x)
479
+
480
+ d2 = self.side2(hx2)
481
+ d2 = _upsample_like(d2,x)
482
+
483
+ d3 = self.side3(hx3)
484
+ d3 = _upsample_like(d3,x)
485
+
486
+ d4 = self.side4(hx4)
487
+ d4 = _upsample_like(d4,x)
488
+
489
+ d5 = self.side5(hx5)
490
+ d5 = _upsample_like(d5,x)
491
+
492
+ d6 = self.side6(hx6)
493
+ d6 = _upsample_like(d6,x)
494
+
495
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
496
+
497
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1,hx2,hx3,hx4,hx5,hx6]
498
+
499
+ class ISNetDIS(nn.Module):
500
+
501
+ def __init__(self,in_ch=3,out_ch=1):
502
+ super(ISNetDIS,self).__init__()
503
+
504
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
505
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
506
+
507
+ self.stage1 = RSU7(64,32,64)
508
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
509
+
510
+ self.stage2 = RSU6(64,32,128)
511
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
512
+
513
+ self.stage3 = RSU5(128,64,256)
514
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
515
+
516
+ self.stage4 = RSU4(256,128,512)
517
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
518
+
519
+ self.stage5 = RSU4F(512,256,512)
520
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
521
+
522
+ self.stage6 = RSU4F(512,256,512)
523
+
524
+ # decoder
525
+ self.stage5d = RSU4F(1024,256,512)
526
+ self.stage4d = RSU4(1024,128,256)
527
+ self.stage3d = RSU5(512,64,128)
528
+ self.stage2d = RSU6(256,32,64)
529
+ self.stage1d = RSU7(128,16,64)
530
+
531
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
532
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
533
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
534
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
535
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
536
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
537
+
538
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
539
+
540
+ def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'):
541
+
542
+ # return muti_loss_fusion(preds,targets)
543
+ return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
544
+
545
+ def compute_loss(self, preds, targets):
546
+
547
+ # return muti_loss_fusion(preds,targets)
548
+ return muti_loss_fusion(preds, targets)
549
+
550
+ def forward(self,x):
551
+
552
+ hx = x
553
+
554
+ hxin = self.conv_in(hx)
555
+ hx = self.pool_in(hxin)
556
+
557
+ #stage 1
558
+ hx1 = self.stage1(hxin)
559
+ hx = self.pool12(hx1)
560
+
561
+ #stage 2
562
+ hx2 = self.stage2(hx)
563
+ hx = self.pool23(hx2)
564
+
565
+ #stage 3
566
+ hx3 = self.stage3(hx)
567
+ hx = self.pool34(hx3)
568
+
569
+ #stage 4
570
+ hx4 = self.stage4(hx)
571
+ hx = self.pool45(hx4)
572
+
573
+ #stage 5
574
+ hx5 = self.stage5(hx)
575
+ hx = self.pool56(hx5)
576
+
577
+ #stage 6
578
+ hx6 = self.stage6(hx)
579
+ hx6up = _upsample_like(hx6,hx5)
580
+
581
+ #-------------------- decoder --------------------
582
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
583
+ hx5dup = _upsample_like(hx5d,hx4)
584
+
585
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
586
+ hx4dup = _upsample_like(hx4d,hx3)
587
+
588
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
589
+ hx3dup = _upsample_like(hx3d,hx2)
590
+
591
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
592
+ hx2dup = _upsample_like(hx2d,hx1)
593
+
594
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
595
+
596
+
597
+ #side output
598
+ d1 = self.side1(hx1d)
599
+ d1 = _upsample_like(d1,x)
600
+
601
+ d2 = self.side2(hx2d)
602
+ d2 = _upsample_like(d2,x)
603
+
604
+ d3 = self.side3(hx3d)
605
+ d3 = _upsample_like(d3,x)
606
+
607
+ d4 = self.side4(hx4d)
608
+ d4 = _upsample_like(d4,x)
609
+
610
+ d5 = self.side5(hx5d)
611
+ d5 = _upsample_like(d5,x)
612
+
613
+ d6 = self.side6(hx6)
614
+ d6 = _upsample_like(d6,x)
615
+
616
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
617
+
618
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
619
+
620
+
621
+ resize = transforms.Resize(512)
622
+
623
+
624
+ def bg_removal(img):
625
+ # current_path = os.getcwd()
626
+ # print(current_path)
627
+ model_path = '/sensei-fs-3/users/sxiao/chart-dataset/bg_removal.pth' # the model path
628
+ input_size=[1024,1024]
629
+ net = ISNetDIS()
630
+ if torch.cuda.is_available():
631
+ net.load_state_dict(torch.load(model_path))
632
+ net = net.cuda()
633
+ else:
634
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
635
+ net.eval()
636
+ image_origin = img.convert("RGB")
637
+ im = np.array(image_origin)
638
+ if len(im.shape) < 3:
639
+ im = im[:, :, np.newaxis]
640
+ im_shp=im.shape[0:2]
641
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
642
+ im_tensor = F.upsample(torch.unsqueeze(im_tensor,0), input_size, mode="bilinear").type(torch.uint8)
643
+ image = torch.divide(im_tensor,255.0)
644
+ image = normalize(image, [0.5,0.5,0.5], [1.0,1.0,1.0])
645
+ image=image.cuda()
646
+ result=net(image)
647
+ result=torch.squeeze(F.upsample(result[0][0], im_shp,mode='bilinear'),0)
648
+ ma = torch.max(result)
649
+ mi = torch.min(result)
650
+ result = (result-mi)/(ma-mi)
651
+ a = (result[0] * 255).cpu().data.numpy().astype(np.uint8)
652
+ avg = np.average(a)
653
+ b = np.where(a > avg//2, 255, 0).astype(np.uint8)
654
+ mask = Image.fromarray(b)
655
+ image_origin.putalpha(mask)
656
+ # image_origin.save('output/mask/bg_removel.png')
657
+ return image_origin
658
+
659
+ # current_path = os.getcwd()
660
+ # print(current_path)
661
+ # img = Image.open(current_path+'/data/bar1.png')
662
+ # result = bg_removal(img)
663
+ # result.save('bg_removel.png')