plutosss commited on
Commit
8b4ec8a
·
verified ·
1 Parent(s): 7bd470a

Upload 17 files

Browse files
TEED/utils/AF/Fmish.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script provides functional interface for Mish activation function.
3
+ """
4
+
5
+ # import pytorch
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+
10
+ @torch.jit.script
11
+ def mish(input):
12
+ """
13
+ Applies the mish function element-wise:
14
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
15
+ See additional documentation for mish class.
16
+ """
17
+ return input * torch.tanh(F.softplus(input))
TEED/utils/AF/Fsmish.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script based on:
3
+ Wang, Xueliang, Honge Ren, and Achuan Wang.
4
+ "Smish: A Novel Activation Function for Deep Learning Methods.
5
+ " Electronics 11.4 (2022): 540.
6
+ """
7
+
8
+ # import pytorch
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+
13
+ @torch.jit.script
14
+ def smish(input):
15
+ """
16
+ Applies the mish function element-wise:
17
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(sigmoid(x))))
18
+ See additional documentation for mish class.
19
+ """
20
+ return input * torch.tanh(torch.log(1+torch.sigmoid(input)))
TEED/utils/AF/Xmish.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Applies the mish function element-wise:
3
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
4
+ """
5
+
6
+ # import pytorch
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ # import activation functions
12
+ import utils.AF.Fmish as Func
13
+
14
+
15
+ class Mish(nn.Module):
16
+ """
17
+ Applies the mish function element-wise:
18
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
19
+ Shape:
20
+ - Input: (N, *) where * means, any number of additional
21
+ dimensions
22
+ - Output: (N, *), same shape as the input
23
+ Examples:
24
+ >>> m = Mish()
25
+ >>> input = torch.randn(2)
26
+ >>> output = m(input)
27
+ Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
28
+ """
29
+
30
+ def __init__(self):
31
+ """
32
+ Init method.
33
+ """
34
+ super().__init__()
35
+
36
+ def forward(self, input):
37
+ """
38
+ Forward pass of the function.
39
+ """
40
+ if torch.__version__ >= "1.9":
41
+ return F.mish(input)
42
+ else:
43
+ return Func.mish(input)
TEED/utils/AF/Xsmish.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script based on:
3
+ Wang, Xueliang, Honge Ren, and Achuan Wang.
4
+ "Smish: A Novel Activation Function for Deep Learning Methods.
5
+ " Electronics 11.4 (2022): 540.
6
+ smish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + sigmoid(x)))
7
+ """
8
+
9
+ # import pytorch
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+ # import activation functions
15
+ import TEED.utils.AF.Fsmish as Func
16
+
17
+
18
+ class Smish(nn.Module):
19
+ """
20
+ Applies the mish function element-wise:
21
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
22
+ Shape:
23
+ - Input: (N, *) where * means, any number of additional
24
+ dimensions
25
+ - Output: (N, *), same shape as the input
26
+ Examples:
27
+ >>> m = Mish()
28
+ >>> input = torch.randn(2)
29
+ >>> output = m(input)
30
+ Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
31
+ """
32
+
33
+ def __init__(self):
34
+ """
35
+ Init method.
36
+ """
37
+ super().__init__()
38
+
39
+ def forward(self, input):
40
+ """
41
+ Forward pass of the function.
42
+ """
43
+ return Func.smish(input)
TEED/utils/AF/__pycache__/Fmish.cpython-38.pyc ADDED
Binary file (624 Bytes). View file
 
TEED/utils/AF/__pycache__/Fsmish.cpython-310.pyc ADDED
Binary file (745 Bytes). View file
 
TEED/utils/AF/__pycache__/Fsmish.cpython-312.pyc ADDED
Binary file (964 Bytes). View file
 
TEED/utils/AF/__pycache__/Fsmish.cpython-38.pyc ADDED
Binary file (741 Bytes). View file
 
TEED/utils/AF/__pycache__/Xmish.cpython-38.pyc ADDED
Binary file (1.45 kB). View file
 
