plutosss commited on
Commit
7bd470a
·
verified ·
1 Parent(s): e3b8c7d

Upload 7 files

Browse files
Files changed (7) hide show
  1. TEED/LICENSE +21 -0
  2. TEED/README.md +68 -0
  3. TEED/dataset.py +581 -0
  4. TEED/devices.py +271 -0
  5. TEED/loss2.py +92 -0
  6. TEED/ted.py +297 -0
  7. TEED/util.py +78 -0
TEED/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Xavier Soria Poma
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
TEED/README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/tiny-and-efficient-model-for-the-edge/edge-detection-on-uded)](https://paperswithcode.com/sota/edge-detection-on-uded?p=tiny-and-efficient-model-for-the-edge)
2
+
3
+ # Tiny and Efficient Model for the Edge Detection Generalization (Paper)
4
+
5
+ ## Overview
6
+
7
+ <div style="text-align:center"><img src='imgs/teedBanner.png' width=800>
8
+ </div>
9
+
10
+
11
+
12
+ Tiny and Efficient Edge Detector (TEED) is a light convolutional neural
13
+ network with only $58K$ parameters, less than $0.2$% of the
14
+ state-of-the-art models. Training on the [BIPED](https://www.kaggle.com/datasets/xavysp/biped)
15
+ dataset takes *less than 30 minutes*, with each epoch requiring
16
+ *less than 5 minutes*. Our proposed model is easy to train
17
+ and it quickly converges within very first few epochs, while the
18
+ predicted edge-maps are crisp and of high quality, see image above.
19
+ [This paper has been accepted by ICCV 2023-Workshop RCV](https://arxiv.org/abs/2308.06468).
20
+
21
+ ... In construction
22
+
23
+ git clone https://github.com/xavysp/TEED.git
24
+ cd TEED
25
+
26
+ Then,
27
+
28
+ ## Testing with TEED
29
+
30
+ Copy and paste your images into data/ folder, and:
31
+
32
+ python main.py --choose_test_data=-1
33
+
34
+ ## Training with TEED
35
+
36
+ Set the following lines in main.py:
37
+
38
+ 25: is_testing =False
39
+ # training with BIPED
40
+ 223: TRAIN_DATA = DATASET_NAMES[0]
41
+
42
+ then run
43
+
44
+ python main.py
45
+
46
+ Check the configurations of the datasets in dataset.py
47
+
48
+
49
+ ## UDED dataset
50
+
51
+ Here the [link](https://github.com/xavysp/UDED) to access the UDED dataset for edge detection
52
+
53
+ ## Citation
54
+
55
+ If you like TEED, why not starring the project on GitHub!
56
+
57
+ [![GitHub stars](https://img.shields.io/github/stars/xavysp/TEED.svg?style=social&label=Star&maxAge=3600)](https://GitHub.com/xavysp/TEED/stargazers/)
58
+
59
+ Please cite our Dataset if you find helpful in your academic/scientific publication,
60
+ ```
61
+ @InProceedings{Soria_2023teed,
62
+ author = {Soria, Xavier and Li, Yachuan and Rouhani, Mohammad and Sappa, Angel D.},
63
+ title = {Tiny and Efficient Model for the Edge Detection Generalization},
64
+ booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) Workshops},
65
+ month = {October},
66
+ year = {2023},
67
+ pages = {1364-1373}
68
+ }
TEED/dataset.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ import json
9
+
10
+ DATASET_NAMES = [
11
+ 'BIPED',
12
+ 'BIPED-B2',
13
+ 'BIPED-B3',
14
+ 'BIPED-B5',
15
+ 'BIPED-B6',
16
+ 'BSDS', # 5
17
+ 'BRIND', # 6
18
+ 'ICEDA', #7
19
+ 'BSDS300',
20
+ 'CID', #9
21
+ 'DCD',
22
+ 'MDBD', #11
23
+ 'PASCAL',
24
+ 'NYUD', #13
25
+ 'BIPBRI',
26
+ 'UDED', # 15 just for testing
27
+ 'DMRIR',
28
+ 'CLASSIC'
29
+ ] # 8
30
+ # [108, 109.451,112.230,137.86]
31
+ BIPED_mean = [103.939,116.779,123.68,137.86]
32
+
33
+ def dataset_info(dataset_name, is_linux=True):
34
+ if is_linux:
35
+
36
+ config = {
37
+ 'UDED': {
38
+ 'img_height': 512, # 321
39
+ 'img_width': 512, # 481
40
+ 'train_list': None,
41
+ 'test_list': 'test_pair.lst',
42
+ 'data_dir': '/root/workspace/datasets/UDED', # mean_rgb
43
+ 'yita': 0.5,
44
+ 'mean': [104.007, 116.669, 122.679, 137.86]# [104.007, 116.669, 122.679, 137.86]
45
+ }, #[98.939,111.779,117.68,137.86]
46
+ 'BSDS': {
47
+ 'img_height': 512, #321
48
+ 'img_width': 512, #481
49
+ 'train_list': 'train_pair.lst',
50
+ 'test_list': 'test_pair.lst',
51
+ 'data_dir': '/root/workspace/datasets/BSDS', # mean_rgb
52
+ 'yita': 0.5,
53
+ 'mean': [104.007, 116.669, 122.679, 137.86]
54
+ },
55
+ 'BRIND': {
56
+ 'img_height': 512, # 321
57
+ 'img_width': 512, # 481
58
+ 'train_list': 'train_pair_all.lst',
59
+ # all train_pair_all.lst
60
+ # less train_pair.lst
61
+ 'test_list': 'test_pair.lst',
62
+ 'data_dir': '/root/workspace/datasets/BRIND', # mean_rgb
63
+ 'yita': 0.5,
64
+ 'mean': [104.007, 116.669, 122.679, 137.86]
65
+ },
66
+ 'ICEDA': {
67
+ 'img_height': 1024, # 321
68
+ 'img_width': 1408, # 481
69
+ 'train_list': None,
70
+ 'test_list': 'test_pair.lst',
71
+ 'data_dir': '/root/workspace/datasets/ICEDA', # mean_rgb
72
+ 'yita': 0.5,
73
+ 'mean': [104.007, 116.669, 122.679, 137.86]
74
+ },
75
+ 'BSDS300': {
76
+ 'img_height': 512, #321
77
+ 'img_width': 512, #481
78
+ 'test_list': 'test_pair.lst',
79
+ 'train_list': None,
80
+ 'data_dir': '/root/workspace/datasets/BSDS300', # NIR
81
+ 'yita': 0.5,
82
+ 'mean': [104.007, 116.669, 122.679, 137.86]
83
+ },
84
+ 'PASCAL': {
85
+ 'img_height': 416, # 375
86
+ 'img_width': 512, #500
87
+ 'test_list': 'test_pair.lst',
88
+ 'train_list': None,
89
+ 'data_dir': '/root/datasets/PASCAL', # mean_rgb
90
+ 'yita': 0.3,
91
+ 'mean': [104.007, 116.669, 122.679, 137.86]
92
+ },
93
+ 'CID': {
94
+ 'img_height': 512,
95
+ 'img_width': 512,
96
+ 'test_list': 'test_pair.lst',
97
+ 'train_list': None,
98
+ 'data_dir': '/root/datasets/CID', # mean_rgb
99
+ 'yita': 0.3,
100
+ 'mean': [104.007, 116.669, 122.679, 137.86]
101
+ },
102
+ 'NYUD': {
103
+ 'img_height': 448,#425
104
+ 'img_width': 560,#560
105
+ 'test_list': 'test_pair.lst',
106
+ 'train_list': None,
107
+ 'data_dir': '/root/datasets/NYUD', # mean_rgb
108
+ 'yita': 0.5,
109
+ 'mean': [104.007, 116.669, 122.679, 137.86]
110
+ },
111
+ 'MDBD': {
112
+ 'img_height': 720,
113
+ 'img_width': 1280,
114
+ 'test_list': 'test_pair.lst',
115
+ 'train_list': 'train_pair.lst',
116
+ 'data_dir': '/root/workspace/datasets/MDBD', # mean_rgb
117
+ 'yita': 0.3,
118
+ 'mean': [104.007, 116.669, 122.679, 137.86]
119
+ },
120
+ 'BIPED': {
121
+ 'img_height': 720, #720 # 1088
122
+ 'img_width': 1280, # 1280 5 1920
123
+ 'test_list': 'test_pair.lst',
124
+ 'train_list': 'train_pair0.lst', # Base augmentation
125
+ # 'train_list': 'train_pairB3.lst', # another augmentation
126
+ # 'train_list': 'train_pairB5.lst', # Last augmentation
127
+ 'data_dir': '/root/workspace/datasets/BIPED', # mean_rgb
128
+ 'yita': 0.5,
129
+ 'mean':BIPED_mean
130
+ #
131
+ },
132
+ 'CLASSIC': {
133
+ 'img_height': 512,#
134
+ 'img_width': 512,# 512
135
+ 'test_list': None,
136
+ 'train_list': None,
137
+ 'data_dir': 'data', # mean_rgb
138
+ 'yita': 0.5,
139
+ 'mean': [104.007, 116.669, 122.679, 137.86]
140
+ },
141
+ 'BIPED-B2': {'img_height': 720, # 720
142
+ 'img_width': 1280, # 1280
143
+ 'test_list': 'test_pair.lst',
144
+ 'train_list': 'train_rgb.lst',
145
+ 'data_dir': 'C:/Users/xavysp/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
146
+ 'yita': 0.5,
147
+ 'mean':BIPED_mean},
148
+ 'BIPED-B3': {'img_height': 720, # 720
149
+ 'img_width': 1280, # 1280
150
+ 'test_list': 'test_pair.lst',
151
+ 'train_list': 'train_rgb.lst',
152
+ 'data_dir': 'C:/Users/xavysp/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
153
+ 'yita': 0.5,
154
+ 'mean':BIPED_mean},
155
+ 'BIPED-B5': {'img_height': 720, # 720
156
+ 'img_width': 1280, # 1280
157
+ 'test_list': 'test_pair.lst',
158
+ 'train_list': 'train_rgb.lst',
159
+ 'data_dir': 'C:/Users/xavysp/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
160
+ 'yita': 0.5,
161
+ 'mean':BIPED_mean},
162
+ 'BIPED-B6': {'img_height': 720, # 720
163
+ 'img_width': 1280, # 1280
164
+ 'test_list': 'test_pair.lst',
165
+ 'train_list': 'train_rgb.lst',
166
+ 'data_dir': 'C:/Users/xavysp/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
167
+ 'yita': 0.5,
168
+ 'mean':BIPED_mean},
169
+ 'DCD': {
170
+ 'img_height': 352, #240
171
+ 'img_width': 480,# 360
172
+ 'test_list': 'test_pair.lst',
173
+ 'train_list': None,
174
+ 'data_dir': '/opt/dataset/DCD', # mean_rgb
175
+ 'yita': 0.2,
176
+ 'mean': [104.007, 116.669, 122.679, 137.86]
177
+ }
178
+ }
179
+ else:
180
+ config = {
181
+ 'UDED': {
182
+ 'img_height': 512, # 321
183
+ 'img_width': 512, # 481
184
+ 'train_list': None,
185
+ 'test_list': 'test_pair.lst',
186
+ 'data_dir': 'C:/dataset/UDED', # mean_rgb
187
+ 'yita': 0.5,
188
+ 'mean':[104.007, 116.669, 122.679, 137.86] # [183.939,196.779,203.68,137.86] # [104.007, 116.669, 122.679, 137.86]
189
+ },
190
+ 'BSDS': {'img_height': 480, # 321
191
+ 'img_width': 480, # 481
192
+ 'test_list': 'test_pair.lst',
193
+ 'data_dir': 'C:/dataset/BSDS', # mean_rgb
194
+ 'yita': 0.5,
195
+ 'mean':[103.939, 116.669, 122.679, 137.86] },
196
+ # [103.939, 116.669, 122.679, 137.86]
197
+ #[159.510, 159.451,162.230,137.86]
198
+ 'BRIND': {
199
+ 'img_height': 512, # 321
200
+ 'img_width': 512, # 481
201
+ 'train_list': 'train_pair_all.lst',
202
+ # all train_pair_all.lst
203
+ # less train_pair.lst
204
+ 'test_list': 'test_pair.lst',
205
+ 'data_dir': 'C:/dataset/BRIND', # mean_rgb
206
+ 'yita': 0.5,
207
+ 'mean': [104.007, 116.669, 122.679, 137.86]
208
+ },
209
+ 'ICEDA': {
210
+ 'img_height': 1024, # 321
211
+ 'img_width': 1408, # 481
212
+ 'train_list': None,
213
+ 'test_list': 'test_pair.lst',
214
+ 'data_dir': 'C:/dataset/ICEDA', # mean_rgb
215
+ 'yita': 0.5,
216
+ 'mean': [104.007, 116.669, 122.679, 137.86]
217
+ },
218
+ 'BSDS300': {'img_height': 512, # 321
219
+ 'img_width': 512, # 481
220
+ 'test_list': 'test_pair.lst',
221
+ 'data_dir': 'C:/Users/xavysp/dataset/BSDS300', # NIR
222
+ 'yita': 0.5,
223
+ 'mean': [104.007, 116.669, 122.679, 137.86]},
224
+ 'PASCAL': {'img_height': 375,
225
+ 'img_width': 500,
226
+ 'test_list': 'test_pair.lst',
227
+ 'data_dir': 'C:/dataset/PASCAL', # mean_rgb
228
+ 'yita': 0.3,
229
+ 'mean': [104.007, 116.669, 122.679, 137.86]},
230
+ 'CID': {'img_height': 512,
231
+ 'img_width': 512,
232
+ 'test_list': 'test_pair.lst',
233
+ 'data_dir': 'C:/dataset/CID', # mean_rgb
234
+ 'yita': 0.3,
235
+ 'mean': [104.007, 116.669, 122.679, 137.86]},
236
+ 'NYUD': {'img_height': 425,
237
+ 'img_width': 560,
238
+ 'test_list': 'test_pair.lst',
239
+ 'data_dir': 'C:/dataset/NYUD', # mean_rgb
240
+ 'yita': 0.5,
241
+ 'mean': [104.007, 116.669, 122.679, 137.86]},
242
+ 'MDBD': {'img_height': 720,
243
+ 'img_width': 1280,
244
+ 'test_list': 'test_pair.lst',
245
+ 'train_list': 'train_pair.lst',
246
+ 'data_dir': 'C:/dataset/MDBD', # mean_rgb
247
+ 'yita': 0.3,
248
+ 'mean': [104.007, 116.669, 122.679, 137.86]},
249
+ 'BIPED': {'img_height': 720, # 720
250
+ 'img_width': 1280, # 1280
251
+ 'test_list': 'test_pair.lst',
252
+ 'train_list': 'train_pair0.lst',
253
+ # 'train_list': 'train_rgb.lst',
254
+ 'data_dir': 'C:/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
255
+ 'yita': 0.5,
256
+ 'mean':BIPED_mean},
257
+ 'BIPED-B2': {'img_height': 720, # 720
258
+ 'img_width': 1280, # 1280
259
+ 'test_list': 'test_pair.lst',
260
+ 'train_list': 'train_rgb.lst',
261
+ 'data_dir': 'C:/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
262
+ 'yita': 0.5,
263
+ 'mean':BIPED_mean},
264
+ 'BIPED-B3': {'img_height': 720, # 720
265
+ 'img_width': 1280, # 1280
266
+ 'test_list': 'test_pair.lst',
267
+ 'train_list': 'train_rgb.lst',
268
+ 'data_dir': 'C:/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
269
+ 'yita': 0.5,
270
+ 'mean':BIPED_mean},
271
+ 'BIPED-B5': {'img_height': 720, # 720
272
+ 'img_width': 1280, # 1280
273
+ 'test_list': 'test_pair.lst',
274
+ 'train_list': 'train_rgb.lst',
275
+ 'data_dir': 'C:/Users/xavysp/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
276
+ 'yita': 0.5,
277
+ 'mean':BIPED_mean},
278
+ 'BIPED-B6': {'img_height': 720, # 720
279
+ 'img_width': 1280, # 1280
280
+ 'test_list': 'test_pair.lst',
281
+ 'train_list': 'train_rgb.lst',
282
+ 'data_dir': 'C:/Users/xavysp/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
283
+ 'yita': 0.5,
284
+ 'mean':BIPED_mean},
285
+ 'CLASSIC': {'img_height': 512,
286
+ 'img_width': 512,
287
+ 'test_list': None,
288
+ 'train_list': None,
289
+ 'data_dir': 'teed_tmp', # mean_rgb
290
+ 'yita': 0.5,
291
+ 'mean': [104.007, 116.669, 122.679, 137.86]},
292
+ 'DCD': {'img_height': 240,
293
+ 'img_width': 360,
294
+ 'test_list': 'test_pair.lst',
295
+ 'data_dir': 'C:/dataset/DCD', # mean_rgb
296
+ 'yita': 0.2,
297
+ 'mean': [104.007, 116.669, 122.679, 137.86]}
298
+ }
299
+ return config[dataset_name]
300
+
301
+ class TestDataset(Dataset):
302
+ def __init__(self,
303
+ data_root,
304
+ test_data,
305
+ img_height,
306
+ img_width,
307
+ test_list=None,
308
+ arg=None
309
+ ):
310
+ if test_data not in DATASET_NAMES:
311
+ raise ValueError(f"Unsupported dataset: {test_data}")
312
+
313
+ self.data_root = data_root
314
+ self.test_data = test_data
315
+ self.test_list = test_list
316
+ self.args = arg
317
+ self.up_scale = arg.up_scale
318
+ self.mean_bgr = arg.mean_test if len(arg.mean_test) == 3 else arg.mean_test[:3]
319
+ self.img_height = img_height
320
+ self.img_width = img_width
321
+ self.data_index = self._build_index()
322
+
323
+
324
+ def _build_index(self):
325
+ sample_indices = []
326
+ if self.test_data == "CLASSIC":
327
+ # for single image testing
328
+ images_path = os.listdir(self.data_root)
329
+ labels_path = None
330
+ sample_indices = [images_path, labels_path]
331
+ else:
332
+ # image and label paths are located in a list file
333
+
334
+ if not self.test_list:
335
+ raise ValueError(
336
+ f"Test list not provided for dataset: {self.test_data}")
337
+
338
+ list_name = os.path.join(self.data_root, self.test_list)
339
+ if self.test_data.upper() in ['BIPED', 'BRIND','UDED','ICEDA']:
340
+
341
+ with open(list_name,encoding='utf-8') as f:
342
+ files = json.load(f)
343
+ for pair in files:
344
+ tmp_img = pair[0]
345
+ tmp_gt = pair[1]
346
+ sample_indices.append(
347
+ (os.path.join(self.data_root, tmp_img),
348
+ os.path.join(self.data_root, tmp_gt),))
349
+ else:
350
+ with open(list_name, 'r') as f:
351
+ files = f.readlines()
352
+ files = [line.strip() for line in files]
353
+ pairs = [line.split() for line in files]
354
+
355
+ for pair in pairs:
356
+ tmp_img = pair[0]
357
+ tmp_gt = pair[1]
358
+ sample_indices.append(
359
+ (os.path.join(self.data_root, tmp_img),
360
+ os.path.join(self.data_root, tmp_gt),))
361
+ return sample_indices
362
+
363
+ def __len__(self):
364
+ return len(self.data_index[0]) if self.test_data.upper() == 'CLASSIC' else len(self.data_index)
365
+
366
+ def __getitem__(self, idx):
367
+ # get data sample
368
+ # image_path, label_path = self.data_index[idx]
369
+ if self.data_index[1] is None:
370
+ image_path = self.data_index[0][idx] if len(self.data_index[0]) > 1 else self.data_index[0][idx - 1]
371
+ else:
372
+ image_path = self.data_index[idx][0]
373
+ label_path = None if self.test_data == "CLASSIC" else self.data_index[idx][1]
374
+ img_name = os.path.basename(image_path)
375
+ # print(img_name)
376
+ file_name = os.path.splitext(img_name)[0] + ".png"
377
+
378
+ # base dir
379
+ if self.test_data.upper() == 'BIPED':
380
+ img_dir = os.path.join(self.data_root, 'imgs', 'test')
381
+ gt_dir = os.path.join(self.data_root, 'edge_maps', 'test')
382
+ elif self.test_data.upper() == 'CLASSIC':
383
+ img_dir = self.data_root
384
+ gt_dir = None
385
+ else:
386
+ img_dir = self.data_root
387
+ gt_dir = self.data_root
388
+
389
+ # load data
390
+ image = cv2.imdecode(np.fromfile(os.path.join(img_dir, image_path), np.uint8), cv2.IMREAD_COLOR)
391
+ if not self.test_data == "CLASSIC":
392
+ label = cv2.imread(os.path.join(
393
+ gt_dir, label_path), cv2.IMREAD_COLOR)
394
+ else:
395
+ label = None
396
+
397
+ im_shape = [image.shape[0], image.shape[1]]
398
+ image, label = self.transform(img=image, gt=label)
399
+
400
+ return dict(images=image, labels=label, file_names=file_name, image_shape=im_shape)
401
+
402
+ def transform(self, img, gt):
403
+ # gt[gt< 51] = 0 # test without gt discrimination
404
+ # up scale test image
405
+ if self.up_scale:
406
+ # For TEED BIPBRIlight Upscale
407
+ img = cv2.resize(img,(0,0),fx=1.3,fy=1.3)
408
+
409
+ if img.shape[0] < 512 or img.shape[1] < 512:
410
+ #TEED BIPED standard proposal if you want speed up the test, comment this block
411
+ img = cv2.resize(img, (0, 0), fx=1.5, fy=1.5)
412
+ # else:
413
+ # img = cv2.resize(img, (0, 0), fx=1.1, fy=1.1)
414
+
415
+ # Make sure images and labels are divisible by 2^4=16
416
+ if img.shape[0] % 8 != 0 or img.shape[1] % 8 != 0:
417
+ img_width = ((img.shape[1] // 8) + 1) * 8
418
+ img_height = ((img.shape[0] // 8) + 1) * 8
419
+ img = cv2.resize(img, (img_width, img_height))
420
+ # gt = cv2.resize(gt, (img_width, img_height))
421
+ else:
422
+ pass
423
+ # img_width = self.args.test_img_width
424
+ # img_height = self.args.test_img_height
425
+ # img = cv2.resize(img, (img_width, img_height))
426
+ # gt = cv2.resize(gt, (img_width, img_height))
427
+ # # For FPS
428
+ # img = cv2.resize(img, (496,320))
429
+
430
+ img = np.array(img, dtype=np.float32)
431
+ # if self.rgb:
432
+ # img = img[:, :, ::-1] # RGB->BGR
433
+
434
+ img -= self.mean_bgr
435
+ img = img.transpose((2, 0, 1))
436
+ img = torch.from_numpy(img.copy()).float()
437
+
438
+ if self.test_data == "CLASSIC":
439
+ gt = np.zeros((img.shape[:2]))
440
+ gt = torch.from_numpy(np.array([gt])).float()
441
+ else:
442
+ gt = np.array(gt, dtype=np.float32)
443
+ if len(gt.shape) == 3:
444
+ gt = gt[:, :, 0]
445
+ gt /= 255.
446
+ gt = torch.from_numpy(np.array([gt])).float()
447
+
448
+ return img, gt
449
+
450
+ # *************************************************
451
+ # ************* training **************************
452
+ # *************************************************
453
+ class BipedDataset(Dataset):
454
+ train_modes = ['train', 'test', ]
455
+ dataset_types = ['rgbr', ]
456
+ data_types = ['aug', ]
457
+
458
+ def __init__(self,
459
+ data_root,
460
+ img_height,
461
+ img_width,
462
+ train_mode='train',
463
+ dataset_type='rgbr',
464
+ # is_scaling=None,
465
+ # Whether to crop image or otherwise resize image to match image height and width.
466
+ crop_img=False,
467
+ arg=None
468
+ ):
469
+ self.data_root = data_root
470
+ self.train_mode = train_mode
471
+ self.dataset_type = dataset_type
472
+ self.data_type = 'aug' # be aware that this might change in the future
473
+ self.img_height = img_height
474
+ self.img_width = img_width
475
+ self.mean_bgr = arg.mean_train if len(arg.mean_train) == 3 else arg.mean_train[:3]
476
+ self.crop_img = crop_img
477
+ self.arg = arg
478
+
479
+ self.data_index = self._build_index()
480
+
481
+ def _build_index(self):
482
+ assert self.train_mode in self.train_modes, self.train_mode
483
+ assert self.dataset_type in self.dataset_types, self.dataset_type
484
+ assert self.data_type in self.data_types, self.data_type
485
+
486
+ data_root = os.path.abspath(self.data_root)
487
+ sample_indices = []
488
+
489
+ file_path = os.path.join(data_root, self.arg.train_list)
490
+ if self.arg.train_data.lower() == 'bsds':
491
+
492
+ with open(file_path, 'r') as f:
493
+ files = f.readlines()
494
+ files = [line.strip() for line in files]
495
+
496
+ pairs = [line.split() for line in files]
497
+ for pair in pairs:
498
+ tmp_img = pair[0]
499
+ tmp_gt = pair[1]
500
+ sample_indices.append(
501
+ (os.path.join(data_root, tmp_img),
502
+ os.path.join(data_root, tmp_gt),))
503
+ else:
504
+ with open(file_path) as f:
505
+ files = json.load(f)
506
+ for pair in files:
507
+ tmp_img = pair[0]
508
+ tmp_gt = pair[1]
509
+ sample_indices.append(
510
+ (os.path.join(data_root, tmp_img),
511
+ os.path.join(data_root, tmp_gt),))
512
+
513
+ return sample_indices
514
+
515
+ def __len__(self):
516
+ return len(self.data_index)
517
+
518
+ def __getitem__(self, idx):
519
+ # get data sample
520
+ image_path, label_path = self.data_index[idx]
521
+
522
+ # load data
523
+ image = cv2.imdecode(np.fromfile(image_path, np.uint8), cv2.IMREAD_COLOR)
524
+ label = cv2.imdecode(np.fromfile(label_path), cv2.IMREAD_GRAYSCALE)
525
+ image, label = self.transform(img=image, gt=label)
526
+ return dict(images=image, labels=label)
527
+
528
+ def transform(self, img, gt):
529
+ gt = np.array(gt, dtype=np.float32)
530
+ if len(gt.shape) == 3:
531
+ gt = gt[:, :, 0]
532
+
533
+ gt /= 255. # for LDC input and BDCN
534
+
535
+ img = np.array(img, dtype=np.float32)
536
+ img -= self.mean_bgr
537
+ i_h, i_w, _ = img.shape
538
+ # 400 for BIPEd and 352 for BSDS check with 384
539
+ crop_size = self.img_height if self.img_height == self.img_width else None # 448# MDBD=480 BIPED=480/400 BSDS=352
540
+ #
541
+ # # for BSDS 352/BRIND
542
+ # if i_w > crop_size and i_h > crop_size: # later 400, before crop_size
543
+ # i = random.randint(0, i_h - crop_size)
544
+ # j = random.randint(0, i_w - crop_size)
545
+ # img = img[i:i + crop_size, j:j + crop_size]
546
+ # gt = gt[i:i + crop_size, j:j + crop_size]
547
+
548
+ # for BIPED/MDBD
549
+ # Second augmentation
550
+ if i_w> 400 and i_h>400: #before 420
551
+ h,w = gt.shape
552
+ if np.random.random() > 0.4: #before i_w> 500 and i_h>500:
553
+
554
+ LR_img_size = crop_size #l BIPED=256, 240 200 # MDBD= 352 BSDS= 176
555
+ i = random.randint(0, h - LR_img_size)
556
+ j = random.randint(0, w - LR_img_size)
557
+ # if img.
558
+ img = img[i:i + LR_img_size , j:j + LR_img_size ]
559
+ gt = gt[i:i + LR_img_size , j:j + LR_img_size ]
560
+ else:
561
+ LR_img_size = 300# 256 300 400 # l BIPED=208-352, # MDBD= 352-480- BSDS= 176-320
562
+ i = random.randint(0, h - LR_img_size)
563
+ j = random.randint(0, w - LR_img_size)
564
+ # if img.
565
+ img = img[i:i + LR_img_size, j:j + LR_img_size]
566
+ gt = gt[i:i + LR_img_size, j:j + LR_img_size]
567
+ img = cv2.resize(img, dsize=(crop_size, crop_size), )
568
+ gt = cv2.resize(gt, dsize=(crop_size, crop_size))
569
+
570
+ else:
571
+ # New addidings
572
+ img = cv2.resize(img, dsize=(crop_size, crop_size))
573
+ gt = cv2.resize(gt, dsize=(crop_size, crop_size))
574
+ # BRIND Best for TEDD+BIPED
575
+ gt[gt > 0.1] +=0.2#0.4
576
+ gt = np.clip(gt, 0., 1.)
577
+
578
+ img = img.transpose((2, 0, 1))
579
+ img = torch.from_numpy(img.copy()).float()
580
+ gt = torch.from_numpy(np.array([gt])).float()
581
+ return img, gt
TEED/devices.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import contextlib
3
+ from functools import lru_cache
4
+
5
+ import torch
6
+ from modules import errors, shared, npu_specific
7
+
8
+ if sys.platform == "darwin":
9
+ from modules import mac_specific
10
+
11
+ if shared.cmd_opts.use_ipex:
12
+ from modules import xpu_specific
13
+
14
+
15
+ def has_xpu() -> bool:
16
+ return shared.cmd_opts.use_ipex and xpu_specific.has_xpu
17
+
18
+
19
+ def has_mps() -> bool:
20
+ if sys.platform != "darwin":
21
+ return False
22
+ else:
23
+ return mac_specific.has_mps
24
+
25
+
26
+ def cuda_no_autocast(device_id=None) -> bool:
27
+ if device_id is None:
28
+ device_id = get_cuda_device_id()
29
+ return (
30
+ torch.cuda.get_device_capability(device_id) == (7, 5)
31
+ and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
32
+ )
33
+
34
+
35
+ def get_cuda_device_id():
36
+ return (
37
+ int(shared.cmd_opts.device_id)
38
+ if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
39
+ else 0
40
+ ) or torch.cuda.current_device()
41
+
42
+
43
+ def get_cuda_device_string():
44
+ if shared.cmd_opts.device_id is not None:
45
+ return f"cuda:{shared.cmd_opts.device_id}"
46
+
47
+ return "cuda"
48
+
49
+
50
+ def get_optimal_device_name():
51
+ if torch.cuda.is_available():
52
+ return get_cuda_device_string()
53
+
54
+ if has_mps():
55
+ return "mps"
56
+
57
+ if has_xpu():
58
+ return xpu_specific.get_xpu_device_string()
59
+
60
+ if npu_specific.has_npu:
61
+ return npu_specific.get_npu_device_string()
62
+
63
+ return "cpu"
64
+
65
+
66
+ def get_optimal_device():
67
+ return torch.device(get_optimal_device_name())
68
+
69
+
70
+ def get_device_for(task):
71
+ if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:
72
+ return cpu
73
+
74
+ return get_optimal_device()
75
+
76
+
77
+ def torch_gc():
78
+
79
+ if torch.cuda.is_available():
80
+ with torch.cuda.device(get_cuda_device_string()):
81
+ torch.cuda.empty_cache()
82
+ torch.cuda.ipc_collect()
83
+
84
+ if has_mps():
85
+ mac_specific.torch_mps_gc()
86
+
87
+ if has_xpu():
88
+ xpu_specific.torch_xpu_gc()
89
+
90
+ if npu_specific.has_npu:
91
+ torch_npu_set_device()
92
+ npu_specific.torch_npu_gc()
93
+
94
+
95
+ def torch_npu_set_device():
96
+ # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
97
+ if npu_specific.has_npu:
98
+ torch.npu.set_device(0)
99
+
100
+
101
+ def enable_tf32():
102
+ if torch.cuda.is_available():
103
+
104
+ # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
105
+ # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
106
+ if cuda_no_autocast():
107
+ torch.backends.cudnn.benchmark = True
108
+
109
+ torch.backends.cuda.matmul.allow_tf32 = True
110
+ torch.backends.cudnn.allow_tf32 = True
111
+
112
+
113
+ errors.run(enable_tf32, "Enabling TF32")
114
+
115
+ cpu: torch.device = torch.device("cpu")
116
+ fp8: bool = False
117
+ device: torch.device = None
118
+ device_interrogate: torch.device = None
119
+ device_gfpgan: torch.device = None
120
+ device_esrgan: torch.device = None
121
+ device_codeformer: torch.device = None
122
+ dtype: torch.dtype = torch.float16
123
+ dtype_vae: torch.dtype = torch.float16
124
+ dtype_unet: torch.dtype = torch.float16
125
+ dtype_inference: torch.dtype = torch.float16
126
+ unet_needs_upcast = False
127
+
128
+
129
+ def cond_cast_unet(input):
130
+ return input.to(dtype_unet) if unet_needs_upcast else input
131
+
132
+
133
+ def cond_cast_float(input):
134
+ return input.float() if unet_needs_upcast else input
135
+
136
+
137
+ nv_rng = None
138
+ patch_module_list = [
139
+ torch.nn.Linear,
140
+ torch.nn.Conv2d,
141
+ torch.nn.MultiheadAttention,
142
+ torch.nn.GroupNorm,
143
+ torch.nn.LayerNorm,
144
+ ]
145
+
146
+
147
+ def manual_cast_forward(target_dtype):
148
+ def forward_wrapper(self, *args, **kwargs):
149
+ if any(
150
+ isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
151
+ for arg in args
152
+ ):
153
+ args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
154
+ kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
155
+
156
+ org_dtype = target_dtype
157
+ for param in self.parameters():
158
+ if param.dtype != target_dtype:
159
+ org_dtype = param.dtype
160
+ break
161
+
162
+ if org_dtype != target_dtype:
163
+ self.to(target_dtype)
164
+ result = self.org_forward(*args, **kwargs)
165
+ if org_dtype != target_dtype:
166
+ self.to(org_dtype)
167
+
168
+ if target_dtype != dtype_inference:
169
+ if isinstance(result, tuple):
170
+ result = tuple(
171
+ i.to(dtype_inference)
172
+ if isinstance(i, torch.Tensor)
173
+ else i
174
+ for i in result
175
+ )
176
+ elif isinstance(result, torch.Tensor):
177
+ result = result.to(dtype_inference)
178
+ return result
179
+ return forward_wrapper
180
+
181
+
182
+ @contextlib.contextmanager
183
+ def manual_cast(target_dtype):
184
+ applied = False
185
+ for module_type in patch_module_list:
186
+ if hasattr(module_type, "org_forward"):
187
+ continue
188
+ applied = True
189
+ org_forward = module_type.forward
190
+ if module_type == torch.nn.MultiheadAttention:
191
+ module_type.forward = manual_cast_forward(torch.float32)
192
+ else:
193
+ module_type.forward = manual_cast_forward(target_dtype)
194
+ module_type.org_forward = org_forward
195
+ try:
196
+ yield None
197
+ finally:
198
+ if applied:
199
+ for module_type in patch_module_list:
200
+ if hasattr(module_type, "org_forward"):
201
+ module_type.forward = module_type.org_forward
202
+ delattr(module_type, "org_forward")
203
+
204
+
205
+ def autocast(disable=False):
206
+ if disable:
207
+ return contextlib.nullcontext()
208
+
209
+ if fp8 and device==cpu:
210
+ return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
211
+
212
+ if fp8 and dtype_inference == torch.float32:
213
+ return manual_cast(dtype)
214
+
215
+ if dtype == torch.float32 or dtype_inference == torch.float32:
216
+ return contextlib.nullcontext()
217
+
218
+ if has_xpu() or has_mps() or cuda_no_autocast():
219
+ return manual_cast(dtype)
220
+
221
+ return torch.autocast("cuda")
222
+
223
+
224
+ def without_autocast(disable=False):
225
+ return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
226
+
227
+
228
+ class NansException(Exception):
229
+ pass
230
+
231
+
232
+ def test_for_nans(x, where):
233
+ if shared.cmd_opts.disable_nan_check:
234
+ return
235
+
236
+ if not torch.all(torch.isnan(x)).item():
237
+ return
238
+
239
+ if where == "unet":
240
+ message = "A tensor with all NaNs was produced in Unet."
241
+
242
+ if not shared.cmd_opts.no_half:
243
+ message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
244
+
245
+ elif where == "vae":
246
+ message = "A tensor with all NaNs was produced in VAE."
247
+
248
+ if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
249
+ message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
250
+ else:
251
+ message = "A tensor with all NaNs was produced."
252
+
253
+ message += " Use --disable-nan-check commandline argument to disable this check."
254
+
255
+ raise NansException(message)
256
+
257
+
258
+ @lru_cache
259
+ def first_time_calculation():
260
+ """
261
+ just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
262
+ spends about 2.7 seconds doing that, at least with NVidia.
263
+ """
264
+
265
+ x = torch.zeros((1, 1)).to(device, dtype)
266
+ linear = torch.nn.Linear(1, 1).to(device, dtype)
267
+ linear(x)
268
+
269
+ x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
270
+ conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
271
+ conv2d(x)
TEED/loss2.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from TEED.utils.AF.Fsmish import smish as Fsmish
4
+
5
+ def bdcn_loss2(inputs, targets, l_weight=1.1):
6
+ # bdcn loss modified in DexiNed
7
+
8
+ targets = targets.long()
9
+ mask = targets.float()
10
+ num_positive = torch.sum((mask > 0.0).float()).float() # >0.1
11
+ num_negative = torch.sum((mask <= 0.0).float()).float() # <= 0.1
12
+
13
+ mask[mask > 0.] = 1.0 * num_negative / (num_positive + num_negative) #0.1
14
+ mask[mask <= 0.] = 1.1 * num_positive / (num_positive + num_negative) # before mask[mask <= 0.1]
15
+ inputs= torch.sigmoid(inputs)
16
+ cost = torch.nn.BCELoss(mask, reduction='none')(inputs, targets.float())
17
+ cost = torch.sum(cost.float().mean((1, 2, 3))) # before sum
18
+ return l_weight*cost
19
+
20
+ # ------------ cats losses ----------
21
+ def bdrloss(prediction, label, radius,device='cpu'):
22
+ '''
23
+ The boundary tracing loss that handles the confusing pixels.
24
+ '''
25
+
26
+ filt = torch.ones(1, 1, 2*radius+1, 2*radius+1)
27
+ filt.requires_grad = False
28
+ filt = filt.to(device)
29
+
30
+ bdr_pred = prediction * label
31
+ pred_bdr_sum = label * F.conv2d(bdr_pred, filt, bias=None, stride=1, padding=radius)
32
+
33
+ texture_mask = F.conv2d(label.float(), filt, bias=None, stride=1, padding=radius)
34
+ mask = (texture_mask != 0).float()
35
+ mask[label == 1] = 0
36
+ pred_texture_sum = F.conv2d(prediction * (1-label) * mask, filt, bias=None, stride=1, padding=radius)
37
+
38
+ softmax_map = torch.clamp(pred_bdr_sum / (pred_texture_sum + pred_bdr_sum + 1e-10), 1e-10, 1 - 1e-10)
39
+ cost = -label * torch.log(softmax_map)
40
+ cost[label == 0] = 0
41
+
42
+ return torch.sum(cost.float().mean((1, 2, 3)))
43
+
44
+ def textureloss(prediction, label, mask_radius, device='cpu'):
45
+ '''
46
+ The texture suppression loss that smooths the texture regions.
47
+ '''
48
+ filt1 = torch.ones(1, 1, 3, 3)
49
+ filt1.requires_grad = False
50
+ filt1 = filt1.to(device)
51
+ filt2 = torch.ones(1, 1, 2*mask_radius+1, 2*mask_radius+1)
52
+ filt2.requires_grad = False
53
+ filt2 = filt2.to(device)
54
+
55
+ pred_sums = F.conv2d(prediction.float(), filt1, bias=None, stride=1, padding=1)
56
+ label_sums = F.conv2d(label.float(), filt2, bias=None, stride=1, padding=mask_radius)
57
+
58
+ mask = 1 - torch.gt(label_sums, 0).float()
59
+
60
+ loss = -torch.log(torch.clamp(1-pred_sums/9, 1e-10, 1-1e-10))
61
+ loss[mask == 0] = 0
62
+
63
+ return torch.sum(loss.float().mean((1, 2, 3)))
64
+
65
+
66
+ def cats_loss(prediction, label, l_weight=[0.,0.], device='cpu'):
67
+ # tracingLoss
68
+
69
+ tex_factor,bdr_factor = l_weight
70
+ balanced_w = 1.1
71
+ label = label.float()
72
+ prediction = prediction.float()
73
+ with torch.no_grad():
74
+ mask = label.clone()
75
+
76
+ num_positive = torch.sum((mask == 1).float()).float()
77
+ num_negative = torch.sum((mask == 0).float()).float()
78
+ beta = num_negative / (num_positive + num_negative)
79
+ mask[mask == 1] = beta
80
+ mask[mask == 0] = balanced_w * (1 - beta)
81
+ mask[mask == 2] = 0
82
+
83
+ prediction = torch.sigmoid(prediction)
84
+
85
+ cost = torch.nn.functional.binary_cross_entropy(
86
+ prediction.float(), label.float(), weight=mask, reduction='none')
87
+ cost = torch.sum(cost.float().mean((1, 2, 3))) # by me
88
+ label_w = (label != 0).float()
89
+ textcost = textureloss(prediction.float(), label_w.float(), mask_radius=4, device=device)
90
+ bdrcost = bdrloss(prediction.float(), label_w.float(), radius=4, device=device)
91
+
92
+ return cost + bdr_factor * bdrcost + tex_factor * textcost
TEED/ted.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TEED: is a Tiny but Efficient Edge Detection, it comes from the LDC-B3
2
+ # with a Slightly modification
3
+ # LDC parameters:
4
+ # 155665
5
+ # TED > 58K
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from TEED.utils.AF.Fsmish import smish as Fsmish
12
+ from TEED.utils.AF.Xsmish import Smish
13
+ from TEED.utils.img_processing import count_parameters
14
+
15
+
16
+ def weight_init(m):
17
+ if isinstance(m, (nn.Conv2d,)):
18
+ torch.nn.init.xavier_normal_(m.weight, gain=1.0)
19
+
20
+ if m.bias is not None:
21
+ torch.nn.init.zeros_(m.bias)
22
+
23
+ # for fusion layer
24
+ if isinstance(m, (nn.ConvTranspose2d,)):
25
+ torch.nn.init.xavier_normal_(m.weight, gain=1.0)
26
+ if m.bias is not None:
27
+ torch.nn.init.zeros_(m.bias)
28
+
29
+ class CoFusion(nn.Module):
30
+ # from LDC
31
+
32
+ def __init__(self, in_ch, out_ch):
33
+ super(CoFusion, self).__init__()
34
+ self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3,
35
+ stride=1, padding=1) # before 64
36
+ self.conv3= nn.Conv2d(32, out_ch, kernel_size=3,
37
+ stride=1, padding=1)# before 64 instead of 32
38
+ self.relu = nn.ReLU()
39
+ self.norm_layer1 = nn.GroupNorm(4, 32) # before 64
40
+
41
+ def forward(self, x):
42
+ # fusecat = torch.cat(x, dim=1)
43
+ attn = self.relu(self.norm_layer1(self.conv1(x)))
44
+ attn = F.softmax(self.conv3(attn), dim=1)
45
+ return ((x * attn).sum(1)).unsqueeze(1)
46
+
47
+
48
+ class CoFusion2(nn.Module):
49
+ # TEDv14-3
50
+ def __init__(self, in_ch, out_ch):
51
+ super(CoFusion2, self).__init__()
52
+ self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3,
53
+ stride=1, padding=1) # before 64
54
+ # self.conv2 = nn.Conv2d(32, 32, kernel_size=3,
55
+ # stride=1, padding=1)# before 64
56
+ self.conv3 = nn.Conv2d(32, out_ch, kernel_size=3,
57
+ stride=1, padding=1)# before 64 instead of 32
58
+ self.smish= Smish()#nn.ReLU(inplace=True)
59
+
60
+
61
+ def forward(self, x):
62
+ # fusecat = torch.cat(x, dim=1)
63
+ attn = self.conv1(self.smish(x))
64
+ attn = self.conv3(self.smish(attn)) # before , )dim=1)
65
+
66
+ # return ((fusecat * attn).sum(1)).unsqueeze(1)
67
+ return ((x * attn).sum(1)).unsqueeze(1)
68
+
69
+ class DoubleFusion(nn.Module):
70
+ # TED fusion before the final edge map prediction
71
+ def __init__(self, in_ch, out_ch):
72
+ super(DoubleFusion, self).__init__()
73
+ self.DWconv1 = nn.Conv2d(in_ch, in_ch*8, kernel_size=3,
74
+ stride=1, padding=1, groups=in_ch) # before 64
75
+ self.PSconv1 = nn.PixelShuffle(1)
76
+
77
+ self.DWconv2 = nn.Conv2d(24, 24*1, kernel_size=3,
78
+ stride=1, padding=1,groups=24)# before 64 instead of 32
79
+
80
+ self.AF= Smish()#XAF() #nn.Tanh()# XAF() # # Smish()#
81
+
82
+
83
+ def forward(self, x):
84
+ # fusecat = torch.cat(x, dim=1)
85
+ attn = self.PSconv1(self.DWconv1(self.AF(x))) # #TEED best res TEDv14 [8, 32, 352, 352]
86
+
87
+ attn2 = self.PSconv1(self.DWconv2(self.AF(attn))) # #TEED best res TEDv14[8, 3, 352, 352]
88
+
89
+ return Fsmish(((attn2 +attn).sum(1)).unsqueeze(1)) #TED best res
90
+
91
+ class _DenseLayer(nn.Sequential):
92
+ def __init__(self, input_features, out_features):
93
+ super(_DenseLayer, self).__init__()
94
+
95
+ self.add_module('conv1', nn.Conv2d(input_features, out_features,
96
+ kernel_size=3, stride=1, padding=2, bias=True)),
97
+ self.add_module('smish1', Smish()),
98
+ self.add_module('conv2', nn.Conv2d(out_features, out_features,
99
+ kernel_size=3, stride=1, bias=True))
100
+ def forward(self, x):
101
+ x1, x2 = x
102
+
103
+ new_features = super(_DenseLayer, self).forward(Fsmish(x1)) # F.relu()
104
+
105
+ return 0.5 * (new_features + x2), x2
106
+
107
+
108
+ class _DenseBlock(nn.Sequential):
109
+ def __init__(self, num_layers, input_features, out_features):
110
+ super(_DenseBlock, self).__init__()
111
+ for i in range(num_layers):
112
+ layer = _DenseLayer(input_features, out_features)
113
+ self.add_module('denselayer%d' % (i + 1), layer)
114
+ input_features = out_features
115
+
116
+
117
+ class UpConvBlock(nn.Module):
118
+ def __init__(self, in_features, up_scale):
119
+ super(UpConvBlock, self).__init__()
120
+ self.up_factor = 2
121
+ self.constant_features = 16
122
+
123
+ layers = self.make_deconv_layers(in_features, up_scale)
124
+ assert layers is not None, layers
125
+ self.features = nn.Sequential(*layers)
126
+
127
+ def make_deconv_layers(self, in_features, up_scale):
128
+ layers = []
129
+ all_pads=[0,0,1,3,7]
130
+ for i in range(up_scale):
131
+ kernel_size = 2 ** up_scale
132
+ pad = all_pads[up_scale] # kernel_size-1
133
+ out_features = self.compute_out_features(i, up_scale)
134
+ layers.append(nn.Conv2d(in_features, out_features, 1))
135
+ layers.append(Smish())
136
+ layers.append(nn.ConvTranspose2d(
137
+ out_features, out_features, kernel_size, stride=2, padding=pad))
138
+ in_features = out_features
139
+ return layers
140
+
141
+ def compute_out_features(self, idx, up_scale):
142
+ return 1 if idx == up_scale - 1 else self.constant_features
143
+
144
+ def forward(self, x):
145
+ return self.features(x)
146
+
147
+
148
+ class SingleConvBlock(nn.Module):
149
+ def __init__(self, in_features, out_features, stride, use_ac=False):
150
+ super(SingleConvBlock, self).__init__()
151
+ # self.use_bn = use_bs
152
+ self.use_ac=use_ac
153
+ self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride,
154
+ bias=True)
155
+ if self.use_ac:
156
+ self.smish = Smish()
157
+
158
+ def forward(self, x):
159
+ x = self.conv(x)
160
+ if self.use_ac:
161
+ return self.smish(x)
162
+ else:
163
+ return x
164
+
165
+ class DoubleConvBlock(nn.Module):
166
+ def __init__(self, in_features, mid_features,
167
+ out_features=None,
168
+ stride=1,
169
+ use_act=True):
170
+ super(DoubleConvBlock, self).__init__()
171
+
172
+ self.use_act = use_act
173
+ if out_features is None:
174
+ out_features = mid_features
175
+ self.conv1 = nn.Conv2d(in_features, mid_features,
176
+ 3, padding=1, stride=stride)
177
+ self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1)
178
+ self.smish= Smish()#nn.ReLU(inplace=True)
179
+
180
+ def forward(self, x):
181
+ x = self.conv1(x)
182
+ x = self.smish(x)
183
+ x = self.conv2(x)
184
+ if self.use_act:
185
+ x = self.smish(x)
186
+ return x
187
+
188
+
189
+ class TED(nn.Module):
190
+ """ Definition of Tiny and Efficient Edge Detector
191
+ model
192
+ """
193
+
194
+ def __init__(self):
195
+ super(TED, self).__init__()
196
+ self.block_1 = DoubleConvBlock(3, 16, 16, stride=2,)
197
+ self.block_2 = DoubleConvBlock(16, 32, use_act=False)
198
+ self.dblock_3 = _DenseBlock(1, 32, 48) # [32,48,100,100] before (2, 32, 64)
199
+
200
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
201
+
202
+ # skip1 connection, see fig. 2
203
+ self.side_1 = SingleConvBlock(16, 32, 2)
204
+
205
+ # skip2 connection, see fig. 2
206
+ self.pre_dense_3 = SingleConvBlock(32, 48, 1) # before (32, 64, 1)
207
+
208
+ # USNet
209
+ self.up_block_1 = UpConvBlock(16, 1)
210
+ self.up_block_2 = UpConvBlock(32, 1)
211
+ self.up_block_3 = UpConvBlock(48, 2) # (32, 64, 1)
212
+
213
+ self.block_cat = DoubleFusion(3,3) # TEED: DoubleFusion
214
+
215
+ self.apply(weight_init)
216
+
217
+ def slice(self, tensor, slice_shape):
218
+ t_shape = tensor.shape
219
+ img_h, img_w = slice_shape
220
+ if img_w!=t_shape[-1] or img_h!=t_shape[2]:
221
+ new_tensor = F.interpolate(
222
+ tensor, size=(img_h, img_w), mode='bicubic',align_corners=False)
223
+
224
+ else:
225
+ new_tensor=tensor
226
+ # tensor[..., :height, :width]
227
+ return new_tensor
228
+ def resize_input(self,tensor):
229
+ t_shape = tensor.shape
230
+ if t_shape[2] % 8 != 0 or t_shape[3] % 8 != 0:
231
+ img_w= ((t_shape[3]// 8) + 1) * 8
232
+ img_h = ((t_shape[2] // 8) + 1) * 8
233
+ new_tensor = F.interpolate(
234
+ tensor, size=(img_h, img_w), mode='bicubic', align_corners=False)
235
+ else:
236
+ new_tensor = tensor
237
+ return new_tensor
238
+
239
+ def crop_bdcn(data1, h, w, crop_h, crop_w):
240
+ # Based on BDCN Implementation @ https://github.com/pkuCactus/BDCN
241
+ _, _, h1, w1 = data1.size()
242
+ assert (h <= h1 and w <= w1)
243
+ data = data1[:, :, crop_h:crop_h + h, crop_w:crop_w + w]
244
+ return data
245
+
246
+
247
+ def forward(self, x, single_test=False):
248
+ assert x.ndim == 4, x.shape
249
+ # supose the image size is 352x352
250
+
251
+ # Block 1
252
+ block_1 = self.block_1(x) # [8,16,176,176]
253
+ block_1_side = self.side_1(block_1) # 16 [8,32,88,88]
254
+
255
+ # Block 2
256
+ block_2 = self.block_2(block_1) # 32 # [8,32,176,176]
257
+ block_2_down = self.maxpool(block_2) # [8,32,88,88]
258
+ block_2_add = block_2_down + block_1_side # [8,32,88,88]
259
+
260
+ # Block 3
261
+ block_3_pre_dense = self.pre_dense_3(block_2_down) # [8,64,88,88] block 3 L connection
262
+ block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) # [8,64,88,88]
263
+
264
+ # upsampling blocks
265
+ out_1 = self.up_block_1(block_1)
266
+ out_2 = self.up_block_2(block_2)
267
+ out_3 = self.up_block_3(block_3)
268
+
269
+ results = [out_1, out_2, out_3]
270
+
271
+ # concatenate multiscale outputs
272
+ block_cat = torch.cat(results, dim=1) # Bx6xHxW
273
+ block_cat = self.block_cat(block_cat) # Bx1xHxW DoubleFusion
274
+
275
+ results.append(block_cat)
276
+ return results
277
+
278
+
279
+ if __name__ == '__main__':
280
+ batch_size = 8
281
+ img_height = 352
282
+ img_width = 352
283
+
284
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
285
+ device = "cpu"
286
+ input = torch.rand(batch_size, 3, img_height, img_width).to(device)
287
+ # target = torch.rand(batch_size, 1, img_height, img_width).to(device)
288
+ print(f"input shape: {input.shape}")
289
+ model = TED().to(device)
290
+ output = model(input)
291
+ print(f"output shapes: {[t.shape for t in output]}")
292
+
293
+ # for i in range(20000):
294
+ # print(i)
295
+ # output = model(input)
296
+ # loss = nn.MSELoss()(output[-1], target)
297
+ # loss.backward()
TEED/util.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import os
4
+
5
+
6
+ def load_model(filename: str, remote_url: str, model_dir: str) -> str:
7
+ """
8
+ Load the model from the specified filename and remote URL if it doesn't exist locally.
9
+
10
+ Args:
11
+ filename (str): The filename of the model.
12
+ remote_url (str): The remote URL of the model.
13
+ """
14
+ local_path = os.path.join(model_dir, filename)
15
+ if not os.path.exists(local_path):
16
+ from scripts.utils import load_file_from_url
17
+ load_file_from_url(remote_url, model_dir=model_dir)
18
+ return local_path
19
+
20
+
21
+ def HWC3(x):
22
+ assert x.dtype == np.uint8
23
+ if x.ndim == 2:
24
+ x = x[:, :, None]
25
+ assert x.ndim == 3
26
+ H, W, C = x.shape
27
+ assert C == 1 or C == 3 or C == 4
28
+ if C == 3:
29
+ return x
30
+ if C == 1:
31
+ return np.concatenate([x, x, x], axis=2)
32
+ if C == 4:
33
+ color = x[:, :, 0:3].astype(np.float32)
34
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
35
+ y = color * alpha + 255.0 * (1.0 - alpha)
36
+ y = y.clip(0, 255).astype(np.uint8)
37
+ return y
38
+
39
+
40
+ def make_noise_disk(H, W, C, F):
41
+ noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
42
+ noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
43
+ noise = noise[F: F + H, F: F + W]
44
+ noise -= np.min(noise)
45
+ noise /= np.max(noise)
46
+ if C == 1:
47
+ noise = noise[:, :, None]
48
+ return noise
49
+
50
+
51
+ def nms(x, t, s):
52
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
53
+
54
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
55
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
56
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
57
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
58
+
59
+ y = np.zeros_like(x)
60
+
61
+ for f in [f1, f2, f3, f4]:
62
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
63
+
64
+ z = np.zeros_like(y, dtype=np.uint8)
65
+ z[y > t] = 255
66
+ return z
67
+
68
+
69
+ def min_max_norm(x):
70
+ x -= np.min(x)
71
+ x /= np.maximum(np.max(x), 1e-5)
72
+ return x
73
+
74
+
75
+ def safe_step(x, step=2):
76
+ y = x.astype(np.float32) * float(step + 1)
77
+ y = y.astype(np.int32).astype(np.float32) / float(step)
78
+ return y