jsxyhelu commited on
Commit
81a2131
·
1 Parent(s): 5049187
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlite3 import InterfaceError
2
+ import gradio as gr
3
+ import logging
4
+ import os
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+
12
+ from utils.data_loading import BasicDataset
13
+ from unet import UNet
14
+
15
+
16
+ def predict_img(net,full_img,device,scale_factor=1,out_threshold=0.5):
17
+ net.eval()
18
+ img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
19
+ img = img.unsqueeze(0)
20
+ img = img.to(device=device, dtype=torch.float32)
21
+
22
+ with torch.no_grad():
23
+ output = net(img)
24
+
25
+ if net.n_classes > 1:
26
+ probs = F.softmax(output, dim=1)[0]
27
+ else:
28
+ probs = torch.sigmoid(output)[0]
29
+
30
+ tf = transforms.Compose([
31
+ transforms.ToPILImage(),
32
+ #transforms.Resize((full_img.size[1], full_img.size[0])),
33
+ transforms.ToTensor()
34
+ ])
35
+
36
+ full_mask = tf(probs.cpu()).squeeze()
37
+
38
+ if net.n_classes == 1:
39
+ return (full_mask > out_threshold).numpy()
40
+ else:
41
+ return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
42
+
43
+
44
+
45
+
46
+
47
+ def mask_to_image(mask: np.ndarray):
48
+ if mask.ndim == 2:
49
+ return Image.fromarray((mask * 255).astype(np.uint8))
50
+ elif mask.ndim == 3:
51
+ return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8))
52
+
53
+
54
+
55
+
56
+ def to_black(image):
57
+ modelPath = "./checkpoints/skyseg0113.pth"
58
+
59
+ net = UNet(n_channels=3, n_classes=2)
60
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
61
+
62
+ logging.info(f'Loading model ')
63
+ net.to(device=device)
64
+ net.load_state_dict(torch.load(modelPath, map_location=device))
65
+
66
+ logging.info('over')
67
+ mask = predict_img(net=net,full_img=image,scale_factor=0.5,out_threshold=0.5,device=device)
68
+ output = mask_to_image(mask)
69
+ return output
70
+
71
+ interface = gr.Interface(fn=to_black, inputs="image", outputs="image" )
72
+ interface.launch()
checkpoints/skyseg0113.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c16a76f7ee7a2ad3c360ea348cae69b55b97ee3905e3da6beed42e51561f1f8a
3
+ size 69129741
evaluate.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from tqdm import tqdm
4
+
5
+ from utils.dice_score import multiclass_dice_coeff, dice_coeff
6
+
7
+
8
+ def evaluate(net, dataloader, device):
9
+ net.eval()
10
+ num_val_batches = len(dataloader)
11
+ dice_score = 0
12
+
13
+ # iterate over the validation set
14
+ for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
15
+ image, mask_true = batch['image'], batch['mask']
16
+ # move images and labels to correct device and type
17
+ image = image.to(device=device, dtype=torch.float32)
18
+ mask_true = mask_true.to(device=device, dtype=torch.long)
19
+
20
+ ####
21
+ one = torch.ones_like(mask_true)
22
+ zero = torch.zeros_like(mask_true)
23
+ mask_true = torch.where(mask_true>0,one,zero)
24
+ ####
25
+
26
+
27
+ mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
28
+
29
+ with torch.no_grad():
30
+ # predict the mask
31
+ mask_pred = net(image)
32
+
33
+ # convert to one-hot format
34
+ if net.n_classes == 1:
35
+ mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
36
+ # compute the Dice score
37
+ dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
38
+ else:
39
+ mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
40
+ # compute the Dice score, ignoring background
41
+ dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False)
42
+
43
+
44
+
45
+ net.train()
46
+
47
+ # Fixes a potential division by zero error
48
+ if num_val_batches == 0:
49
+ return dice_score
50
+ return dice_score / num_val_batches
unet/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .unet_model import UNet
2
+ from .u2net import U2NET
3
+ from .u2net import U2NETP
unet/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (237 Bytes). View file
 
unet/__pycache__/u2net.cpython-38.pyc ADDED
Binary file (10.5 kB). View file
 
unet/__pycache__/unet_model.cpython-38.pyc ADDED
Binary file (1.29 kB). View file
 