TEED/utils/AF/__pycache__/Xsmish.cpython-310.pyc ADDED
Binary file (1.55 kB). View file
 
TEED/utils/AF/__pycache__/Xsmish.cpython-312.pyc ADDED
Binary file (1.7 kB). View file
 
TEED/utils/AF/__pycache__/Xsmish.cpython-38.pyc ADDED
Binary file (1.53 kB). View file
 
TEED/utils/__pycache__/img_processing.cpython-310.pyc ADDED
Binary file (7.04 kB). View file
 
TEED/utils/__pycache__/img_processing.cpython-312.pyc ADDED
Binary file (14.7 kB). View file
 
TEED/utils/__pycache__/img_processing.cpython-38.pyc ADDED
Binary file (6.99 kB). View file
 
TEED/utils/img_processing.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import kornia as kn
7
+ import cv2_ext
8
+
9
+ from skimage.metrics import mean_squared_error, peak_signal_noise_ratio
10
+ from sklearn.metrics import mean_absolute_error
11
+
12
+
13
+ def image_normalization(img, img_min=0, img_max=255,
14
+ epsilon=1e-12):
15
+ """This is a typical image normalization function
16
+ where the minimum and maximum of the image is needed
17
+ source: https://en.wikipedia.org/wiki/Normalization_(image_processing)
18
+
19
+ :param img: an image could be gray scale or color
20
+ :param img_min: for default is 0
21
+ :param img_max: for default is 255
22
+
23
+ :return: a normalized image, if max is 255 the dtype is uint8
24
+ """
25
+
26
+ img = np.float32(img)
27
+ # whenever an inconsistent image
28
+ img = (img - np.min(img)) * (img_max - img_min) / \
29
+ ((np.max(img) - np.min(img)) + epsilon) + img_min
30
+ return img
31
+
32
+ def count_parameters(model=None):
33
+ if model is not None:
34
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
35
+ else:
36
+ print("Error counting model parameters line 32 img_processing.py")
37
+ raise NotImplementedError
38
+
39
+
40
+ def save_image_batch_to_disk(tensor, output_dir, file_names, img_shape=None, arg=None, is_inchannel=False):
41
+
42
+ os.makedirs(output_dir, exist_ok=True)
43
+ predict_all = arg.predict_all
44
+ if not arg.is_testing:
45
+ assert len(tensor.shape) == 4, tensor.shape
46
+ img_height,img_width = img_shape[0].item(),img_shape[1].item()
47
+
48
+ for tensor_image, file_name in zip(tensor, file_names):
49
+ image_vis = kn.utils.tensor_to_image(
50
+ torch.sigmoid(tensor_image))#[..., 0]
51
+ image_vis = (255.0*(1.0 - image_vis)).astype(np.uint8)
52
+ output_file_name = os.path.join(output_dir, file_name)
53
+ # print('image vis size', image_vis.shape)
54
+ image_vis =cv2.resize(image_vis, (img_width, img_height))
55
+ assert cv2_ext.imwrite(output_file_name, image_vis)
56
+ assert cv2_ext.imwrite('checkpoints/current_res/'+file_name, image_vis)
57
+ # print(f"Image saved in {output_file_name}")
58
+ else:
59
+ if is_inchannel:
60
+
61
+ tensor, tensor2 = tensor
62
+ fuse_name = 'fusedCH'
63
+ av_name='avgCH'
64
+ is_2tensors=True
65
+ edge_maps2 = []
66
+ for i in tensor2:
67
+ tmp = torch.sigmoid(i).cpu().detach().numpy()
68
+ edge_maps2.append(tmp)
69
+ tensor2 = np.array(edge_maps2)
70
+ else:
71
+ fuse_name = 'fused'
72
+ # av_name = 'avg'
73
+ tensor2=None
74
+ tmp_img2 = None
75
+
76
+ # output_dir_f = os.path.join(output_dir, fuse_name)# normal execution
77
+ output_dir_f = output_dir# for DMRIR
78
+ # output_dir_a = os.path.join(output_dir, av_name)
79
+ os.makedirs(output_dir_f, exist_ok=True)
80
+ # os.makedirs(output_dir_a, exist_ok=True)
81
+ if predict_all:
82
+ all_data_dir = os.path.join(output_dir, "all_edges")
83
+ os.makedirs(all_data_dir, exist_ok=True)
84
+ out1_dir = os.path.join(all_data_dir,"o1")
85
+ out2_dir = os.path.join(all_data_dir,"o2")
86
+ out3_dir = os.path.join(all_data_dir,"o3")# TEED =output 3
87
+ out4_dir = os.path.join(all_data_dir,"o4") # TEED = average
88
+ out5_dir = os.path.join(all_data_dir,"o5")# fusion # TEED
89
+ out6_dir = os.path.join(all_data_dir,"o6") # fusion
90
+ os.makedirs(out1_dir, exist_ok=True)
91
+ os.makedirs(out2_dir, exist_ok=True)
92
+ os.makedirs(out3_dir, exist_ok=True)
93
+ os.makedirs(out4_dir, exist_ok=True)
94
+ os.makedirs(out5_dir, exist_ok=True)
95
+ os.makedirs(out6_dir, exist_ok=True)
96
+
97
+ # 255.0 * (1.0 - em_a)
98
+ edge_maps = []
99
+ for i in tensor:
100
+ tmp = torch.sigmoid(i).cpu().detach().numpy()
101
+ edge_maps.append(tmp)
102
+ tensor = np.array(edge_maps)
103
+ # print(f"tensor shape: {tensor.shape}")
104
+
105
+ image_shape = [x.cpu().detach().numpy() for x in img_shape]
106
+ # (H, W) -> (W, H)
107
+ image_shape = [[y, x] for x, y in zip(image_shape[0], image_shape[1])]
108
+
109
+ assert len(image_shape) == len(file_names)
110
+
111
+ idx = 0
112
+ for i_shape, file_name in zip(image_shape, file_names):
113
+ tmp = tensor[:, idx, ...]
114
+ tmp2 = tensor2[:, idx, ...] if tensor2 is not None else None
115
+ # tmp = np.transpose(np.squeeze(tmp), [0, 1, 2])
116
+ tmp = np.squeeze(tmp)
117
+ tmp2 = np.squeeze(tmp2) if tensor2 is not None else None
118
+
119
+ # Iterate our all 7 NN outputs for a particular image
120
+ preds = []
121
+ fuse_num = tmp.shape[0]-1
122
+ for i in range(tmp.shape[0]):
123
+ tmp_img = tmp[i]
124
+ tmp_img = np.uint8(image_normalization(tmp_img))
125
+ tmp_img = cv2.bitwise_not(tmp_img)
126
+ # tmp_img[tmp_img < 0.0] = 0.0
127
+ # tmp_img = 255.0 * (1.0 - tmp_img)
128
+ if tmp2 is not None:
129
+ tmp_img2 = tmp2[i]
130
+ tmp_img2 = np.uint8(image_normalization(tmp_img2))
131
+ tmp_img2 = cv2.bitwise_not(tmp_img2)
132
+
133
+ # Resize prediction to match input image size
134
+ if not tmp_img.shape[1] == i_shape[0] or not tmp_img.shape[0] == i_shape[1]:
135
+ tmp_img = cv2.resize(tmp_img, (i_shape[0], i_shape[1]))
136
+ tmp_img2 = cv2.resize(tmp_img2, (i_shape[0], i_shape[1])) if tmp2 is not None else None
137
+
138
+
139
+ if tmp2 is not None:
140
+ tmp_mask = np.logical_and(tmp_img>128,tmp_img2<128)
141
+ tmp_img= np.where(tmp_mask, tmp_img2, tmp_img)
142
+ preds.append(tmp_img)
143
+
144
+ else:
145
+ preds.append(tmp_img)
146
+
147
+ if i == fuse_num:
148
+ # print('fuse num',tmp.shape[0], fuse_num, i)
149
+ fuse = tmp_img
150
+ fuse = fuse.astype(np.uint8)
151
+ if tmp_img2 is not None:
152
+ fuse2 = tmp_img2
153
+ fuse2 = fuse2.astype(np.uint8)
154
+ # fuse = fuse-fuse2
155
+ fuse_mask=np.logical_and(fuse>128,fuse2<128)
156
+ fuse = np.where(fuse_mask,fuse2, fuse)
157
+
158
+ # print(fuse.shape, fuse_mask.shape)
159
+
160
+ # Save predicted edge maps
161
+ average = np.array(preds, dtype=np.float32)
162
+ average = np.uint8(np.mean(average, axis=0))
163
+ output_file_name_f = os.path.join(output_dir_f, file_name)
164
+ # output_file_name_a = os.path.join(output_dir_a, file_name)
165
+ cv2.imwrite(output_file_name_f, fuse)
166
+ # cv2_ext.imwrite(output_file_name_a, average)
167
+ if predict_all:
168
+ cv2.imwrite(os.path.join(out1_dir,file_name),preds[0])
169
+ cv2.imwrite(os.path.join(out2_dir,file_name),preds[1])
170
+ cv2.imwrite(os.path.join(out3_dir,file_name),preds[2])
171
+ cv2.imwrite(os.path.join(out4_dir,file_name),average)
172
+ cv2.imwrite(os.path.join(out5_dir,file_name),fuse)
173
+ cv2.imwrite(os.path.join(out6_dir,file_name),fuse)
174
+
175
+ idx += 1
176
+
177
+
178
+ def restore_rgb(config, I, restore_rgb=False):
179
+ """
180
+ :param config: [args.channel_swap, args.mean_pixel_value]
181
+ :param I: and image or a set of images
182
+ :return: an image or a set of images restored
183
+ """
184
+
185
+ if len(I) > 3 and not type(I) == np.ndarray:
186
+ I = np.array(I)
187
+ I = I[:, :, :, 0:3]
188
+ n = I.shape[0]
189
+ for i in range(n):
190
+ x = I[i, ...]
191
+ x = np.array(x, dtype=np.float32)
192
+ x += config[1]
193
+ if restore_rgb:
194
+ x = x[:, :, config[0]]
195
+ x = image_normalization(x)
196
+ I[i, :, :, :] = x
197
+ elif len(I.shape) == 3 and I.shape[-1] == 3:
198
+ I = np.array(I, dtype=np.float32)
199
+ I += config[1]
200
+ if restore_rgb:
201
+ I = I[:, :, config[0]]
202
+ I = image_normalization(I)
203
+ else:
204
+ print("Sorry the input data size is out of our configuration")
205
+ return I
206
+
207
+
208
+ def visualize_result(imgs_list, arg):
209
+ """
210
+ data 2 image in one matrix
211
+ :param imgs_list: a list of prediction, gt and input data
212
+ :param arg:
213
+ :return: one image with the whole of imgs_list data
214
+ """
215
+ n_imgs = len(imgs_list)
216
+ data_list = []
217
+ for i in range(n_imgs):
218
+ tmp = imgs_list[i]
219
+ # print(tmp.shape)
220
+ if tmp.shape[0] == 3:
221
+ tmp = np.transpose(tmp, [1, 2, 0])
222
+ tmp = restore_rgb([
223
+ arg.channel_swap,
224
+ arg.mean_train[:3]
225
+ ], tmp)
226
+ tmp = np.uint8(image_normalization(tmp))
227
+ else:
228
+ tmp = np.squeeze(tmp)
229
+ if len(tmp.shape) == 2:
230
+ tmp = np.uint8(image_normalization(tmp))
231
+ tmp = cv2.bitwise_not(tmp)
232
+ tmp = cv2.cvtColor(tmp, cv2.COLOR_GRAY2BGR)
233
+ else:
234
+ tmp = np.uint8(image_normalization(tmp))
235
+ data_list.append(tmp)
236
+ # print(i,tmp.shape)
237
+ img = data_list[0]
238
+ if n_imgs % 2 == 0:
239
+ imgs = np.zeros((img.shape[0] * 2 + 10, img.shape[1]
240
+ * (n_imgs // 2) + ((n_imgs // 2 - 1) * 5), 3))
241
+ else:
242
+ imgs = np.zeros((img.shape[0] * 2 + 10, img.shape[1]
243
+ * ((1 + n_imgs) // 2) + ((n_imgs // 2) * 5), 3))
244
+ n_imgs += 1
245
+
246
+ k = 0
247
+ imgs = np.uint8(imgs)
248
+ i_step = img.shape[0] + 10
249
+ j_step = img.shape[1] + 5
250
+ for i in range(2):
251
+ for j in range(n_imgs // 2):
252
+ if k < len(data_list):
253
+ imgs[i * i_step:i * i_step+img.shape[0],
254
+ j * j_step:j * j_step+img.shape[1],
255
+ :] = data_list[k]
256
+ k += 1
257
+ else:
258
+ pass
259
+ return imgs
260
+
261
+
262
+
263
+ if __name__ == '__main__':
264
+
265
+ img_base_dir='tmp_edge'
266
+ gt_base_dir='C:/Users/xavysp/dataset/BIPED/edges/edge_maps/test/rgbr'
267
+ # gt_base_dir='C:/Users/xavysp/dataset/BRIND/test_edges'
268
+ # gt_base_dir='C:/Users/xavysp/dataset/UDED/gt'
269
+ vers = 'TEED model in BIPED'
270
+ list_img = os.listdir(img_base_dir)
271
+ list_gt = os.listdir(gt_base_dir)
272
+ mse_list=[]
273
+ psnr_list=[]
274
+ mae_list=[]
275
+
276
+ for img_name, gt_name in zip(list_img,list_gt):
277
+
278
+ # print(img_name, ' ', gt_name)
279
+ tmp_img = cv2.imread(os.path.join(img_base_dir,img_name),0)
280
+ tmp_img = cv2.bitwise_not(tmp_img) # if the image's background
281
+ # is white uncomment this line
282
+ tmp_gt = cv2.imread(os.path.join(gt_base_dir,gt_name),0)
283
+ # print(f"image {img_name} {tmp_img.shape}")
284
+ # print(f"gt {gt_name} {tmp_gt.shape}")
285
+ a = tmp_img.copy()
286
+ tmp_img = image_normalization(tmp_img, img_max=1.)
287
+ tmp_gt = image_normalization(tmp_gt, img_max=1.)
288
+ psnr = peak_signal_noise_ratio(tmp_gt, tmp_img)
289
+ mse = mean_squared_error(tmp_gt, tmp_img)
290
+ mae = mean_absolute_error(tmp_gt, tmp_img)
291
+ # a = cv2.bitwise_not(a) # save data
292
+ # cv2_ext.imwrite(os.path.join("tmp_res",img_name), a) # save data
293
+
294
+ psnr_list.append(psnr)
295
+ mse_list.append(mse)
296
+ mae_list.append(mae)
297
+ print(f"PSNR= {psnr} in {img_name}")
298
+
299
+ av_psnr =np.array(psnr_list).mean()
300
+ av_mse =np.array(mse_list).mean()
301
+ av_mae =np.array(mae_list).mean()
302
+ print(" MSE results: mean ", av_mse)
303
+ print(" MAE results: mean ", av_mae)
304
+ # print(mse_list)
305
+ print(" PSNR results: mean", av_psnr)
306
+ # print(psnr_list)
307
+ print('version: ',vers)
TEED/utils/train_pair0.lst ADDED
The diff for this file is too large to render. See raw diff