aka7774 commited on
Commit
dcbc4e3
·
verified ·
1 Parent(s): fbbd4d4

Upload 9 files

Browse files
anime-seg/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ ## Anime Segmentation Models
6
+
7
+ models of [https://github.com/SkyTNT/anime-segmentation](https://github.com/SkyTNT/anime-segmentation)
anime-seg/isnetis.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c8f6b9a77386c54dcdbf55b6c917108c4bdf4328abca9152c7bce5727b74d18
3
+ size 204275908
anime-seg/isnetis.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f15622d853e8260172812b657053460e20806f04b9e05147d49af7bed31a6e99
3
+ size 176069933
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+ import pytorch_lightning as pl
8
+ from model import ISNetDIS, ISNetGTEncoder, U2NET, U2NET_full2, U2NET_lite2, MODNet
9
+
10
+ def get_mask(model, input_img):
11
+ h, w = input_img.shape[0], input_img.shape[1]
12
+ ph, pw = 0, 0
13
+ tmpImg = np.zeros([h, w, 3], dtype=np.float16)
14
+ tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255
15
+ tmpImg = tmpImg.transpose((2, 0, 1))
16
+ tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device)
17
+ with torch.no_grad():
18
+ pred = model(tmpImg)
19
+ pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
20
+ pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w, h))[:, :, np.newaxis]
21
+ return pred
22
+
23
+ def get_net(net_name):
24
+ if net_name == "isnet":
25
+ return ISNetDIS()
26
+ elif net_name == "isnet_is":
27
+ return ISNetDIS()
28
+ elif net_name == "isnet_gt":
29
+ return ISNetGTEncoder()
30
+ elif net_name == "u2net":
31
+ return U2NET_full2()
32
+ elif net_name == "u2netl":
33
+ return U2NET_lite2()
34
+ elif net_name == "modnet":
35
+ return MODNet()
36
+ raise NotImplemented
37
+
38
+ # from anime-segmentation.train
39
+ class AnimeSegmentation(pl.LightningModule):
40
+ def __init__(self, net_name):
41
+ super().__init__()
42
+ assert net_name in ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet"]
43
+ self.net = get_net(net_name)
44
+ if net_name == "isnet_is":
45
+ self.gt_encoder = get_net("isnet_gt")
46
+ for param in self.gt_encoder.parameters():
47
+ param.requires_grad = False
48
+ else:
49
+ self.gt_encoder = None
50
+
51
+ @classmethod
52
+ def try_load(cls, net_name, ckpt_path, map_location=None):
53
+ state_dict = torch.load(ckpt_path, map_location=map_location)
54
+ if "epoch" in state_dict:
55
+ return cls.load_from_checkpoint(ckpt_path, net_name=net_name, map_location=map_location)
56
+ else:
57
+ model = cls(net_name)
58
+ if any([k.startswith("net.") for k, v in state_dict.items()]):
59
+ model.load_state_dict(state_dict)
60
+ else:
61
+ model.net.load_state_dict(state_dict)
62
+ return model
63
+
64
+ def forward(self, x):
65
+ if isinstance(self.net, ISNetDIS):
66
+ return self.net(x)[0][0].sigmoid()
67
+ if isinstance(self.net, ISNetGTEncoder):
68
+ return self.net(x)[0][0].sigmoid()
69
+ elif isinstance(self.net, U2NET):
70
+ return self.net(x)[0].sigmoid()
71
+ elif isinstance(self.net, MODNet):
72
+ return self.net(x, True)[2]
73
+ raise NotImplemented
74
+
75
+ def animeseg(image):
76
+ if not image:
77
+ return None
78
+
79
+ model = AnimeSegmentation.try_load('isnet_is', 'anime-seg/isnetis.ckpt', 'cuda')
80
+ model.eval()
81
+ model.to('cuda')
82
+
83
+ img = np.array(image, dtype=np.uint8)
84
+ mask = get_mask(model, img)
85
+ img = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8)
86
+ return img
87
+
88
+ with gr.Blocks() as demo:
89
+ title = gr.Markdown('# katanuki')
90
+ with gr.Row():
91
+ src_image = gr.Image(label="Source", sources="upload", interactive=True, type="pil")
92
+ dst_image = gr.Image(label="Result", interactive=False, type="numpy")
93
+
94
+ src_image.change(
95
+ fn=animeseg,
96
+ inputs=[src_image],
97
+ outputs=[dst_image],
98
+ )
99
+
100
+ demo.launch()
model/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .u2net import U2NET_full
2
+ from .u2net import U2NET_full2
3
+ from .u2net import U2NET_lite
4
+ from .u2net import U2NET_lite2
5
+ from .u2net import U2NET
6
+ from .isnet import ISNetDIS, ISNetGTEncoder
7
+ from .modnet import MODNet
model/isnet.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codes are borrowed from
2
+ # https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import models
7
+ import torch.nn.functional as F
8
+
9
+ bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
10
+
11
+
12
+ def muti_loss_fusion(preds, target):
13
+ loss0 = 0.0
14
+ loss = 0.0
15
+
16
+ for i in range(0, len(preds)):
17
+ if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
18
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
19
+ loss = loss + bce_loss(preds[i], tmp_target)
20
+ else:
21
+ loss = loss + bce_loss(preds[i], target)
22
+ if i == 0:
23
+ loss0 = loss
24
+ return loss0, loss
25
+
26
+
27
+ fea_loss = nn.MSELoss(reduction="mean")
28
+ kl_loss = nn.KLDivLoss(reduction="mean")
29
+ l1_loss = nn.L1Loss(reduction="mean")
30
+ smooth_l1_loss = nn.SmoothL1Loss(reduction="mean")
31
+
32
+
33
+ def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'):
34
+ loss0 = 0.0
35
+ loss = 0.0
36
+
37
+ for i in range(0, len(preds)):
38
+ if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
39
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
40
+ loss = loss + bce_loss(preds[i], tmp_target)
41
+ else:
42
+ loss = loss + bce_loss(preds[i], target)
43
+ if i == 0:
44
+ loss0 = loss
45
+
46
+ for i in range(0, len(dfs)):
47
+ df = dfs[i]
48
+ fs_i = fs[i]
49
+ if mode == 'MSE':
50
+ loss = loss + fea_loss(df, fs_i) ### add the mse loss of features as additional constraints
51
+ elif mode == 'KL':
52
+ loss = loss + kl_loss(F.log_softmax(df, dim=1), F.softmax(fs_i, dim=1))
53
+ elif mode == 'MAE':
54
+ loss = loss + l1_loss(df, fs_i)
55
+ elif mode == 'SmoothL1':
56
+ loss = loss + smooth_l1_loss(df, fs_i)
57
+
58
+ return loss0, loss
59
+
60
+
61
+ class REBNCONV(nn.Module):
62
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
63
+ super(REBNCONV, self).__init__()
64
+
65
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride)
66
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
67
+ self.relu_s1 = nn.ReLU(inplace=True)
68
+
69
+ def forward(self, x):
70
+ hx = x
71
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
72
+
73
+ return xout
74
+
75
+
76
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
77
+ def _upsample_like(src, tar):
78
+ src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=False)
79
+
80
+ return src
81
+
82
+
83
+ ### RSU-7 ###
84
+ class RSU7(nn.Module):
85
+
86
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
87
+ super(RSU7, self).__init__()
88
+
89
+ self.in_ch = in_ch
90
+ self.mid_ch = mid_ch
91
+ self.out_ch = out_ch
92
+
93
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
94
+
95
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
96
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
97
+
98
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
99
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
100
+
101
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
102
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
103
+
104
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
105
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
106
+
107
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
108
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
109
+
110
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
111
+
112
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
113
+
114
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
115
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
116
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
117
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
118
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
119
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
120
+
121
+ def forward(self, x):
122
+ b, c, h, w = x.shape
123
+
124
+ hx = x
125
+ hxin = self.rebnconvin(hx)
126
+
127
+ hx1 = self.rebnconv1(hxin)
128
+ hx = self.pool1(hx1)
129
+
130
+ hx2 = self.rebnconv2(hx)
131
+ hx = self.pool2(hx2)
132
+
133
+ hx3 = self.rebnconv3(hx)
134
+ hx = self.pool3(hx3)
135
+
136
+ hx4 = self.rebnconv4(hx)
137
+ hx = self.pool4(hx4)
138
+
139
+ hx5 = self.rebnconv5(hx)
140
+ hx = self.pool5(hx5)
141
+
142
+ hx6 = self.rebnconv6(hx)
143
+
144
+ hx7 = self.rebnconv7(hx6)
145
+
146
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
147
+ hx6dup = _upsample_like(hx6d, hx5)
148
+
149
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
150
+ hx5dup = _upsample_like(hx5d, hx4)
151
+
152
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
153
+ hx4dup = _upsample_like(hx4d, hx3)
154
+
155
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
156
+ hx3dup = _upsample_like(hx3d, hx2)
157
+
158
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
159
+ hx2dup = _upsample_like(hx2d, hx1)
160
+
161
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
162
+
163
+ return hx1d + hxin
164
+
165
+
166
+ ### RSU-6 ###
167
+ class RSU6(nn.Module):
168
+
169
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
170
+ super(RSU6, self).__init__()
171
+
172
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
173
+
174
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
175
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
176
+
177
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
178
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
179
+
180
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
181
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
182
+
183
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
184
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
185
+
186
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
187
+
188
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
189
+
190
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
191
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
192
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
193
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
194
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
195
+
196
+ def forward(self, x):
197
+ hx = x
198
+
199
+ hxin = self.rebnconvin(hx)
200
+
201
+ hx1 = self.rebnconv1(hxin)
202
+ hx = self.pool1(hx1)
203
+
204
+ hx2 = self.rebnconv2(hx)
205
+ hx = self.pool2(hx2)
206
+
207
+ hx3 = self.rebnconv3(hx)
208
+ hx = self.pool3(hx3)
209
+
210
+ hx4 = self.rebnconv4(hx)
211
+ hx = self.pool4(hx4)
212
+
213
+ hx5 = self.rebnconv5(hx)
214
+
215
+ hx6 = self.rebnconv6(hx5)
216
+
217
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
218
+ hx5dup = _upsample_like(hx5d, hx4)
219
+
220
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
221
+ hx4dup = _upsample_like(hx4d, hx3)
222
+
223
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
224
+ hx3dup = _upsample_like(hx3d, hx2)
225
+
226
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
227
+ hx2dup = _upsample_like(hx2d, hx1)
228
+
229
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
230
+
231
+ return hx1d + hxin
232
+
233
+
234
+ ### RSU-5 ###
235
+ class RSU5(nn.Module):
236
+
237
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
238
+ super(RSU5, self).__init__()
239
+
240
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
241
+
242
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
243
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
244
+
245
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
246
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
247
+
248
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
249
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
250
+
251
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
252
+
253
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
254
+
255
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
256
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
257
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
258
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
259
+
260
+ def forward(self, x):
261
+ hx = x
262
+
263
+ hxin = self.rebnconvin(hx)
264
+
265
+ hx1 = self.rebnconv1(hxin)
266
+ hx = self.pool1(hx1)
267
+
268
+ hx2 = self.rebnconv2(hx)
269
+ hx = self.pool2(hx2)
270
+
271
+ hx3 = self.rebnconv3(hx)
272
+ hx = self.pool3(hx3)
273
+
274
+ hx4 = self.rebnconv4(hx)
275
+
276
+ hx5 = self.rebnconv5(hx4)
277
+
278
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
279
+ hx4dup = _upsample_like(hx4d, hx3)
280
+
281
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
282
+ hx3dup = _upsample_like(hx3d, hx2)
283
+
284
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
285
+ hx2dup = _upsample_like(hx2d, hx1)
286
+
287
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
288
+
289
+ return hx1d + hxin
290
+
291
+
292
+ ### RSU-4 ###
293
+ class RSU4(nn.Module):
294
+
295
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
296
+ super(RSU4, self).__init__()
297
+
298
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
299
+
300
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
301
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
302
+
303
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
304
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
305
+
306
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
307
+
308
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
309
+
310
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
311
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
312
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
313
+
314
+ def forward(self, x):
315
+ hx = x
316
+
317
+ hxin = self.rebnconvin(hx)
318
+
319
+ hx1 = self.rebnconv1(hxin)
320
+ hx = self.pool1(hx1)
321
+
322
+ hx2 = self.rebnconv2(hx)
323
+ hx = self.pool2(hx2)
324
+
325
+ hx3 = self.rebnconv3(hx)
326
+
327
+ hx4 = self.rebnconv4(hx3)
328
+
329
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
330
+ hx3dup = _upsample_like(hx3d, hx2)
331
+
332
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
333
+ hx2dup = _upsample_like(hx2d, hx1)
334
+
335
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
336
+
337
+ return hx1d + hxin
338
+
339
+
340
+ ### RSU-4F ###
341
+ class RSU4F(nn.Module):
342
+
343
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
344
+ super(RSU4F, self).__init__()
345
+
346
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
347
+
348
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
349
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
350
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
351
+
352
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
353
+
354
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
355
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
356
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
357
+
358
+ def forward(self, x):
359
+ hx = x
360
+
361
+ hxin = self.rebnconvin(hx)
362
+
363
+ hx1 = self.rebnconv1(hxin)
364
+ hx2 = self.rebnconv2(hx1)
365
+ hx3 = self.rebnconv3(hx2)
366
+
367
+ hx4 = self.rebnconv4(hx3)
368
+
369
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
370
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
371
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
372
+
373
+ return hx1d + hxin
374
+
375
+
376
+ class myrebnconv(nn.Module):
377
+ def __init__(self, in_ch=3,
378
+ out_ch=1,
379
+ kernel_size=3,
380
+ stride=1,
381
+ padding=1,
382
+ dilation=1,
383
+ groups=1):
384
+ super(myrebnconv, self).__init__()
385
+
386
+ self.conv = nn.Conv2d(in_ch,
387
+ out_ch,
388
+ kernel_size=kernel_size,
389
+ stride=stride,
390
+ padding=padding,
391
+ dilation=dilation,
392
+ groups=groups)
393
+ self.bn = nn.BatchNorm2d(out_ch)
394
+ self.rl = nn.ReLU(inplace=True)
395
+
396
+ def forward(self, x):
397
+ return self.rl(self.bn(self.conv(x)))
398
+
399
+
400
+ class ISNetGTEncoder(nn.Module):
401
+
402
+ def __init__(self, in_ch=1, out_ch=1):
403
+ super(ISNetGTEncoder, self).__init__()
404
+
405
+ self.conv_in = myrebnconv(in_ch, 16, 3, stride=2, padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
406
+
407
+ self.stage1 = RSU7(16, 16, 64)
408
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
409
+
410
+ self.stage2 = RSU6(64, 16, 64)
411
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
412
+
413
+ self.stage3 = RSU5(64, 32, 128)
414
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
415
+
416
+ self.stage4 = RSU4(128, 32, 256)
417
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
418
+
419
+ self.stage5 = RSU4F(256, 64, 512)
420
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
421
+
422
+ self.stage6 = RSU4F(512, 64, 512)
423
+
424
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
425
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
426
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
427
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
428
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
429
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
430
+
431
+ @staticmethod
432
+ def compute_loss(args):
433
+ preds, targets = args
434
+ return muti_loss_fusion(preds, targets)
435
+
436
+ def forward(self, x):
437
+ hx = x
438
+
439
+ hxin = self.conv_in(hx)
440
+ # hx = self.pool_in(hxin)
441
+
442
+ # stage 1
443
+ hx1 = self.stage1(hxin)
444
+ hx = self.pool12(hx1)
445
+
446
+ # stage 2
447
+ hx2 = self.stage2(hx)
448
+ hx = self.pool23(hx2)
449
+
450
+ # stage 3
451
+ hx3 = self.stage3(hx)
452
+ hx = self.pool34(hx3)
453
+
454
+ # stage 4
455
+ hx4 = self.stage4(hx)
456
+ hx = self.pool45(hx4)
457
+
458
+ # stage 5
459
+ hx5 = self.stage5(hx)
460
+ hx = self.pool56(hx5)
461
+
462
+ # stage 6
463
+ hx6 = self.stage6(hx)
464
+
465
+ # side output
466
+ d1 = self.side1(hx1)
467
+ d1 = _upsample_like(d1, x)
468
+
469
+ d2 = self.side2(hx2)
470
+ d2 = _upsample_like(d2, x)
471
+
472
+ d3 = self.side3(hx3)
473
+ d3 = _upsample_like(d3, x)
474
+
475
+ d4 = self.side4(hx4)
476
+ d4 = _upsample_like(d4, x)
477
+
478
+ d5 = self.side5(hx5)
479
+ d5 = _upsample_like(d5, x)
480
+
481
+ d6 = self.side6(hx6)
482
+ d6 = _upsample_like(d6, x)
483
+
484
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
485
+
486
+ # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1, hx2, hx3, hx4, hx5, hx6]
487
+ return [d1, d2, d3, d4, d5, d6], [hx1, hx2, hx3, hx4, hx5, hx6]
488
+
489
+
490
+ class ISNetDIS(nn.Module):
491
+
492
+ def __init__(self, in_ch=3, out_ch=1):
493
+ super(ISNetDIS, self).__init__()
494
+
495
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
496
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
497
+
498
+ self.stage1 = RSU7(64, 32, 64)
499
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
500
+
501
+ self.stage2 = RSU6(64, 32, 128)
502
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
503
+
504
+ self.stage3 = RSU5(128, 64, 256)
505
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
506
+
507
+ self.stage4 = RSU4(256, 128, 512)
508
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
509
+
510
+ self.stage5 = RSU4F(512, 256, 512)
511
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
512
+
513
+ self.stage6 = RSU4F(512, 256, 512)
514
+
515
+ # decoder
516
+ self.stage5d = RSU4F(1024, 256, 512)
517
+ self.stage4d = RSU4(1024, 128, 256)
518
+ self.stage3d = RSU5(512, 64, 128)
519
+ self.stage2d = RSU6(256, 32, 64)
520
+ self.stage1d = RSU7(128, 16, 64)
521
+
522
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
523
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
524
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
525
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
526
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
527
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
528
+
529
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
530
+
531
+ @staticmethod
532
+ def compute_loss_kl(preds, targets, dfs, fs, mode='MSE'):
533
+ return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
534
+
535
+ @staticmethod
536
+ def compute_loss(args):
537
+ if len(args) == 3:
538
+ ds, dfs, labels = args
539
+ return muti_loss_fusion(ds, labels)
540
+ else:
541
+ ds, dfs, labels, fs = args
542
+ return muti_loss_fusion_kl(ds, labels, dfs, fs, mode="MSE")
543
+
544
+ def forward(self, x):
545
+ hx = x
546
+
547
+ hxin = self.conv_in(hx)
548
+ hx = self.pool_in(hxin)
549
+
550
+ # stage 1
551
+ hx1 = self.stage1(hxin)
552
+ hx = self.pool12(hx1)
553
+
554
+ # stage 2
555
+ hx2 = self.stage2(hx)
556
+ hx = self.pool23(hx2)
557
+
558
+ # stage 3
559
+ hx3 = self.stage3(hx)
560
+ hx = self.pool34(hx3)
561
+
562
+ # stage 4
563
+ hx4 = self.stage4(hx)
564
+ hx = self.pool45(hx4)
565
+
566
+ # stage 5
567
+ hx5 = self.stage5(hx)
568
+ hx = self.pool56(hx5)
569
+
570
+ # stage 6
571
+ hx6 = self.stage6(hx)
572
+ hx6up = _upsample_like(hx6, hx5)
573
+
574
+ # -------------------- decoder --------------------
575
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
576
+ hx5dup = _upsample_like(hx5d, hx4)
577
+
578
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
579
+ hx4dup = _upsample_like(hx4d, hx3)
580
+
581
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
582
+ hx3dup = _upsample_like(hx3d, hx2)
583
+
584
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
585
+ hx2dup = _upsample_like(hx2d, hx1)
586
+
587
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
588
+
589
+ # side output
590
+ d1 = self.side1(hx1d)
591
+ d1 = _upsample_like(d1, x)
592
+
593
+ d2 = self.side2(hx2d)
594
+ d2 = _upsample_like(d2, x)
595
+
596
+ d3 = self.side3(hx3d)
597
+ d3 = _upsample_like(d3, x)
598
+
599
+ d4 = self.side4(hx4d)
600
+ d4 = _upsample_like(d4, x)
601
+
602
+ d5 = self.side5(hx5d)
603
+ d5 = _upsample_like(d5, x)
604
+
605
+ d6 = self.side6(hx6)
606
+ d6 = _upsample_like(d6, x)
607
+
608
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
609
+
610
+ # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
611
+ return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
model/modnet.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codes are borrowed from
2
+ # https://github.com/ZHKKKe/MODNet/blob/master/src/trainer.py
3
+ # https://github.com/ZHKKKe/MODNet/blob/master/src/models/backbones/mobilenetv2.py
4
+ # https://github.com/ZHKKKe/MODNet/blob/master/src/models/modnet.py
5
+
6
+ import numpy as np
7
+ import scipy
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import os
12
+ import math
13
+ import torch
14
+ from scipy.ndimage import gaussian_filter
15
+
16
+
17
+ # ----------------------------------------------------------------------------------
18
+ # Loss Functions
19
+ # ----------------------------------------------------------------------------------
20
+
21
+
22
+ class GaussianBlurLayer(nn.Module):
23
+ """ Add Gaussian Blur to a 4D tensors
24
+ This layer takes a 4D tensor of {N, C, H, W} as input.
25
+ The Gaussian blur will be performed in given channel number (C) splitly.
26
+ """
27
+
28
+ def __init__(self, channels, kernel_size):
29
+ """
30
+ Arguments:
31
+ channels (int): Channel for input tensor
32
+ kernel_size (int): Size of the kernel used in blurring
33
+ """
34
+
35
+ super(GaussianBlurLayer, self).__init__()
36
+ self.channels = channels
37
+ self.kernel_size = kernel_size
38
+ assert self.kernel_size % 2 != 0
39
+
40
+ self.op = nn.Sequential(
41
+ nn.ReflectionPad2d(math.floor(self.kernel_size / 2)),
42
+ nn.Conv2d(channels, channels, self.kernel_size,
43
+ stride=1, padding=0, bias=None, groups=channels)
44
+ )
45
+
46
+ self._init_kernel()
47
+
48
+ def forward(self, x):
49
+ """
50
+ Arguments:
51
+ x (torch.Tensor): input 4D tensor
52
+ Returns:
53
+ torch.Tensor: Blurred version of the input
54
+ """
55
+
56
+ if not len(list(x.shape)) == 4:
57
+ print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
58
+ exit()
59
+ elif not x.shape[1] == self.channels:
60
+ print('In \'GaussianBlurLayer\', the required channel ({0}) is'
61
+ 'not the same as input ({1})\n'.format(self.channels, x.shape[1]))
62
+ exit()
63
+
64
+ return self.op(x)
65
+
66
+ def _init_kernel(self):
67
+ sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
68
+
69
+ n = np.zeros((self.kernel_size, self.kernel_size))
70
+ i = math.floor(self.kernel_size / 2)
71
+ n[i, i] = 1
72
+ kernel = gaussian_filter(n, sigma)
73
+
74
+ for name, param in self.named_parameters():
75
+ param.data.copy_(torch.from_numpy(kernel))
76
+ param.requires_grad = False
77
+
78
+
79
+ blurer = GaussianBlurLayer(1, 3)
80
+
81
+
82
+ def loss_func(pred_semantic, pred_detail, pred_matte, image, trimap, gt_matte,
83
+ semantic_scale=10.0, detail_scale=10.0, matte_scale=1.0):
84
+ """ loss of MODNet
85
+ Arguments:
86
+ blurer: GaussianBlurLayer
87
+ pred_semantic: model output
88
+ pred_detail: model output
89
+ pred_matte: model output
90
+ image : input RGB image ts pixel values should be normalized
91
+ trimap : trimap used to calculate the losses
92
+ its pixel values can be 0, 0.5, or 1
93
+ (foreground=1, background=0, unknown=0.5)
94
+ gt_matte: ground truth alpha matte its pixel values are between [0, 1]
95
+ semantic_scale (float): scale of the semantic loss
96
+ NOTE: please adjust according to your dataset
97
+ detail_scale (float): scale of the detail loss
98
+ NOTE: please adjust according to your dataset
99
+ matte_scale (float): scale of the matte loss
100
+ NOTE: please adjust according to your dataset
101
+
102
+ Returns:
103
+ semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch]
104
+ detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch]
105
+ matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch]
106
+ """
107
+
108
+ trimap = trimap.float()
109
+ # calculate the boundary mask from the trimap
110
+ boundaries = (trimap < 0.5) + (trimap > 0.5)
111
+
112
+ # calculate the semantic loss
113
+ gt_semantic = F.interpolate(gt_matte, scale_factor=1 / 16, mode='bilinear')
114
+ gt_semantic = blurer(gt_semantic)
115
+ semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
116
+ semantic_loss = semantic_scale * semantic_loss
117
+
118
+ # calculate the detail loss
119
+ pred_boundary_detail = torch.where(boundaries, trimap, pred_detail.float())
120
+ gt_detail = torch.where(boundaries, trimap, gt_matte.float())
121
+ detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail.float()))
122
+ detail_loss = detail_scale * detail_loss
123
+
124
+ # calculate the matte loss
125
+ pred_boundary_matte = torch.where(boundaries, trimap, pred_matte.float())
126
+ matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
127
+ matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \
128
+ + 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte)
129
+ matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss)
130
+ matte_loss = matte_scale * matte_loss
131
+
132
+ return semantic_loss, detail_loss, matte_loss
133
+
134
+
135
+ # ------------------------------------------------------------------------------
136
+ # Useful functions
137
+ # ------------------------------------------------------------------------------
138
+
139
+ def _make_divisible(v, divisor, min_value=None):
140
+ if min_value is None:
141
+ min_value = divisor
142
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
143
+ # Make sure that round down does not go down by more than 10%.
144
+ if new_v < 0.9 * v:
145
+ new_v += divisor
146
+ return new_v
147
+
148
+
149
+ def conv_bn(inp, oup, stride):
150
+ return nn.Sequential(
151
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
152
+ nn.BatchNorm2d(oup),
153
+ nn.ReLU6(inplace=True)
154
+ )
155
+
156
+
157
+ def conv_1x1_bn(inp, oup):
158
+ return nn.Sequential(
159
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
160
+ nn.BatchNorm2d(oup),
161
+ nn.ReLU6(inplace=True)
162
+ )
163
+
164
+
165
+ # ------------------------------------------------------------------------------
166
+ # Class of Inverted Residual block
167
+ # ------------------------------------------------------------------------------
168
+
169
+ class InvertedResidual(nn.Module):
170
+ def __init__(self, inp, oup, stride, expansion, dilation=1):
171
+ super(InvertedResidual, self).__init__()
172
+ self.stride = stride
173
+ assert stride in [1, 2]
174
+
175
+ hidden_dim = round(inp * expansion)
176
+ self.use_res_connect = self.stride == 1 and inp == oup
177
+
178
+ if expansion == 1:
179
+ self.conv = nn.Sequential(
180
+ # dw
181
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
182
+ nn.BatchNorm2d(hidden_dim),
183
+ nn.ReLU6(inplace=True),
184
+ # pw-linear
185
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
186
+ nn.BatchNorm2d(oup),
187
+ )
188
+ else:
189
+ self.conv = nn.Sequential(
190
+ # pw
191
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
192
+ nn.BatchNorm2d(hidden_dim),
193
+ nn.ReLU6(inplace=True),
194
+ # dw
195
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
196
+ nn.BatchNorm2d(hidden_dim),
197
+ nn.ReLU6(inplace=True),
198
+ # pw-linear
199
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
200
+ nn.BatchNorm2d(oup),
201
+ )
202
+
203
+ def forward(self, x):
204
+ if self.use_res_connect:
205
+ return x + self.conv(x)
206
+ else:
207
+ return self.conv(x)
208
+
209
+
210
+ # ------------------------------------------------------------------------------
211
+ # Class of MobileNetV2
212
+ # ------------------------------------------------------------------------------
213
+
214
+ class MobileNetV2(nn.Module):
215
+ def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000):
216
+ super(MobileNetV2, self).__init__()
217
+ self.in_channels = in_channels
218
+ self.num_classes = num_classes
219
+ input_channel = 32
220
+ last_channel = 1280
221
+ interverted_residual_setting = [
222
+ # t, c, n, s
223
+ [1, 16, 1, 1],
224
+ [expansion, 24, 2, 2],
225
+ [expansion, 32, 3, 2],
226
+ [expansion, 64, 4, 2],
227
+ [expansion, 96, 3, 1],
228
+ [expansion, 160, 3, 2],
229
+ [expansion, 320, 1, 1],
230
+ ]
231
+
232
+ # building first layer
233
+ input_channel = _make_divisible(input_channel * alpha, 8)
234
+ self.last_channel = _make_divisible(last_channel * alpha, 8) if alpha > 1.0 else last_channel
235
+ self.features = [conv_bn(self.in_channels, input_channel, 2)]
236
+
237
+ # building inverted residual blocks
238
+ for t, c, n, s in interverted_residual_setting:
239
+ output_channel = _make_divisible(int(c * alpha), 8)
240
+ for i in range(n):
241
+ if i == 0:
242
+ self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
243
+ else:
244
+ self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
245
+ input_channel = output_channel
246
+
247
+ # building last several layers
248
+ self.features.append(conv_1x1_bn(input_channel, self.last_channel))
249
+
250
+ # make it nn.Sequential
251
+ self.features = nn.Sequential(*self.features)
252
+
253
+ # building classifier
254
+ if self.num_classes is not None:
255
+ self.classifier = nn.Sequential(
256
+ nn.Dropout(0.2),
257
+ nn.Linear(self.last_channel, num_classes),
258
+ )
259
+
260
+ # Initialize weights
261
+ self._init_weights()
262
+
263
+ def forward(self, x):
264
+ # Stage1
265
+ x = self.features[0](x)
266
+ x = self.features[1](x)
267
+ # Stage2
268
+ x = self.features[2](x)
269
+ x = self.features[3](x)
270
+ # Stage3
271
+ x = self.features[4](x)
272
+ x = self.features[5](x)
273
+ x = self.features[6](x)
274
+ # Stage4
275
+ x = self.features[7](x)
276
+ x = self.features[8](x)
277
+ x = self.features[9](x)
278
+ x = self.features[10](x)
279
+ x = self.features[11](x)
280
+ x = self.features[12](x)
281
+ x = self.features[13](x)
282
+ # Stage5
283
+ x = self.features[14](x)
284
+ x = self.features[15](x)
285
+ x = self.features[16](x)
286
+ x = self.features[17](x)
287
+ x = self.features[18](x)
288
+
289
+ # Classification
290
+ if self.num_classes is not None:
291
+ x = x.mean(dim=(2, 3))
292
+ x = self.classifier(x)
293
+
294
+ # Output
295
+ return x
296
+
297
+ def _load_pretrained_model(self, pretrained_file):
298
+ pretrain_dict = torch.load(pretrained_file, map_location='cpu')
299
+ model_dict = {}
300
+ state_dict = self.state_dict()
301
+ print("[MobileNetV2] Loading pretrained model...")
302
+ for k, v in pretrain_dict.items():
303
+ if k in state_dict:
304
+ model_dict[k] = v
305
+ else:
306
+ print(k, "is ignored")
307
+ state_dict.update(model_dict)
308
+ self.load_state_dict(state_dict)
309
+
310
+ def _init_weights(self):
311
+ for m in self.modules():
312
+ if isinstance(m, nn.Conv2d):
313
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
314
+ m.weight.data.normal_(0, math.sqrt(2. / n))
315
+ if m.bias is not None:
316
+ m.bias.data.zero_()
317
+ elif isinstance(m, nn.BatchNorm2d):
318
+ m.weight.data.fill_(1)
319
+ m.bias.data.zero_()
320
+ elif isinstance(m, nn.Linear):
321
+ n = m.weight.size(1)
322
+ m.weight.data.normal_(0, 0.01)
323
+ m.bias.data.zero_()
324
+
325
+
326
+ class BaseBackbone(nn.Module):
327
+ """ Superclass of Replaceable Backbone Model for Semantic Estimation
328
+ """
329
+
330
+ def __init__(self, in_channels):
331
+ super(BaseBackbone, self).__init__()
332
+ self.in_channels = in_channels
333
+
334
+ self.model = None
335
+ self.enc_channels = []
336
+
337
+ def forward(self, x):
338
+ raise NotImplementedError
339
+
340
+ def load_pretrained_ckpt(self):
341
+ raise NotImplementedError
342
+
343
+
344
+ class MobileNetV2Backbone(BaseBackbone):
345
+ """ MobileNetV2 Backbone
346
+ """
347
+
348
+ def __init__(self, in_channels):
349
+ super(MobileNetV2Backbone, self).__init__(in_channels)
350
+
351
+ self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None)
352
+ self.enc_channels = [16, 24, 32, 96, 1280]
353
+
354
+ def forward(self, x):
355
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
356
+ x = self.model.features[0](x)
357
+ x = self.model.features[1](x)
358
+ enc2x = x
359
+
360
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
361
+ x = self.model.features[2](x)
362
+ x = self.model.features[3](x)
363
+ enc4x = x
364
+
365
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
366
+ x = self.model.features[4](x)
367
+ x = self.model.features[5](x)
368
+ x = self.model.features[6](x)
369
+ enc8x = x
370
+
371
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
372
+ x = self.model.features[7](x)
373
+ x = self.model.features[8](x)
374
+ x = self.model.features[9](x)
375
+ x = self.model.features[10](x)
376
+ x = self.model.features[11](x)
377
+ x = self.model.features[12](x)
378
+ x = self.model.features[13](x)
379
+ enc16x = x
380
+
381
+ # x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
382
+ x = self.model.features[14](x)
383
+ x = self.model.features[15](x)
384
+ x = self.model.features[16](x)
385
+ x = self.model.features[17](x)
386
+ x = self.model.features[18](x)
387
+ enc32x = x
388
+ return [enc2x, enc4x, enc8x, enc16x, enc32x]
389
+
390
+ def load_pretrained_ckpt(self):
391
+ # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
392
+ ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt'
393
+ if not os.path.exists(ckpt_path):
394
+ print('cannot find the pretrained mobilenetv2 backbone')
395
+ exit()
396
+
397
+ ckpt = torch.load(ckpt_path)
398
+ self.model.load_state_dict(ckpt)
399
+
400
+
401
+ SUPPORTED_BACKBONES = {
402
+ 'mobilenetv2': MobileNetV2Backbone,
403
+ }
404
+
405
+
406
+ # ------------------------------------------------------------------------------
407
+ # MODNet Basic Modules
408
+ # ------------------------------------------------------------------------------
409
+
410
+ class IBNorm(nn.Module):
411
+ """ Combine Instance Norm and Batch Norm into One Layer
412
+ """
413
+
414
+ def __init__(self, in_channels):
415
+ super(IBNorm, self).__init__()
416
+ in_channels = in_channels
417
+ self.bnorm_channels = int(in_channels / 2)
418
+ self.inorm_channels = in_channels - self.bnorm_channels
419
+
420
+ self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
421
+ self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
422
+
423
+ def forward(self, x):
424
+ bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
425
+ in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())
426
+
427
+ return torch.cat((bn_x, in_x), 1)
428
+
429
+
430
+ class Conv2dIBNormRelu(nn.Module):
431
+ """ Convolution + IBNorm + ReLu
432
+ """
433
+
434
+ def __init__(self, in_channels, out_channels, kernel_size,
435
+ stride=1, padding=0, dilation=1, groups=1, bias=True,
436
+ with_ibn=True, with_relu=True):
437
+ super(Conv2dIBNormRelu, self).__init__()
438
+
439
+ layers = [
440
+ nn.Conv2d(in_channels, out_channels, kernel_size,
441
+ stride=stride, padding=padding, dilation=dilation,
442
+ groups=groups, bias=bias)
443
+ ]
444
+
445
+ if with_ibn:
446
+ layers.append(IBNorm(out_channels))
447
+ if with_relu:
448
+ layers.append(nn.ReLU(inplace=True))
449
+
450
+ self.layers = nn.Sequential(*layers)
451
+
452
+ def forward(self, x):
453
+ return self.layers(x)
454
+
455
+
456
+ class SEBlock(nn.Module):
457
+ """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
458
+ """
459
+
460
+ def __init__(self, in_channels, out_channels, reduction=1):
461
+ super(SEBlock, self).__init__()
462
+ self.pool = nn.AdaptiveAvgPool2d(1)
463
+ self.fc = nn.Sequential(
464
+ nn.Linear(in_channels, int(in_channels // reduction), bias=False),
465
+ nn.ReLU(inplace=True),
466
+ nn.Linear(int(in_channels // reduction), out_channels, bias=False),
467
+ nn.Sigmoid()
468
+ )
469
+
470
+ def forward(self, x):
471
+ b, c, _, _ = x.size()
472
+ w = self.pool(x).view(b, c)
473
+ w = self.fc(w).view(b, c, 1, 1)
474
+
475
+ return x * w.expand_as(x)
476
+
477
+
478
+ # ------------------------------------------------------------------------------
479
+ # MODNet Branches
480
+ # ------------------------------------------------------------------------------
481
+
482
+ class LRBranch(nn.Module):
483
+ """ Low Resolution Branch of MODNet
484
+ """
485
+
486
+ def __init__(self, backbone):
487
+ super(LRBranch, self).__init__()
488
+
489
+ enc_channels = backbone.enc_channels
490
+
491
+ self.backbone = backbone
492
+ self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
493
+ self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
494
+ self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
495
+ self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False,
496
+ with_relu=False)
497
+
498
+ def forward(self, img, inference):
499
+ enc_features = self.backbone.forward(img)
500
+ enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]
501
+
502
+ enc32x = self.se_block(enc32x)
503
+ lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
504
+ lr16x = self.conv_lr16x(lr16x)
505
+ lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
506
+ lr8x = self.conv_lr8x(lr8x)
507
+
508
+ pred_semantic = None
509
+ if not inference:
510
+ lr = self.conv_lr(lr8x)
511
+ pred_semantic = torch.sigmoid(lr)
512
+
513
+ return pred_semantic, lr8x, [enc2x, enc4x]
514
+
515
+
516
+ class HRBranch(nn.Module):
517
+ """ High Resolution Branch of MODNet
518
+ """
519
+
520
+ def __init__(self, hr_channels, enc_channels):
521
+ super(HRBranch, self).__init__()
522
+
523
+ self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
524
+ self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)
525
+
526
+ self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
527
+ self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
528
+
529
+ self.conv_hr4x = nn.Sequential(
530
+ Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
531
+ Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
532
+ Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
533
+ )
534
+
535
+ self.conv_hr2x = nn.Sequential(
536
+ Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
537
+ Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
538
+ Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
539
+ Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
540
+ )
541
+
542
+ self.conv_hr = nn.Sequential(
543
+ Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
544
+ Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
545
+ )
546
+
547
+ def forward(self, img, enc2x, enc4x, lr8x, inference):
548
+ img2x = F.interpolate(img, scale_factor=1 / 2, mode='bilinear', align_corners=False)
549
+ img4x = F.interpolate(img, scale_factor=1 / 4, mode='bilinear', align_corners=False)
550
+
551
+ enc2x = self.tohr_enc2x(enc2x)
552
+ hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))
553
+
554
+ enc4x = self.tohr_enc4x(enc4x)
555
+ hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))
556
+
557
+ lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
558
+ hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))
559
+
560
+ hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
561
+ hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))
562
+
563
+ pred_detail = None
564
+ if not inference:
565
+ hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
566
+ hr = self.conv_hr(torch.cat((hr, img), dim=1))
567
+ pred_detail = torch.sigmoid(hr)
568
+
569
+ return pred_detail, hr2x
570
+
571
+
572
+ class FusionBranch(nn.Module):
573
+ """ Fusion Branch of MODNet
574
+ """
575
+
576
+ def __init__(self, hr_channels, enc_channels):
577
+ super(FusionBranch, self).__init__()
578
+ self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
579
+
580
+ self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
581
+ self.conv_f = nn.Sequential(
582
+ Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
583
+ Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
584
+ )
585
+
586
+ def forward(self, img, lr8x, hr2x):
587
+ lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
588
+ lr4x = self.conv_lr4x(lr4x)
589
+ lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)
590
+
591
+ f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
592
+ f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
593
+ f = self.conv_f(torch.cat((f, img), dim=1))
594
+ pred_matte = torch.sigmoid(f)
595
+
596
+ return pred_matte
597
+
598
+
599
+ # ------------------------------------------------------------------------------
600
+ # MODNet
601
+ # ------------------------------------------------------------------------------
602
+
603
+ class MODNet(nn.Module):
604
+ """ Architecture of MODNet
605
+ """
606
+
607
+ def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=False):
608
+ super(MODNet, self).__init__()
609
+
610
+ self.in_channels = in_channels
611
+ self.hr_channels = hr_channels
612
+ self.backbone_arch = backbone_arch
613
+ self.backbone_pretrained = backbone_pretrained
614
+
615
+ self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)
616
+
617
+ self.lr_branch = LRBranch(self.backbone)
618
+ self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
619
+ self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)
620
+
621
+ for m in self.modules():
622
+ if isinstance(m, nn.Conv2d):
623
+ self._init_conv(m)
624
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
625
+ self._init_norm(m)
626
+
627
+ if self.backbone_pretrained:
628
+ self.backbone.load_pretrained_ckpt()
629
+
630
+ def forward(self, img, inference):
631
+ pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
632
+ pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
633
+ pred_matte = self.f_branch(img, lr8x, hr2x)
634
+
635
+ return pred_semantic, pred_detail, pred_matte
636
+
637
+ @staticmethod
638
+ def compute_loss(args):
639
+ pred_semantic, pred_detail, pred_matte, image, trimap, gt_matte = args
640
+ semantic_loss, detail_loss, matte_loss = loss_func(pred_semantic, pred_detail, pred_matte,
641
+ image, trimap, gt_matte)
642
+ loss = semantic_loss + detail_loss + matte_loss
643
+ return matte_loss, loss
644
+
645
+ def freeze_norm(self):
646
+ norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
647
+ for m in self.modules():
648
+ for n in norm_types:
649
+ if isinstance(m, n):
650
+ m.eval()
651
+ continue
652
+
653
+ def _init_conv(self, conv):
654
+ nn.init.kaiming_uniform_(
655
+ conv.weight, a=0, mode='fan_in', nonlinearity='relu')
656
+ if conv.bias is not None:
657
+ nn.init.constant_(conv.bias, 0)
658
+
659
+ def _init_norm(self, norm):
660
+ if norm.weight is not None:
661
+ nn.init.constant_(norm.weight, 1)
662
+ nn.init.constant_(norm.bias, 0)
663
+
664
+ def _apply(self, fn):
665
+ super(MODNet, self)._apply(fn)
666
+ blurer._apply(fn) # let blurer's device same as modnet
667
+ return self
model/u2net.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codes are borrowed from
2
+ # https://github.com/xuebinqin/U-2-Net/blob/master/model/u2net_refactor.py
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ __all__ = ['U2NET_full', 'U2NET_full2', 'U2NET_lite', 'U2NET_lite2', "U2NET"]
10
+
11
+ bce_loss = nn.BCEWithLogitsLoss(reduction='mean')
12
+
13
+
14
+ def _upsample_like(x, size):
15
+ return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
16
+
17
+
18
+ def _size_map(x, height):
19
+ # {height: size} for Upsample
20
+ size = list(x.shape[-2:])
21
+ sizes = {}
22
+ for h in range(1, height):
23
+ sizes[h] = size
24
+ size = [math.ceil(w / 2) for w in size]
25
+ return sizes
26
+
27
+
28
+ class REBNCONV(nn.Module):
29
+ def __init__(self, in_ch=3, out_ch=3, dilate=1):
30
+ super(REBNCONV, self).__init__()
31
+
32
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate)
33
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
34
+ self.relu_s1 = nn.ReLU(inplace=True)
35
+
36
+ def forward(self, x):
37
+ return self.relu_s1(self.bn_s1(self.conv_s1(x)))
38
+
39
+
40
+ class RSU(nn.Module):
41
+ def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
42
+ super(RSU, self).__init__()
43
+ self.name = name
44
+ self.height = height
45
+ self.dilated = dilated
46
+ self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
47
+
48
+ def forward(self, x):
49
+ sizes = _size_map(x, self.height)
50
+ x = self.rebnconvin(x)
51
+
52
+ # U-Net like symmetric encoder-decoder structure
53
+ def unet(x, height=1):
54
+ if height < self.height:
55
+ x1 = getattr(self, f'rebnconv{height}')(x)
56
+ if not self.dilated and height < self.height - 1:
57
+ x2 = unet(getattr(self, 'downsample')(x1), height + 1)
58
+ else:
59
+ x2 = unet(x1, height + 1)
60
+
61
+ x = getattr(self, f'rebnconv{height}d')(torch.cat((x2, x1), 1))
62
+ return _upsample_like(x, sizes[height - 1]) if not self.dilated and height > 1 else x
63
+ else:
64
+ return getattr(self, f'rebnconv{height}')(x)
65
+
66
+ return x + unet(x)
67
+
68
+ def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
69
+ self.add_module('rebnconvin', REBNCONV(in_ch, out_ch))
70
+ self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
71
+
72
+ self.add_module(f'rebnconv1', REBNCONV(out_ch, mid_ch))
73
+ self.add_module(f'rebnconv1d', REBNCONV(mid_ch * 2, out_ch))
74
+
75
+ for i in range(2, height):
76
+ dilate = 1 if not dilated else 2 ** (i - 1)
77
+ self.add_module(f'rebnconv{i}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
78
+ self.add_module(f'rebnconv{i}d', REBNCONV(mid_ch * 2, mid_ch, dilate=dilate))
79
+
80
+ dilate = 2 if not dilated else 2 ** (height - 1)
81
+ self.add_module(f'rebnconv{height}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
82
+
83
+
84
+ class U2NET(nn.Module):
85
+ def __init__(self, cfgs, out_ch):
86
+ super(U2NET, self).__init__()
87
+ self.out_ch = out_ch
88
+ self._make_layers(cfgs)
89
+
90
+ def forward(self, x):
91
+ sizes = _size_map(x, self.height)
92
+ maps = [] # storage for maps
93
+
94
+ # side saliency map
95
+ def unet(x, height=1):
96
+ if height < 6:
97
+ x1 = getattr(self, f'stage{height}')(x)
98
+ x2 = unet(getattr(self, 'downsample')(x1), height + 1)
99
+ x = getattr(self, f'stage{height}d')(torch.cat((x2, x1), 1))
100
+ side(x, height)
101
+ return _upsample_like(x, sizes[height - 1]) if height > 1 else x
102
+ else:
103
+ x = getattr(self, f'stage{height}')(x)
104
+ side(x, height)
105
+ return _upsample_like(x, sizes[height - 1])
106
+
107
+ def side(x, h):
108
+ # side output saliency map (before sigmoid)
109
+ x = getattr(self, f'side{h}')(x)
110
+ x = _upsample_like(x, sizes[1])
111
+ maps.append(x)
112
+
113
+ def fuse():
114
+ # fuse saliency probability maps
115
+ maps.reverse()
116
+ x = torch.cat(maps, 1)
117
+ x = getattr(self, 'outconv')(x)
118
+ maps.insert(0, x)
119
+ # return [torch.sigmoid(x) for x in maps]
120
+ return [x for x in maps]
121
+
122
+ unet(x)
123
+ maps = fuse()
124
+ return maps
125
+
126
+ @staticmethod
127
+ def compute_loss(args):
128
+ preds, labels_v = args
129
+ d0, d1, d2, d3, d4, d5, d6 = preds
130
+ loss0 = bce_loss(d0, labels_v)
131
+ loss1 = bce_loss(d1, labels_v)
132
+ loss2 = bce_loss(d2, labels_v)
133
+ loss3 = bce_loss(d3, labels_v)
134
+ loss4 = bce_loss(d4, labels_v)
135
+ loss5 = bce_loss(d5, labels_v)
136
+ loss6 = bce_loss(d6, labels_v)
137
+
138
+ loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
139
+
140
+ return loss0, loss
141
+
142
+ def _make_layers(self, cfgs):
143
+ self.height = int((len(cfgs) + 1) / 2)
144
+ self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
145
+ for k, v in cfgs.items():
146
+ # build rsu block
147
+ self.add_module(k, RSU(v[0], *v[1]))
148
+ if v[2] > 0:
149
+ # build side layer
150
+ self.add_module(f'side{v[0][-1]}', nn.Conv2d(v[2], self.out_ch, 3, padding=1))
151
+ # build fuse layer
152
+ self.add_module('outconv', nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1))
153
+
154
+
155
+ def U2NET_full():
156
+ full = {
157
+ # cfgs for building RSUs and sides
158
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
159
+ 'stage1': ['En_1', (7, 3, 32, 64), -1],
160
+ 'stage2': ['En_2', (6, 64, 32, 128), -1],
161
+ 'stage3': ['En_3', (5, 128, 64, 256), -1],
162
+ 'stage4': ['En_4', (4, 256, 128, 512), -1],
163
+ 'stage5': ['En_5', (4, 512, 256, 512, True), -1],
164
+ 'stage6': ['En_6', (4, 512, 256, 512, True), 512],
165
+ 'stage5d': ['De_5', (4, 1024, 256, 512, True), 512],
166
+ 'stage4d': ['De_4', (4, 1024, 128, 256), 256],
167
+ 'stage3d': ['De_3', (5, 512, 64, 128), 128],
168
+ 'stage2d': ['De_2', (6, 256, 32, 64), 64],
169
+ 'stage1d': ['De_1', (7, 128, 16, 64), 64],
170
+ }
171
+ return U2NET(cfgs=full, out_ch=1)
172
+
173
+
174
+ def U2NET_full2():
175
+ full = {
176
+ # cfgs for building RSUs and sides
177
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
178
+ 'stage1': ['En_1', (8, 3, 32, 64), -1],
179
+ 'stage2': ['En_2', (7, 64, 32, 128), -1],
180
+ 'stage3': ['En_3', (6, 128, 64, 256), -1],
181
+ 'stage4': ['En_4', (5, 256, 128, 512), -1],
182
+ 'stage5': ['En_5', (5, 512, 256, 512, True), -1],
183
+ 'stage6': ['En_6', (5, 512, 256, 512, True), 512],
184
+ 'stage5d': ['De_5', (5, 1024, 256, 512, True), 512],
185
+ 'stage4d': ['De_4', (5, 1024, 128, 256), 256],
186
+ 'stage3d': ['De_3', (6, 512, 64, 128), 128],
187
+ 'stage2d': ['De_2', (7, 256, 32, 64), 64],
188
+ 'stage1d': ['De_1', (8, 128, 16, 64), 64],
189
+ }
190
+ return U2NET(cfgs=full, out_ch=1)
191
+
192
+
193
+ def U2NET_lite():
194
+ lite = {
195
+ # cfgs for building RSUs and sides
196
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
197
+ 'stage1': ['En_1', (7, 3, 16, 64), -1],
198
+ 'stage2': ['En_2', (6, 64, 16, 64), -1],
199
+ 'stage3': ['En_3', (5, 64, 16, 64), -1],
200
+ 'stage4': ['En_4', (4, 64, 16, 64), -1],
201
+ 'stage5': ['En_5', (4, 64, 16, 64, True), -1],
202
+ 'stage6': ['En_6', (4, 64, 16, 64, True), 64],
203
+ 'stage5d': ['De_5', (4, 128, 16, 64, True), 64],
204
+ 'stage4d': ['De_4', (4, 128, 16, 64), 64],
205
+ 'stage3d': ['De_3', (5, 128, 16, 64), 64],
206
+ 'stage2d': ['De_2', (6, 128, 16, 64), 64],
207
+ 'stage1d': ['De_1', (7, 128, 16, 64), 64],
208
+ }
209
+ return U2NET(cfgs=lite, out_ch=1)
210
+
211
+
212
+ def U2NET_lite2():
213
+ lite = {
214
+ # cfgs for building RSUs and sides
215
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
216
+ 'stage1': ['En_1', (8, 3, 16, 64), -1],
217
+ 'stage2': ['En_2', (7, 64, 16, 64), -1],
218
+ 'stage3': ['En_3', (6, 64, 16, 64), -1],
219
+ 'stage4': ['En_4', (5, 64, 16, 64), -1],
220
+ 'stage5': ['En_5', (5, 64, 16, 64, True), -1],
221
+ 'stage6': ['En_6', (5, 64, 16, 64, True), 64],
222
+ 'stage5d': ['De_5', (5, 128, 16, 64, True), 64],
223
+ 'stage4d': ['De_4', (5, 128, 16, 64), 64],
224
+ 'stage3d': ['De_3', (6, 128, 16, 64), 64],
225
+ 'stage2d': ['De_2', (7, 128, 16, 64), 64],
226
+ 'stage1d': ['De_1', (8, 128, 16, 64), 64],
227
+ }
228
+ return U2NET(cfgs=lite, out_ch=1)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ opencv-python
2
+ pytorch_lightning
3
+ torchvision