unet/__pycache__/unet_parts.cpython-38.pyc ADDED
Binary file (2.84 kB). View file
 
unet/u2net.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class REBNCONV(nn.Module):
6
+ def __init__(self,in_ch=3,out_ch=3,dirate=1):
7
+ super(REBNCONV,self).__init__()
8
+
9
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
10
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
11
+ self.relu_s1 = nn.ReLU(inplace=True)
12
+
13
+ def forward(self,x):
14
+
15
+ hx = x
16
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
17
+
18
+ return xout
19
+
20
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
21
+ def _upsample_like(src,tar):
22
+
23
+ src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
24
+
25
+ return src
26
+
27
+
28
+ ### RSU-7 ###
29
+ class RSU7(nn.Module):#UNet07DRES(nn.Module):
30
+
31
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
32
+ super(RSU7,self).__init__()
33
+
34
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
35
+
36
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
37
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
38
+
39
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
40
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
41
+
42
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
43
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
44
+
45
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
46
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
47
+
48
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
49
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
50
+
51
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
52
+
53
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
54
+
55
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
56
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
57
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
58
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
59
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
60
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
61
+
62
+ def forward(self,x):
63
+
64
+ hx = x
65
+ hxin = self.rebnconvin(hx)
66
+
67
+ hx1 = self.rebnconv1(hxin)
68
+ hx = self.pool1(hx1)
69
+
70
+ hx2 = self.rebnconv2(hx)
71
+ hx = self.pool2(hx2)
72
+
73
+ hx3 = self.rebnconv3(hx)
74
+ hx = self.pool3(hx3)
75
+
76
+ hx4 = self.rebnconv4(hx)
77
+ hx = self.pool4(hx4)
78
+
79
+ hx5 = self.rebnconv5(hx)
80
+ hx = self.pool5(hx5)
81
+
82
+ hx6 = self.rebnconv6(hx)
83
+
84
+ hx7 = self.rebnconv7(hx6)
85
+
86
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
87
+ hx6dup = _upsample_like(hx6d,hx5)
88
+
89
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
90
+ hx5dup = _upsample_like(hx5d,hx4)
91
+
92
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
93
+ hx4dup = _upsample_like(hx4d,hx3)
94
+
95
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
96
+ hx3dup = _upsample_like(hx3d,hx2)
97
+
98
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
99
+ hx2dup = _upsample_like(hx2d,hx1)
100
+
101
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
102
+
103
+ return hx1d + hxin
104
+
105
+ ### RSU-6 ###
106
+ class RSU6(nn.Module):#UNet06DRES(nn.Module):
107
+
108
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
109
+ super(RSU6,self).__init__()
110
+
111
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
112
+
113
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
114
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
115
+
116
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
117
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
118
+
119
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
120
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
121
+
122
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
123
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
124
+
125
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
126
+
127
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
128
+
129
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
130
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
131
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
132
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
133
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
134
+
135
+ def forward(self,x):
136
+
137
+ hx = x
138
+
139
+ hxin = self.rebnconvin(hx)
140
+
141
+ hx1 = self.rebnconv1(hxin)
142
+ hx = self.pool1(hx1)
143
+
144
+ hx2 = self.rebnconv2(hx)
145
+ hx = self.pool2(hx2)
146
+
147
+ hx3 = self.rebnconv3(hx)
148
+ hx = self.pool3(hx3)
149
+
150
+ hx4 = self.rebnconv4(hx)
151
+ hx = self.pool4(hx4)
152
+
153
+ hx5 = self.rebnconv5(hx)
154
+
155
+ hx6 = self.rebnconv6(hx5)
156
+
157
+
158
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
159
+ hx5dup = _upsample_like(hx5d,hx4)
160
+
161
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
162
+ hx4dup = _upsample_like(hx4d,hx3)
163
+
164
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
165
+ hx3dup = _upsample_like(hx3d,hx2)
166
+
167
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
168
+ hx2dup = _upsample_like(hx2d,hx1)
169
+
170
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
171
+
172
+ return hx1d + hxin
173
+
174
+ ### RSU-5 ###
175
+ class RSU5(nn.Module):#UNet05DRES(nn.Module):
176
+
177
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
178
+ super(RSU5,self).__init__()
179
+
180
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
181
+
182
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
183
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
184
+
185
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
186
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
187
+
188
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
189
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
190
+
191
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
192
+
193
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
194
+
195
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
196
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
197
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
198
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
199
+
200
+ def forward(self,x):
201
+
202
+ hx = x
203
+
204
+ hxin = self.rebnconvin(hx)
205
+
206
+ hx1 = self.rebnconv1(hxin)
207
+ hx = self.pool1(hx1)
208
+
209
+ hx2 = self.rebnconv2(hx)
210
+ hx = self.pool2(hx2)
211
+
212
+ hx3 = self.rebnconv3(hx)
213
+ hx = self.pool3(hx3)
214
+
215
+ hx4 = self.rebnconv4(hx)
216
+
217
+ hx5 = self.rebnconv5(hx4)
218
+
219
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
220
+ hx4dup = _upsample_like(hx4d,hx3)
221
+
222
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
223
+ hx3dup = _upsample_like(hx3d,hx2)
224
+
225
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
226
+ hx2dup = _upsample_like(hx2d,hx1)
227
+
228
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
229
+
230
+ return hx1d + hxin
231
+
232
+ ### RSU-4 ###
233
+ class RSU4(nn.Module):#UNet04DRES(nn.Module):
234
+
235
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
236
+ super(RSU4,self).__init__()
237
+
238
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
239
+
240
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
241
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
242
+
243
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
244
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
245
+
246
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
247
+
248
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
249
+
250
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
251
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
252
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
253
+
254
+ def forward(self,x):
255
+
256
+ hx = x
257
+
258
+ hxin = self.rebnconvin(hx)
259
+
260
+ hx1 = self.rebnconv1(hxin)
261
+ hx = self.pool1(hx1)
262
+
263
+ hx2 = self.rebnconv2(hx)
264
+ hx = self.pool2(hx2)
265
+
266
+ hx3 = self.rebnconv3(hx)
267
+
268
+ hx4 = self.rebnconv4(hx3)
269
+
270
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
271
+ hx3dup = _upsample_like(hx3d,hx2)
272
+
273
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
274
+ hx2dup = _upsample_like(hx2d,hx1)
275
+
276
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
277
+
278
+ return hx1d + hxin
279
+
280
+ ### RSU-4F ###
281
+ class RSU4F(nn.Module):#UNet04FRES(nn.Module):
282
+
283
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
284
+ super(RSU4F,self).__init__()
285
+
286
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
287
+
288
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
289
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
290
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
291
+
292
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
293
+
294
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
295
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
296
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
297
+
298
+ def forward(self,x):
299
+
300
+ hx = x
301
+
302
+ hxin = self.rebnconvin(hx)
303
+
304
+ hx1 = self.rebnconv1(hxin)
305
+ hx2 = self.rebnconv2(hx1)
306
+ hx3 = self.rebnconv3(hx2)
307
+
308
+ hx4 = self.rebnconv4(hx3)
309
+
310
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
311
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
312
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
313
+
314
+ return hx1d + hxin
315
+
316
+
317
+ ##### U^2-Net ####
318
+ class U2NET(nn.Module):
319
+
320
+ def __init__(self,in_ch=3,out_ch=1):
321
+ super(U2NET,self).__init__()
322
+
323
+ self.stage1 = RSU7(in_ch,32,64)
324
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
325
+
326
+ self.stage2 = RSU6(64,32,128)
327
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
328
+
329
+ self.stage3 = RSU5(128,64,256)
330
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
331
+
332
+ self.stage4 = RSU4(256,128,512)
333
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
334
+
335
+ self.stage5 = RSU4F(512,256,512)
336
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
337
+
338
+ self.stage6 = RSU4F(512,256,512)
339
+
340
+ # decoder
341
+ self.stage5d = RSU4F(1024,256,512)
342
+ self.stage4d = RSU4(1024,128,256)
343
+ self.stage3d = RSU5(512,64,128)
344
+ self.stage2d = RSU6(256,32,64)
345
+ self.stage1d = RSU7(128,16,64)
346
+
347
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
348
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
349
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
350
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
351
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
352
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
353
+
354
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
355
+
356
+ def forward(self,x):
357
+
358
+ hx = x
359
+
360
+ #stage 1
361
+ hx1 = self.stage1(hx)
362
+ hx = self.pool12(hx1)
363
+
364
+ #stage 2
365
+ hx2 = self.stage2(hx)
366
+ hx = self.pool23(hx2)
367
+
368
+ #stage 3
369
+ hx3 = self.stage3(hx)
370
+ hx = self.pool34(hx3)
371
+
372
+ #stage 4
373
+ hx4 = self.stage4(hx)
374
+ hx = self.pool45(hx4)
375
+
376
+ #stage 5
377
+ hx5 = self.stage5(hx)
378
+ hx = self.pool56(hx5)
379
+
380
+ #stage 6
381
+ hx6 = self.stage6(hx)
382
+ hx6up = _upsample_like(hx6,hx5)
383
+
384
+ #-------------------- decoder --------------------
385
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
386
+ hx5dup = _upsample_like(hx5d,hx4)
387
+
388
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
389
+ hx4dup = _upsample_like(hx4d,hx3)
390
+
391
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
392
+ hx3dup = _upsample_like(hx3d,hx2)
393
+
394
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
395
+ hx2dup = _upsample_like(hx2d,hx1)
396
+
397
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
398
+
399
+
400
+ #side output
401
+ d1 = self.side1(hx1d)
402
+
403
+ d2 = self.side2(hx2d)
404
+ d2 = _upsample_like(d2,d1)
405
+
406
+ d3 = self.side3(hx3d)
407
+ d3 = _upsample_like(d3,d1)
408
+
409
+ d4 = self.side4(hx4d)
410
+ d4 = _upsample_like(d4,d1)
411
+
412
+ d5 = self.side5(hx5d)
413
+ d5 = _upsample_like(d5,d1)
414
+
415
+ d6 = self.side6(hx6)
416
+ d6 = _upsample_like(d6,d1)
417
+
418
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
419
+
420
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
421
+
422
+ ### U^2-Net small ###
423
+ class U2NETP(nn.Module):
424
+
425
+ def __init__(self,in_ch=3,out_ch=1):
426
+ super(U2NETP,self).__init__()
427
+
428
+ self.stage1 = RSU7(in_ch,16,64)
429
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
430
+
431
+ self.stage2 = RSU6(64,16,64)
432
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
433
+
434
+ self.stage3 = RSU5(64,16,64)
435
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
436
+
437
+ self.stage4 = RSU4(64,16,64)
438
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
439
+
440
+ self.stage5 = RSU4F(64,16,64)
441
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
442
+
443
+ self.stage6 = RSU4F(64,16,64)
444
+
445
+ # decoder
446
+ self.stage5d = RSU4F(128,16,64)
447
+ self.stage4d = RSU4(128,16,64)
448
+ self.stage3d = RSU5(128,16,64)
449
+ self.stage2d = RSU6(128,16,64)
450
+ self.stage1d = RSU7(128,16,64)
451
+
452
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
453
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
454
+ self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
455
+ self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
456
+ self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
457
+ self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
458
+
459
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
460
+
461
+ def forward(self,x):
462
+
463
+ hx = x
464
+
465
+ #stage 1
466
+ hx1 = self.stage1(hx)
467
+ hx = self.pool12(hx1)
468
+
469
+ #stage 2
470
+ hx2 = self.stage2(hx)
471
+ hx = self.pool23(hx2)
472
+
473
+ #stage 3
474
+ hx3 = self.stage3(hx)
475
+ hx = self.pool34(hx3)
476
+
477
+ #stage 4
478
+ hx4 = self.stage4(hx)
479
+ hx = self.pool45(hx4)
480
+
481
+ #stage 5
482
+ hx5 = self.stage5(hx)
483
+ hx = self.pool56(hx5)
484
+
485
+ #stage 6
486
+ hx6 = self.stage6(hx)
487
+ hx6up = _upsample_like(hx6,hx5)
488
+
489
+ #decoder
490
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
491
+ hx5dup = _upsample_like(hx5d,hx4)
492
+
493
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
494
+ hx4dup = _upsample_like(hx4d,hx3)
495
+
496
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
497
+ hx3dup = _upsample_like(hx3d,hx2)
498
+
499
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
500
+ hx2dup = _upsample_like(hx2d,hx1)
501
+
502
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
503
+
504
+
505
+ #side output
506
+ d1 = self.side1(hx1d)
507
+
508
+ d2 = self.side2(hx2d)
509
+ d2 = _upsample_like(d2,d1)
510
+
511
+ d3 = self.side3(hx3d)
512
+ d3 = _upsample_like(d3,d1)
513
+
514
+ d4 = self.side4(hx4d)
515
+ d4 = _upsample_like(d4,d1)
516
+
517
+ d5 = self.side5(hx5d)
518
+ d5 = _upsample_like(d5,d1)
519
+
520
+ d6 = self.side6(hx6)
521
+ d6 = _upsample_like(d6,d1)
522
+
523
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
524
+
525
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
unet/u2net_refactor.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import math
5
+
6
+ __all__ = ['U2NET_full', 'U2NET_lite']
7
+
8
+
9
+ def _upsample_like(x, size):
10
+ return nn.Upsample(size=size, mode='bilinear', align_corners=False)(x)
11
+
12
+
13
+ def _size_map(x, height):
14
+ # {height: size} for Upsample
15
+ size = list(x.shape[-2:])
16
+ sizes = {}
17
+ for h in range(1, height):
18
+ sizes[h] = size
19
+ size = [math.ceil(w / 2) for w in size]
20
+ return sizes
21
+
22
+
23
+ class REBNCONV(nn.Module):
24
+ def __init__(self, in_ch=3, out_ch=3, dilate=1):
25
+ super(REBNCONV, self).__init__()
26
+
27
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate)
28
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
29
+ self.relu_s1 = nn.ReLU(inplace=True)
30
+
31
+ def forward(self, x):
32
+ return self.relu_s1(self.bn_s1(self.conv_s1(x)))
33
+
34
+
35
+ class RSU(nn.Module):
36
+ def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
37
+ super(RSU, self).__init__()
38
+ self.name = name
39
+ self.height = height
40
+ self.dilated = dilated
41
+ self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
42
+
43
+ def forward(self, x):
44
+ sizes = _size_map(x, self.height)
45
+ x = self.rebnconvin(x)
46
+
47
+ # U-Net like symmetric encoder-decoder structure
48
+ def unet(x, height=1):
49
+ if height < self.height:
50
+ x1 = getattr(self, f'rebnconv{height}')(x)
51
+ if not self.dilated and height < self.height - 1:
52
+ x2 = unet(getattr(self, 'downsample')(x1), height + 1)
53
+ else:
54
+ x2 = unet(x1, height + 1)
55
+
56
+ x = getattr(self, f'rebnconv{height}d')(torch.cat((x2, x1), 1))
57
+ return _upsample_like(x, sizes[height - 1]) if not self.dilated and height > 1 else x
58
+ else:
59
+ return getattr(self, f'rebnconv{height}')(x)
60
+
61
+ return x + unet(x)
62
+
63
+ def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
64
+ self.add_module('rebnconvin', REBNCONV(in_ch, out_ch))
65
+ self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
66
+
67
+ self.add_module(f'rebnconv1', REBNCONV(out_ch, mid_ch))
68
+ self.add_module(f'rebnconv1d', REBNCONV(mid_ch * 2, out_ch))
69
+
70
+ for i in range(2, height):
71
+ dilate = 1 if not dilated else 2 ** (i - 1)
72
+ self.add_module(f'rebnconv{i}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
73
+ self.add_module(f'rebnconv{i}d', REBNCONV(mid_ch * 2, mid_ch, dilate=dilate))
74
+
75
+ dilate = 2 if not dilated else 2 ** (height - 1)
76
+ self.add_module(f'rebnconv{height}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
77
+
78
+
79
+ class U2NET(nn.Module):
80
+ def __init__(self, cfgs, out_ch):
81
+ super(U2NET, self).__init__()
82
+ self.out_ch = out_ch
83
+ self._make_layers(cfgs)
84
+
85
+ def forward(self, x):
86
+ sizes = _size_map(x, self.height)
87
+ maps = [] # storage for maps
88
+
89
+ # side saliency map
90
+ def unet(x, height=1):
91
+ if height < 6:
92
+ x1 = getattr(self, f'stage{height}')(x)
93
+ x2 = unet(getattr(self, 'downsample')(x1), height + 1)
94
+ x = getattr(self, f'stage{height}d')(torch.cat((x2, x1), 1))
95
+ side(x, height)
96
+ return _upsample_like(x, sizes[height - 1]) if height > 1 else x
97
+ else:
98
+ x = getattr(self, f'stage{height}')(x)
99
+ side(x, height)
100
+ return _upsample_like(x, sizes[height - 1])
101
+
102
+ def side(x, h):
103
+ # side output saliency map (before sigmoid)
104
+ x = getattr(self, f'side{h}')(x)
105
+ x = _upsample_like(x, sizes[1])
106
+ maps.append(x)
107
+
108
+ def fuse():
109
+ # fuse saliency probability maps
110
+ maps.reverse()
111
+ x = torch.cat(maps, 1)
112
+ x = getattr(self, 'outconv')(x)
113
+ maps.insert(0, x)
114
+ return [torch.sigmoid(x) for x in maps]
115
+
116
+ unet(x)
117
+ maps = fuse()
118
+ return maps
119
+
120
+ def _make_layers(self, cfgs):
121
+ self.height = int((len(cfgs) + 1) / 2)
122
+ self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
123
+ for k, v in cfgs.items():
124
+ # build rsu block
125
+ self.add_module(k, RSU(v[0], *v[1]))
126
+ if v[2] > 0:
127
+ # build side layer
128
+ self.add_module(f'side{v[0][-1]}', nn.Conv2d(v[2], self.out_ch, 3, padding=1))
129
+ # build fuse layer
130
+ self.add_module('outconv', nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1))
131
+
132
+
133
+ def U2NET_full():
134
+ full = {
135
+ # cfgs for building RSUs and sides
136
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
137
+ 'stage1': ['En_1', (7, 3, 32, 64), -1],
138
+ 'stage2': ['En_2', (6, 64, 32, 128), -1],
139
+ 'stage3': ['En_3', (5, 128, 64, 256), -1],
140
+ 'stage4': ['En_4', (4, 256, 128, 512), -1],
141
+ 'stage5': ['En_5', (4, 512, 256, 512, True), -1],
142
+ 'stage6': ['En_6', (4, 512, 256, 512, True), 512],
143
+ 'stage5d': ['De_5', (4, 1024, 256, 512, True), 512],
144
+ 'stage4d': ['De_4', (4, 1024, 128, 256), 256],
145
+ 'stage3d': ['De_3', (5, 512, 64, 128), 128],
146
+ 'stage2d': ['De_2', (6, 256, 32, 64), 64],
147
+ 'stage1d': ['De_1', (7, 128, 16, 64), 64],
148
+ }
149
+ return U2NET(cfgs=full, out_ch=1)
150
+
151
+
152
+ def U2NET_lite():
153
+ lite = {
154
+ # cfgs for building RSUs and sides
155
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
156
+ 'stage1': ['En_1', (7, 3, 16, 64), -1],
157
+ 'stage2': ['En_2', (6, 64, 16, 64), -1],
158
+ 'stage3': ['En_3', (5, 64, 16, 64), -1],
159
+ 'stage4': ['En_4', (4, 64, 16, 64), -1],
160
+ 'stage5': ['En_5', (4, 64, 16, 64, True), -1],
161
+ 'stage6': ['En_6', (4, 64, 16, 64, True), 64],
162
+ 'stage5d': ['De_5', (4, 128, 16, 64, True), 64],
163
+ 'stage4d': ['De_4', (4, 128, 16, 64), 64],
164
+ 'stage3d': ['De_3', (5, 128, 16, 64), 64],
165
+ 'stage2d': ['De_2', (6, 128, 16, 64), 64],
166
+ 'stage1d': ['De_1', (7, 128, 16, 64), 64],
167
+ }
168
+ return U2NET(cfgs=lite, out_ch=1)
unet/unet_model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Full assembly of the parts to form the complete network """
2
+
3
+ from .unet_parts import *
4
+
5
+
6
+ class UNet(nn.Module):
7
+ def __init__(self, n_channels, n_classes, bilinear=True):
8
+ super(UNet, self).__init__()
9
+ self.n_channels = n_channels
10
+ self.n_classes = n_classes
11
+ self.bilinear = bilinear
12
+
13
+ self.inc = DoubleConv(n_channels, 64)
14
+ self.down1 = Down(64, 128)
15
+ self.down2 = Down(128, 256)
16
+ self.down3 = Down(256, 512)
17
+ factor = 2 if bilinear else 1
18
+ self.down4 = Down(512, 1024 // factor)
19
+ self.up1 = Up(1024, 512 // factor, bilinear)
20
+ self.up2 = Up(512, 256 // factor, bilinear)
21
+ self.up3 = Up(256, 128 // factor, bilinear)
22
+ self.up4 = Up(128, 64, bilinear)
23
+ self.outc = OutConv(64, n_classes)
24
+
25
+ def forward(self, x):
26
+ x1 = self.inc(x)
27
+ x2 = self.down1(x1)
28
+ x3 = self.down2(x2)
29
+ x4 = self.down3(x3)
30
+ x5 = self.down4(x4)
31
+ x = self.up1(x5, x4)
32
+ x = self.up2(x, x3)
33
+ x = self.up3(x, x2)
34
+ x = self.up4(x, x1)
35
+ logits = self.outc(x)
36
+ return logits
unet/unet_parts.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Parts of the U-Net model """
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class DoubleConv(nn.Module):
9
+ """(convolution => [BN] => ReLU) * 2"""
10
+
11
+ def __init__(self, in_channels, out_channels, mid_channels=None):
12
+ super().__init__()
13
+ if not mid_channels:
14
+ mid_channels = out_channels
15
+ self.double_conv = nn.Sequential(
16
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
17
+ nn.BatchNorm2d(mid_channels),
18
+ nn.ReLU(inplace=True),
19
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
20
+ nn.BatchNorm2d(out_channels),
21
+ nn.ReLU(inplace=True)
22
+ )
23
+
24
+ def forward(self, x):
25
+ return self.double_conv(x)
26
+
27
+
28
+ class Down(nn.Module):
29
+ """Downscaling with maxpool then double conv"""
30
+
31
+ def __init__(self, in_channels, out_channels):
32
+ super().__init__()
33
+ self.maxpool_conv = nn.Sequential(
34
+ nn.MaxPool2d(2),
35
+ DoubleConv(in_channels, out_channels)
36
+ )
37
+
38
+ def forward(self, x):
39
+ return self.maxpool_conv(x)
40
+
41
+
42
+ class Up(nn.Module):
43
+ """Upscaling then double conv"""
44
+
45
+ def __init__(self, in_channels, out_channels, bilinear=True):
46
+ super().__init__()
47
+
48
+ # if bilinear, use the normal convolutions to reduce the number of channels
49
+ if bilinear:
50
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
51
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
52
+ else:
53
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
54
+ self.conv = DoubleConv(in_channels, out_channels)
55
+
56
+ def forward(self, x1, x2):
57
+ x1 = self.up(x1)
58
+ # input is CHW
59
+ diffY = x2.size()[2] - x1.size()[2]
60
+ diffX = x2.size()[3] - x1.size()[3]
61
+
62
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
63
+ diffY // 2, diffY - diffY // 2])
64
+ # if you have padding issues, see
65
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
66
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
67
+ x = torch.cat([x2, x1], dim=1)
68
+ return self.conv(x)
69
+
70
+
71
+ class OutConv(nn.Module):
72
+ def __init__(self, in_channels, out_channels):
73
+ super(OutConv, self).__init__()
74
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
75
+
76
+ def forward(self, x):
77
+ return self.conv(x)
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (139 Bytes). View file
 
utils/__pycache__/data_loading.cpython-38.pyc ADDED
Binary file (3.38 kB). View file
 
utils/__pycache__/dice_score.cpython-38.pyc ADDED
Binary file (1.39 kB). View file
 
utils/__pycache__/utils.cpython-38.pyc ADDED
Binary file (700 Bytes). View file
 
utils/data_loading.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from os import listdir
3
+ from os.path import splitext
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset
10
+
11
+
12
+ class BasicDataset(Dataset):
13
+ def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ''):
14
+ self.images_dir = Path(images_dir)
15
+ self.masks_dir = Path(masks_dir)
16
+ assert 0 < scale <= 1, 'Scale must be between 0 and 1'
17
+ self.scale = scale
18
+ self.mask_suffix = mask_suffix
19
+
20
+ self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith('.')]
21
+ if not self.ids:
22
+ raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there')
23
+ logging.info(f'Creating dataset with {len(self.ids)} examples')
24
+
25
+ def __len__(self):
26
+ return len(self.ids)
27
+
28
+ @classmethod
29
+ def preprocess(cls, pil_img, scale, is_mask):
30
+ w= h = pil_img.size
31
+ newW, newH = int(scale * w), int(scale * h)
32
+ assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
33
+ #pil_img = pil_img.resize((newW, newH))
34
+
35
+ img_ndarray = np.asarray(pil_img)
36
+
37
+ if img_ndarray.ndim == 2 and not is_mask:
38
+ img_ndarray = img_ndarray[np.newaxis, ...]
39
+ elif not is_mask:
40
+ img_ndarray = img_ndarray.transpose((2, 0, 1))
41
+
42
+ if not is_mask:
43
+ img_ndarray = img_ndarray / 255
44
+
45
+ return img_ndarray
46
+
47
+ @classmethod
48
+ def load(cls, filename):
49
+ ext = splitext(filename)[1]
50
+ if ext in ['.npz', '.npy']:
51
+ return Image.fromarray(np.load(filename))
52
+ elif ext in ['.pt', '.pth']:
53
+ return Image.fromarray(torch.load(filename).numpy())
54
+ else:
55
+ return Image.open(filename)
56
+
57
+ def __getitem__(self, idx):
58
+ name = self.ids[idx]
59
+ mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*'))
60
+ img_file = list(self.images_dir.glob(name + '.*'))
61
+
62
+ assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
63
+ assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
64
+ mask = self.load(mask_file[0])
65
+ img = self.load(img_file[0])
66
+
67
+ assert img.size == mask.size, \
68
+ 'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'
69
+
70
+ img = self.preprocess(img, self.scale, is_mask=False)
71
+ mask = self.preprocess(mask, self.scale, is_mask=True)
72
+
73
+ return {
74
+ 'image': torch.as_tensor(img.copy()).float().contiguous(),
75
+ 'mask': torch.as_tensor(mask.copy()).long().contiguous()
76
+ }
77
+
78
+
79
+ class GODataset(BasicDataset):
80
+ def __init__(self, images_dir, masks_dir, scale=1):
81
+ super().__init__(images_dir, masks_dir, scale, mask_suffix='_gt')
utils/dice_score.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+
4
+
5
+ def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
6
+ # Average of Dice coefficient for all batches, or for a single mask
7
+ assert input.size() == target.size()
8
+ if input.dim() == 2 and reduce_batch_first:
9
+ raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')
10
+
11
+ if input.dim() == 2 or reduce_batch_first:
12
+ inter = torch.dot(input.reshape(-1), target.reshape(-1))
13
+ sets_sum = torch.sum(input) + torch.sum(target)
14
+ if sets_sum.item() == 0:
15
+ sets_sum = 2 * inter
16
+
17
+ return (2 * inter + epsilon) / (sets_sum + epsilon)
18
+ else:
19
+ # compute and average metric for each batch element
20
+ dice = 0
21
+ for i in range(input.shape[0]):
22
+ dice += dice_coeff(input[i, ...], target[i, ...])
23
+ return dice / input.shape[0]
24
+
25
+
26
+ def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
27
+ # Average of Dice coefficient for all classes
28
+ assert input.size() == target.size()
29
+ dice = 0
30
+ for channel in range(input.shape[1]):
31
+ dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
32
+
33
+ return dice / input.shape[1]
34
+
35
+
36
+ def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
37
+ # Dice loss (objective to minimize) between 0 and 1
38
+ assert input.size() == target.size()
39
+ fn = multiclass_dice_coeff if multiclass else dice_coeff
40
+ return 1 - fn(input, target, reduce_batch_first=True)
utils/utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+
3
+
4
+ def plot_img_and_mask(img, mask):
5
+ classes = mask.shape[0] if len(mask.shape) > 2 else 1
6
+ fig, ax = plt.subplots(1, classes + 1)
7
+ ax[0].set_title('Input image')
8
+ ax[0].imshow(img)
9
+ if classes > 1:
10
+ for i in range(classes):
11
+ ax[i + 1].set_title(f'Output mask (class {i + 1})')
12
+ ax[i + 1].imshow(mask[:, :, i])
13
+ else:
14
+ ax[1].set_title(f'Output mask')
15
+ ax[1].imshow(mask)
16
+ plt.xticks([]), plt.yticks([])
17
+ plt.show()