Spaces:
Runtime error
Runtime error
Upload 7 files
Browse files- TEED/LICENSE +21 -0
- TEED/README.md +68 -0
- TEED/dataset.py +581 -0
- TEED/devices.py +271 -0
- TEED/loss2.py +92 -0
- TEED/ted.py +297 -0
- TEED/util.py +78 -0
TEED/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2022 Xavier Soria Poma
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
TEED/README.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[](https://paperswithcode.com/sota/edge-detection-on-uded?p=tiny-and-efficient-model-for-the-edge)
|
| 2 |
+
|
| 3 |
+
# Tiny and Efficient Model for the Edge Detection Generalization (Paper)
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
<div style="text-align:center"><img src='imgs/teedBanner.png' width=800>
|
| 8 |
+
</div>
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
Tiny and Efficient Edge Detector (TEED) is a light convolutional neural
|
| 13 |
+
network with only $58K$ parameters, less than $0.2$% of the
|
| 14 |
+
state-of-the-art models. Training on the [BIPED](https://www.kaggle.com/datasets/xavysp/biped)
|
| 15 |
+
dataset takes *less than 30 minutes*, with each epoch requiring
|
| 16 |
+
*less than 5 minutes*. Our proposed model is easy to train
|
| 17 |
+
and it quickly converges within very first few epochs, while the
|
| 18 |
+
predicted edge-maps are crisp and of high quality, see image above.
|
| 19 |
+
[This paper has been accepted by ICCV 2023-Workshop RCV](https://arxiv.org/abs/2308.06468).
|
| 20 |
+
|
| 21 |
+
... In construction
|
| 22 |
+
|
| 23 |
+
git clone https://github.com/xavysp/TEED.git
|
| 24 |
+
cd TEED
|
| 25 |
+
|
| 26 |
+
Then,
|
| 27 |
+
|
| 28 |
+
## Testing with TEED
|
| 29 |
+
|
| 30 |
+
Copy and paste your images into data/ folder, and:
|
| 31 |
+
|
| 32 |
+
python main.py --choose_test_data=-1
|
| 33 |
+
|
| 34 |
+
## Training with TEED
|
| 35 |
+
|
| 36 |
+
Set the following lines in main.py:
|
| 37 |
+
|
| 38 |
+
25: is_testing =False
|
| 39 |
+
# training with BIPED
|
| 40 |
+
223: TRAIN_DATA = DATASET_NAMES[0]
|
| 41 |
+
|
| 42 |
+
then run
|
| 43 |
+
|
| 44 |
+
python main.py
|
| 45 |
+
|
| 46 |
+
Check the configurations of the datasets in dataset.py
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
## UDED dataset
|
| 50 |
+
|
| 51 |
+
Here the [link](https://github.com/xavysp/UDED) to access the UDED dataset for edge detection
|
| 52 |
+
|
| 53 |
+
## Citation
|
| 54 |
+
|
| 55 |
+
If you like TEED, why not starring the project on GitHub!
|
| 56 |
+
|
| 57 |
+
[](https://GitHub.com/xavysp/TEED/stargazers/)
|
| 58 |
+
|
| 59 |
+
Please cite our Dataset if you find helpful in your academic/scientific publication,
|
| 60 |
+
```
|
| 61 |
+
@InProceedings{Soria_2023teed,
|
| 62 |
+
author = {Soria, Xavier and Li, Yachuan and Rouhani, Mohammad and Sappa, Angel D.},
|
| 63 |
+
title = {Tiny and Efficient Model for the Edge Detection Generalization},
|
| 64 |
+
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) Workshops},
|
| 65 |
+
month = {October},
|
| 66 |
+
year = {2023},
|
| 67 |
+
pages = {1364-1373}
|
| 68 |
+
}
|
TEED/dataset.py
ADDED
|
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
DATASET_NAMES = [
|
| 11 |
+
'BIPED',
|
| 12 |
+
'BIPED-B2',
|
| 13 |
+
'BIPED-B3',
|
| 14 |
+
'BIPED-B5',
|
| 15 |
+
'BIPED-B6',
|
| 16 |
+
'BSDS', # 5
|
| 17 |
+
'BRIND', # 6
|
| 18 |
+
'ICEDA', #7
|
| 19 |
+
'BSDS300',
|
| 20 |
+
'CID', #9
|
| 21 |
+
'DCD',
|
| 22 |
+
'MDBD', #11
|
| 23 |
+
'PASCAL',
|
| 24 |
+
'NYUD', #13
|
| 25 |
+
'BIPBRI',
|
| 26 |
+
'UDED', # 15 just for testing
|
| 27 |
+
'DMRIR',
|
| 28 |
+
'CLASSIC'
|
| 29 |
+
] # 8
|
| 30 |
+
# [108, 109.451,112.230,137.86]
|
| 31 |
+
BIPED_mean = [103.939,116.779,123.68,137.86]
|
| 32 |
+
|
| 33 |
+
def dataset_info(dataset_name, is_linux=True):
|
| 34 |
+
if is_linux:
|
| 35 |
+
|
| 36 |
+
config = {
|
| 37 |
+
'UDED': {
|
| 38 |
+
'img_height': 512, # 321
|
| 39 |
+
'img_width': 512, # 481
|
| 40 |
+
'train_list': None,
|
| 41 |
+
'test_list': 'test_pair.lst',
|
| 42 |
+
'data_dir': '/root/workspace/datasets/UDED', # mean_rgb
|
| 43 |
+
'yita': 0.5,
|
| 44 |
+
'mean': [104.007, 116.669, 122.679, 137.86]# [104.007, 116.669, 122.679, 137.86]
|
| 45 |
+
}, #[98.939,111.779,117.68,137.86]
|
| 46 |
+
'BSDS': {
|
| 47 |
+
'img_height': 512, #321
|
| 48 |
+
'img_width': 512, #481
|
| 49 |
+
'train_list': 'train_pair.lst',
|
| 50 |
+
'test_list': 'test_pair.lst',
|
| 51 |
+
'data_dir': '/root/workspace/datasets/BSDS', # mean_rgb
|
| 52 |
+
'yita': 0.5,
|
| 53 |
+
'mean': [104.007, 116.669, 122.679, 137.86]
|
| 54 |
+
},
|
| 55 |
+
'BRIND': {
|
| 56 |
+
'img_height': 512, # 321
|
| 57 |
+
'img_width': 512, # 481
|
| 58 |
+
'train_list': 'train_pair_all.lst',
|
| 59 |
+
# all train_pair_all.lst
|
| 60 |
+
# less train_pair.lst
|
| 61 |
+
'test_list': 'test_pair.lst',
|
| 62 |
+
'data_dir': '/root/workspace/datasets/BRIND', # mean_rgb
|
| 63 |
+
'yita': 0.5,
|
| 64 |
+
'mean': [104.007, 116.669, 122.679, 137.86]
|
| 65 |
+
},
|
| 66 |
+
'ICEDA': {
|
| 67 |
+
'img_height': 1024, # 321
|
| 68 |
+
'img_width': 1408, # 481
|
| 69 |
+
'train_list': None,
|
| 70 |
+
'test_list': 'test_pair.lst',
|
| 71 |
+
'data_dir': '/root/workspace/datasets/ICEDA', # mean_rgb
|
| 72 |
+
'yita': 0.5,
|
| 73 |
+
'mean': [104.007, 116.669, 122.679, 137.86]
|
| 74 |
+
},
|
| 75 |
+
'BSDS300': {
|
| 76 |
+
'img_height': 512, #321
|
| 77 |
+
'img_width': 512, #481
|
| 78 |
+
'test_list': 'test_pair.lst',
|
| 79 |
+
'train_list': None,
|
| 80 |
+
'data_dir': '/root/workspace/datasets/BSDS300', # NIR
|
| 81 |
+
'yita': 0.5,
|
| 82 |
+
'mean': [104.007, 116.669, 122.679, 137.86]
|
| 83 |
+
},
|
| 84 |
+
'PASCAL': {
|
| 85 |
+
'img_height': 416, # 375
|
| 86 |
+
'img_width': 512, #500
|
| 87 |
+
'test_list': 'test_pair.lst',
|
| 88 |
+
'train_list': None,
|
| 89 |
+
'data_dir': '/root/datasets/PASCAL', # mean_rgb
|
| 90 |
+
'yita': 0.3,
|
| 91 |
+
'mean': [104.007, 116.669, 122.679, 137.86]
|
| 92 |
+
},
|
| 93 |
+
'CID': {
|
| 94 |
+
'img_height': 512,
|
| 95 |
+
'img_width': 512,
|
| 96 |
+
'test_list': 'test_pair.lst',
|
| 97 |
+
'train_list': None,
|
| 98 |
+
'data_dir': '/root/datasets/CID', # mean_rgb
|
| 99 |
+
'yita': 0.3,
|
| 100 |
+
'mean': [104.007, 116.669, 122.679, 137.86]
|
| 101 |
+
},
|
| 102 |
+
'NYUD': {
|
| 103 |
+
'img_height': 448,#425
|
| 104 |
+
'img_width': 560,#560
|
| 105 |
+
'test_list': 'test_pair.lst',
|
| 106 |
+
'train_list': None,
|
| 107 |
+
'data_dir': '/root/datasets/NYUD', # mean_rgb
|
| 108 |
+
'yita': 0.5,
|
| 109 |
+
'mean': [104.007, 116.669, 122.679, 137.86]
|
| 110 |
+
},
|
| 111 |
+
'MDBD': {
|
| 112 |
+
'img_height': 720,
|
| 113 |
+
'img_width': 1280,
|
| 114 |
+
'test_list': 'test_pair.lst',
|
| 115 |
+
'train_list': 'train_pair.lst',
|
| 116 |
+
'data_dir': '/root/workspace/datasets/MDBD', # mean_rgb
|
| 117 |
+
'yita': 0.3,
|
| 118 |
+
'mean': [104.007, 116.669, 122.679, 137.86]
|
| 119 |
+
},
|
| 120 |
+
'BIPED': {
|
| 121 |
+
'img_height': 720, #720 # 1088
|
| 122 |
+
'img_width': 1280, # 1280 5 1920
|
| 123 |
+
'test_list': 'test_pair.lst',
|
| 124 |
+
'train_list': 'train_pair0.lst', # Base augmentation
|
| 125 |
+
# 'train_list': 'train_pairB3.lst', # another augmentation
|
| 126 |
+
# 'train_list': 'train_pairB5.lst', # Last augmentation
|
| 127 |
+
'data_dir': '/root/workspace/datasets/BIPED', # mean_rgb
|
| 128 |
+
'yita': 0.5,
|
| 129 |
+
'mean':BIPED_mean
|
| 130 |
+
#
|
| 131 |
+
},
|
| 132 |
+
'CLASSIC': {
|
| 133 |
+
'img_height': 512,#
|
| 134 |
+
'img_width': 512,# 512
|
| 135 |
+
'test_list': None,
|
| 136 |
+
'train_list': None,
|
| 137 |
+
'data_dir': 'data', # mean_rgb
|
| 138 |
+
'yita': 0.5,
|
| 139 |
+
'mean': [104.007, 116.669, 122.679, 137.86]
|
| 140 |
+
},
|
| 141 |
+
'BIPED-B2': {'img_height': 720, # 720
|
| 142 |
+
'img_width': 1280, # 1280
|
| 143 |
+
'test_list': 'test_pair.lst',
|
| 144 |
+
'train_list': 'train_rgb.lst',
|
| 145 |
+
'data_dir': 'C:/Users/xavysp/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
|
| 146 |
+
'yita': 0.5,
|
| 147 |
+
'mean':BIPED_mean},
|
| 148 |
+
'BIPED-B3': {'img_height': 720, # 720
|
| 149 |
+
'img_width': 1280, # 1280
|
| 150 |
+
'test_list': 'test_pair.lst',
|
| 151 |
+
'train_list': 'train_rgb.lst',
|
| 152 |
+
'data_dir': 'C:/Users/xavysp/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
|
| 153 |
+
'yita': 0.5,
|
| 154 |
+
'mean':BIPED_mean},
|
| 155 |
+
'BIPED-B5': {'img_height': 720, # 720
|
| 156 |
+
'img_width': 1280, # 1280
|
| 157 |
+
'test_list': 'test_pair.lst',
|
| 158 |
+
'train_list': 'train_rgb.lst',
|
| 159 |
+
'data_dir': 'C:/Users/xavysp/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
|
| 160 |
+
'yita': 0.5,
|
| 161 |
+
'mean':BIPED_mean},
|
| 162 |
+
'BIPED-B6': {'img_height': 720, # 720
|
| 163 |
+
'img_width': 1280, # 1280
|
| 164 |
+
'test_list': 'test_pair.lst',
|
| 165 |
+
'train_list': 'train_rgb.lst',
|
| 166 |
+
'data_dir': 'C:/Users/xavysp/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
|
| 167 |
+
'yita': 0.5,
|
| 168 |
+
'mean':BIPED_mean},
|
| 169 |
+
'DCD': {
|
| 170 |
+
'img_height': 352, #240
|
| 171 |
+
'img_width': 480,# 360
|
| 172 |
+
'test_list': 'test_pair.lst',
|
| 173 |
+
'train_list': None,
|
| 174 |
+
'data_dir': '/opt/dataset/DCD', # mean_rgb
|
| 175 |
+
'yita': 0.2,
|
| 176 |
+
'mean': [104.007, 116.669, 122.679, 137.86]
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
else:
|
| 180 |
+
config = {
|
| 181 |
+
'UDED': {
|
| 182 |
+
'img_height': 512, # 321
|
| 183 |
+
'img_width': 512, # 481
|
| 184 |
+
'train_list': None,
|
| 185 |
+
'test_list': 'test_pair.lst',
|
| 186 |
+
'data_dir': 'C:/dataset/UDED', # mean_rgb
|
| 187 |
+
'yita': 0.5,
|
| 188 |
+
'mean':[104.007, 116.669, 122.679, 137.86] # [183.939,196.779,203.68,137.86] # [104.007, 116.669, 122.679, 137.86]
|
| 189 |
+
},
|
| 190 |
+
'BSDS': {'img_height': 480, # 321
|
| 191 |
+
'img_width': 480, # 481
|
| 192 |
+
'test_list': 'test_pair.lst',
|
| 193 |
+
'data_dir': 'C:/dataset/BSDS', # mean_rgb
|
| 194 |
+
'yita': 0.5,
|
| 195 |
+
'mean':[103.939, 116.669, 122.679, 137.86] },
|
| 196 |
+
# [103.939, 116.669, 122.679, 137.86]
|
| 197 |
+
#[159.510, 159.451,162.230,137.86]
|
| 198 |
+
'BRIND': {
|
| 199 |
+
'img_height': 512, # 321
|
| 200 |
+
'img_width': 512, # 481
|
| 201 |
+
'train_list': 'train_pair_all.lst',
|
| 202 |
+
# all train_pair_all.lst
|
| 203 |
+
# less train_pair.lst
|
| 204 |
+
'test_list': 'test_pair.lst',
|
| 205 |
+
'data_dir': 'C:/dataset/BRIND', # mean_rgb
|
| 206 |
+
'yita': 0.5,
|
| 207 |
+
'mean': [104.007, 116.669, 122.679, 137.86]
|
| 208 |
+
},
|
| 209 |
+
'ICEDA': {
|
| 210 |
+
'img_height': 1024, # 321
|
| 211 |
+
'img_width': 1408, # 481
|
| 212 |
+
'train_list': None,
|
| 213 |
+
'test_list': 'test_pair.lst',
|
| 214 |
+
'data_dir': 'C:/dataset/ICEDA', # mean_rgb
|
| 215 |
+
'yita': 0.5,
|
| 216 |
+
'mean': [104.007, 116.669, 122.679, 137.86]
|
| 217 |
+
},
|
| 218 |
+
'BSDS300': {'img_height': 512, # 321
|
| 219 |
+
'img_width': 512, # 481
|
| 220 |
+
'test_list': 'test_pair.lst',
|
| 221 |
+
'data_dir': 'C:/Users/xavysp/dataset/BSDS300', # NIR
|
| 222 |
+
'yita': 0.5,
|
| 223 |
+
'mean': [104.007, 116.669, 122.679, 137.86]},
|
| 224 |
+
'PASCAL': {'img_height': 375,
|
| 225 |
+
'img_width': 500,
|
| 226 |
+
'test_list': 'test_pair.lst',
|
| 227 |
+
'data_dir': 'C:/dataset/PASCAL', # mean_rgb
|
| 228 |
+
'yita': 0.3,
|
| 229 |
+
'mean': [104.007, 116.669, 122.679, 137.86]},
|
| 230 |
+
'CID': {'img_height': 512,
|
| 231 |
+
'img_width': 512,
|
| 232 |
+
'test_list': 'test_pair.lst',
|
| 233 |
+
'data_dir': 'C:/dataset/CID', # mean_rgb
|
| 234 |
+
'yita': 0.3,
|
| 235 |
+
'mean': [104.007, 116.669, 122.679, 137.86]},
|
| 236 |
+
'NYUD': {'img_height': 425,
|
| 237 |
+
'img_width': 560,
|
| 238 |
+
'test_list': 'test_pair.lst',
|
| 239 |
+
'data_dir': 'C:/dataset/NYUD', # mean_rgb
|
| 240 |
+
'yita': 0.5,
|
| 241 |
+
'mean': [104.007, 116.669, 122.679, 137.86]},
|
| 242 |
+
'MDBD': {'img_height': 720,
|
| 243 |
+
'img_width': 1280,
|
| 244 |
+
'test_list': 'test_pair.lst',
|
| 245 |
+
'train_list': 'train_pair.lst',
|
| 246 |
+
'data_dir': 'C:/dataset/MDBD', # mean_rgb
|
| 247 |
+
'yita': 0.3,
|
| 248 |
+
'mean': [104.007, 116.669, 122.679, 137.86]},
|
| 249 |
+
'BIPED': {'img_height': 720, # 720
|
| 250 |
+
'img_width': 1280, # 1280
|
| 251 |
+
'test_list': 'test_pair.lst',
|
| 252 |
+
'train_list': 'train_pair0.lst',
|
| 253 |
+
# 'train_list': 'train_rgb.lst',
|
| 254 |
+
'data_dir': 'C:/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
|
| 255 |
+
'yita': 0.5,
|
| 256 |
+
'mean':BIPED_mean},
|
| 257 |
+
'BIPED-B2': {'img_height': 720, # 720
|
| 258 |
+
'img_width': 1280, # 1280
|
| 259 |
+
'test_list': 'test_pair.lst',
|
| 260 |
+
'train_list': 'train_rgb.lst',
|
| 261 |
+
'data_dir': 'C:/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
|
| 262 |
+
'yita': 0.5,
|
| 263 |
+
'mean':BIPED_mean},
|
| 264 |
+
'BIPED-B3': {'img_height': 720, # 720
|
| 265 |
+
'img_width': 1280, # 1280
|
| 266 |
+
'test_list': 'test_pair.lst',
|
| 267 |
+
'train_list': 'train_rgb.lst',
|
| 268 |
+
'data_dir': 'C:/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
|
| 269 |
+
'yita': 0.5,
|
| 270 |
+
'mean':BIPED_mean},
|
| 271 |
+
'BIPED-B5': {'img_height': 720, # 720
|
| 272 |
+
'img_width': 1280, # 1280
|
| 273 |
+
'test_list': 'test_pair.lst',
|
| 274 |
+
'train_list': 'train_rgb.lst',
|
| 275 |
+
'data_dir': 'C:/Users/xavysp/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
|
| 276 |
+
'yita': 0.5,
|
| 277 |
+
'mean':BIPED_mean},
|
| 278 |
+
'BIPED-B6': {'img_height': 720, # 720
|
| 279 |
+
'img_width': 1280, # 1280
|
| 280 |
+
'test_list': 'test_pair.lst',
|
| 281 |
+
'train_list': 'train_rgb.lst',
|
| 282 |
+
'data_dir': 'C:/Users/xavysp/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges'
|
| 283 |
+
'yita': 0.5,
|
| 284 |
+
'mean':BIPED_mean},
|
| 285 |
+
'CLASSIC': {'img_height': 512,
|
| 286 |
+
'img_width': 512,
|
| 287 |
+
'test_list': None,
|
| 288 |
+
'train_list': None,
|
| 289 |
+
'data_dir': 'teed_tmp', # mean_rgb
|
| 290 |
+
'yita': 0.5,
|
| 291 |
+
'mean': [104.007, 116.669, 122.679, 137.86]},
|
| 292 |
+
'DCD': {'img_height': 240,
|
| 293 |
+
'img_width': 360,
|
| 294 |
+
'test_list': 'test_pair.lst',
|
| 295 |
+
'data_dir': 'C:/dataset/DCD', # mean_rgb
|
| 296 |
+
'yita': 0.2,
|
| 297 |
+
'mean': [104.007, 116.669, 122.679, 137.86]}
|
| 298 |
+
}
|
| 299 |
+
return config[dataset_name]
|
| 300 |
+
|
| 301 |
+
class TestDataset(Dataset):
|
| 302 |
+
def __init__(self,
|
| 303 |
+
data_root,
|
| 304 |
+
test_data,
|
| 305 |
+
img_height,
|
| 306 |
+
img_width,
|
| 307 |
+
test_list=None,
|
| 308 |
+
arg=None
|
| 309 |
+
):
|
| 310 |
+
if test_data not in DATASET_NAMES:
|
| 311 |
+
raise ValueError(f"Unsupported dataset: {test_data}")
|
| 312 |
+
|
| 313 |
+
self.data_root = data_root
|
| 314 |
+
self.test_data = test_data
|
| 315 |
+
self.test_list = test_list
|
| 316 |
+
self.args = arg
|
| 317 |
+
self.up_scale = arg.up_scale
|
| 318 |
+
self.mean_bgr = arg.mean_test if len(arg.mean_test) == 3 else arg.mean_test[:3]
|
| 319 |
+
self.img_height = img_height
|
| 320 |
+
self.img_width = img_width
|
| 321 |
+
self.data_index = self._build_index()
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _build_index(self):
|
| 325 |
+
sample_indices = []
|
| 326 |
+
if self.test_data == "CLASSIC":
|
| 327 |
+
# for single image testing
|
| 328 |
+
images_path = os.listdir(self.data_root)
|
| 329 |
+
labels_path = None
|
| 330 |
+
sample_indices = [images_path, labels_path]
|
| 331 |
+
else:
|
| 332 |
+
# image and label paths are located in a list file
|
| 333 |
+
|
| 334 |
+
if not self.test_list:
|
| 335 |
+
raise ValueError(
|
| 336 |
+
f"Test list not provided for dataset: {self.test_data}")
|
| 337 |
+
|
| 338 |
+
list_name = os.path.join(self.data_root, self.test_list)
|
| 339 |
+
if self.test_data.upper() in ['BIPED', 'BRIND','UDED','ICEDA']:
|
| 340 |
+
|
| 341 |
+
with open(list_name,encoding='utf-8') as f:
|
| 342 |
+
files = json.load(f)
|
| 343 |
+
for pair in files:
|
| 344 |
+
tmp_img = pair[0]
|
| 345 |
+
tmp_gt = pair[1]
|
| 346 |
+
sample_indices.append(
|
| 347 |
+
(os.path.join(self.data_root, tmp_img),
|
| 348 |
+
os.path.join(self.data_root, tmp_gt),))
|
| 349 |
+
else:
|
| 350 |
+
with open(list_name, 'r') as f:
|
| 351 |
+
files = f.readlines()
|
| 352 |
+
files = [line.strip() for line in files]
|
| 353 |
+
pairs = [line.split() for line in files]
|
| 354 |
+
|
| 355 |
+
for pair in pairs:
|
| 356 |
+
tmp_img = pair[0]
|
| 357 |
+
tmp_gt = pair[1]
|
| 358 |
+
sample_indices.append(
|
| 359 |
+
(os.path.join(self.data_root, tmp_img),
|
| 360 |
+
os.path.join(self.data_root, tmp_gt),))
|
| 361 |
+
return sample_indices
|
| 362 |
+
|
| 363 |
+
def __len__(self):
|
| 364 |
+
return len(self.data_index[0]) if self.test_data.upper() == 'CLASSIC' else len(self.data_index)
|
| 365 |
+
|
| 366 |
+
def __getitem__(self, idx):
|
| 367 |
+
# get data sample
|
| 368 |
+
# image_path, label_path = self.data_index[idx]
|
| 369 |
+
if self.data_index[1] is None:
|
| 370 |
+
image_path = self.data_index[0][idx] if len(self.data_index[0]) > 1 else self.data_index[0][idx - 1]
|
| 371 |
+
else:
|
| 372 |
+
image_path = self.data_index[idx][0]
|
| 373 |
+
label_path = None if self.test_data == "CLASSIC" else self.data_index[idx][1]
|
| 374 |
+
img_name = os.path.basename(image_path)
|
| 375 |
+
# print(img_name)
|
| 376 |
+
file_name = os.path.splitext(img_name)[0] + ".png"
|
| 377 |
+
|
| 378 |
+
# base dir
|
| 379 |
+
if self.test_data.upper() == 'BIPED':
|
| 380 |
+
img_dir = os.path.join(self.data_root, 'imgs', 'test')
|
| 381 |
+
gt_dir = os.path.join(self.data_root, 'edge_maps', 'test')
|
| 382 |
+
elif self.test_data.upper() == 'CLASSIC':
|
| 383 |
+
img_dir = self.data_root
|
| 384 |
+
gt_dir = None
|
| 385 |
+
else:
|
| 386 |
+
img_dir = self.data_root
|
| 387 |
+
gt_dir = self.data_root
|
| 388 |
+
|
| 389 |
+
# load data
|
| 390 |
+
image = cv2.imdecode(np.fromfile(os.path.join(img_dir, image_path), np.uint8), cv2.IMREAD_COLOR)
|
| 391 |
+
if not self.test_data == "CLASSIC":
|
| 392 |
+
label = cv2.imread(os.path.join(
|
| 393 |
+
gt_dir, label_path), cv2.IMREAD_COLOR)
|
| 394 |
+
else:
|
| 395 |
+
label = None
|
| 396 |
+
|
| 397 |
+
im_shape = [image.shape[0], image.shape[1]]
|
| 398 |
+
image, label = self.transform(img=image, gt=label)
|
| 399 |
+
|
| 400 |
+
return dict(images=image, labels=label, file_names=file_name, image_shape=im_shape)
|
| 401 |
+
|
| 402 |
+
def transform(self, img, gt):
|
| 403 |
+
# gt[gt< 51] = 0 # test without gt discrimination
|
| 404 |
+
# up scale test image
|
| 405 |
+
if self.up_scale:
|
| 406 |
+
# For TEED BIPBRIlight Upscale
|
| 407 |
+
img = cv2.resize(img,(0,0),fx=1.3,fy=1.3)
|
| 408 |
+
|
| 409 |
+
if img.shape[0] < 512 or img.shape[1] < 512:
|
| 410 |
+
#TEED BIPED standard proposal if you want speed up the test, comment this block
|
| 411 |
+
img = cv2.resize(img, (0, 0), fx=1.5, fy=1.5)
|
| 412 |
+
# else:
|
| 413 |
+
# img = cv2.resize(img, (0, 0), fx=1.1, fy=1.1)
|
| 414 |
+
|
| 415 |
+
# Make sure images and labels are divisible by 2^4=16
|
| 416 |
+
if img.shape[0] % 8 != 0 or img.shape[1] % 8 != 0:
|
| 417 |
+
img_width = ((img.shape[1] // 8) + 1) * 8
|
| 418 |
+
img_height = ((img.shape[0] // 8) + 1) * 8
|
| 419 |
+
img = cv2.resize(img, (img_width, img_height))
|
| 420 |
+
# gt = cv2.resize(gt, (img_width, img_height))
|
| 421 |
+
else:
|
| 422 |
+
pass
|
| 423 |
+
# img_width = self.args.test_img_width
|
| 424 |
+
# img_height = self.args.test_img_height
|
| 425 |
+
# img = cv2.resize(img, (img_width, img_height))
|
| 426 |
+
# gt = cv2.resize(gt, (img_width, img_height))
|
| 427 |
+
# # For FPS
|
| 428 |
+
# img = cv2.resize(img, (496,320))
|
| 429 |
+
|
| 430 |
+
img = np.array(img, dtype=np.float32)
|
| 431 |
+
# if self.rgb:
|
| 432 |
+
# img = img[:, :, ::-1] # RGB->BGR
|
| 433 |
+
|
| 434 |
+
img -= self.mean_bgr
|
| 435 |
+
img = img.transpose((2, 0, 1))
|
| 436 |
+
img = torch.from_numpy(img.copy()).float()
|
| 437 |
+
|
| 438 |
+
if self.test_data == "CLASSIC":
|
| 439 |
+
gt = np.zeros((img.shape[:2]))
|
| 440 |
+
gt = torch.from_numpy(np.array([gt])).float()
|
| 441 |
+
else:
|
| 442 |
+
gt = np.array(gt, dtype=np.float32)
|
| 443 |
+
if len(gt.shape) == 3:
|
| 444 |
+
gt = gt[:, :, 0]
|
| 445 |
+
gt /= 255.
|
| 446 |
+
gt = torch.from_numpy(np.array([gt])).float()
|
| 447 |
+
|
| 448 |
+
return img, gt
|
| 449 |
+
|
| 450 |
+
# *************************************************
|
| 451 |
+
# ************* training **************************
|
| 452 |
+
# *************************************************
|
| 453 |
+
class BipedDataset(Dataset):
|
| 454 |
+
train_modes = ['train', 'test', ]
|
| 455 |
+
dataset_types = ['rgbr', ]
|
| 456 |
+
data_types = ['aug', ]
|
| 457 |
+
|
| 458 |
+
def __init__(self,
|
| 459 |
+
data_root,
|
| 460 |
+
img_height,
|
| 461 |
+
img_width,
|
| 462 |
+
train_mode='train',
|
| 463 |
+
dataset_type='rgbr',
|
| 464 |
+
# is_scaling=None,
|
| 465 |
+
# Whether to crop image or otherwise resize image to match image height and width.
|
| 466 |
+
crop_img=False,
|
| 467 |
+
arg=None
|
| 468 |
+
):
|
| 469 |
+
self.data_root = data_root
|
| 470 |
+
self.train_mode = train_mode
|
| 471 |
+
self.dataset_type = dataset_type
|
| 472 |
+
self.data_type = 'aug' # be aware that this might change in the future
|
| 473 |
+
self.img_height = img_height
|
| 474 |
+
self.img_width = img_width
|
| 475 |
+
self.mean_bgr = arg.mean_train if len(arg.mean_train) == 3 else arg.mean_train[:3]
|
| 476 |
+
self.crop_img = crop_img
|
| 477 |
+
self.arg = arg
|
| 478 |
+
|
| 479 |
+
self.data_index = self._build_index()
|
| 480 |
+
|
| 481 |
+
def _build_index(self):
|
| 482 |
+
assert self.train_mode in self.train_modes, self.train_mode
|
| 483 |
+
assert self.dataset_type in self.dataset_types, self.dataset_type
|
| 484 |
+
assert self.data_type in self.data_types, self.data_type
|
| 485 |
+
|
| 486 |
+
data_root = os.path.abspath(self.data_root)
|
| 487 |
+
sample_indices = []
|
| 488 |
+
|
| 489 |
+
file_path = os.path.join(data_root, self.arg.train_list)
|
| 490 |
+
if self.arg.train_data.lower() == 'bsds':
|
| 491 |
+
|
| 492 |
+
with open(file_path, 'r') as f:
|
| 493 |
+
files = f.readlines()
|
| 494 |
+
files = [line.strip() for line in files]
|
| 495 |
+
|
| 496 |
+
pairs = [line.split() for line in files]
|
| 497 |
+
for pair in pairs:
|
| 498 |
+
tmp_img = pair[0]
|
| 499 |
+
tmp_gt = pair[1]
|
| 500 |
+
sample_indices.append(
|
| 501 |
+
(os.path.join(data_root, tmp_img),
|
| 502 |
+
os.path.join(data_root, tmp_gt),))
|
| 503 |
+
else:
|
| 504 |
+
with open(file_path) as f:
|
| 505 |
+
files = json.load(f)
|
| 506 |
+
for pair in files:
|
| 507 |
+
tmp_img = pair[0]
|
| 508 |
+
tmp_gt = pair[1]
|
| 509 |
+
sample_indices.append(
|
| 510 |
+
(os.path.join(data_root, tmp_img),
|
| 511 |
+
os.path.join(data_root, tmp_gt),))
|
| 512 |
+
|
| 513 |
+
return sample_indices
|
| 514 |
+
|
| 515 |
+
def __len__(self):
|
| 516 |
+
return len(self.data_index)
|
| 517 |
+
|
| 518 |
+
def __getitem__(self, idx):
|
| 519 |
+
# get data sample
|
| 520 |
+
image_path, label_path = self.data_index[idx]
|
| 521 |
+
|
| 522 |
+
# load data
|
| 523 |
+
image = cv2.imdecode(np.fromfile(image_path, np.uint8), cv2.IMREAD_COLOR)
|
| 524 |
+
label = cv2.imdecode(np.fromfile(label_path), cv2.IMREAD_GRAYSCALE)
|
| 525 |
+
image, label = self.transform(img=image, gt=label)
|
| 526 |
+
return dict(images=image, labels=label)
|
| 527 |
+
|
| 528 |
+
def transform(self, img, gt):
|
| 529 |
+
gt = np.array(gt, dtype=np.float32)
|
| 530 |
+
if len(gt.shape) == 3:
|
| 531 |
+
gt = gt[:, :, 0]
|
| 532 |
+
|
| 533 |
+
gt /= 255. # for LDC input and BDCN
|
| 534 |
+
|
| 535 |
+
img = np.array(img, dtype=np.float32)
|
| 536 |
+
img -= self.mean_bgr
|
| 537 |
+
i_h, i_w, _ = img.shape
|
| 538 |
+
# 400 for BIPEd and 352 for BSDS check with 384
|
| 539 |
+
crop_size = self.img_height if self.img_height == self.img_width else None # 448# MDBD=480 BIPED=480/400 BSDS=352
|
| 540 |
+
#
|
| 541 |
+
# # for BSDS 352/BRIND
|
| 542 |
+
# if i_w > crop_size and i_h > crop_size: # later 400, before crop_size
|
| 543 |
+
# i = random.randint(0, i_h - crop_size)
|
| 544 |
+
# j = random.randint(0, i_w - crop_size)
|
| 545 |
+
# img = img[i:i + crop_size, j:j + crop_size]
|
| 546 |
+
# gt = gt[i:i + crop_size, j:j + crop_size]
|
| 547 |
+
|
| 548 |
+
# for BIPED/MDBD
|
| 549 |
+
# Second augmentation
|
| 550 |
+
if i_w> 400 and i_h>400: #before 420
|
| 551 |
+
h,w = gt.shape
|
| 552 |
+
if np.random.random() > 0.4: #before i_w> 500 and i_h>500:
|
| 553 |
+
|
| 554 |
+
LR_img_size = crop_size #l BIPED=256, 240 200 # MDBD= 352 BSDS= 176
|
| 555 |
+
i = random.randint(0, h - LR_img_size)
|
| 556 |
+
j = random.randint(0, w - LR_img_size)
|
| 557 |
+
# if img.
|
| 558 |
+
img = img[i:i + LR_img_size , j:j + LR_img_size ]
|
| 559 |
+
gt = gt[i:i + LR_img_size , j:j + LR_img_size ]
|
| 560 |
+
else:
|
| 561 |
+
LR_img_size = 300# 256 300 400 # l BIPED=208-352, # MDBD= 352-480- BSDS= 176-320
|
| 562 |
+
i = random.randint(0, h - LR_img_size)
|
| 563 |
+
j = random.randint(0, w - LR_img_size)
|
| 564 |
+
# if img.
|
| 565 |
+
img = img[i:i + LR_img_size, j:j + LR_img_size]
|
| 566 |
+
gt = gt[i:i + LR_img_size, j:j + LR_img_size]
|
| 567 |
+
img = cv2.resize(img, dsize=(crop_size, crop_size), )
|
| 568 |
+
gt = cv2.resize(gt, dsize=(crop_size, crop_size))
|
| 569 |
+
|
| 570 |
+
else:
|
| 571 |
+
# New addidings
|
| 572 |
+
img = cv2.resize(img, dsize=(crop_size, crop_size))
|
| 573 |
+
gt = cv2.resize(gt, dsize=(crop_size, crop_size))
|
| 574 |
+
# BRIND Best for TEDD+BIPED
|
| 575 |
+
gt[gt > 0.1] +=0.2#0.4
|
| 576 |
+
gt = np.clip(gt, 0., 1.)
|
| 577 |
+
|
| 578 |
+
img = img.transpose((2, 0, 1))
|
| 579 |
+
img = torch.from_numpy(img.copy()).float()
|
| 580 |
+
gt = torch.from_numpy(np.array([gt])).float()
|
| 581 |
+
return img, gt
|
TEED/devices.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import contextlib
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from modules import errors, shared, npu_specific
|
| 7 |
+
|
| 8 |
+
if sys.platform == "darwin":
|
| 9 |
+
from modules import mac_specific
|
| 10 |
+
|
| 11 |
+
if shared.cmd_opts.use_ipex:
|
| 12 |
+
from modules import xpu_specific
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def has_xpu() -> bool:
|
| 16 |
+
return shared.cmd_opts.use_ipex and xpu_specific.has_xpu
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def has_mps() -> bool:
|
| 20 |
+
if sys.platform != "darwin":
|
| 21 |
+
return False
|
| 22 |
+
else:
|
| 23 |
+
return mac_specific.has_mps
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def cuda_no_autocast(device_id=None) -> bool:
|
| 27 |
+
if device_id is None:
|
| 28 |
+
device_id = get_cuda_device_id()
|
| 29 |
+
return (
|
| 30 |
+
torch.cuda.get_device_capability(device_id) == (7, 5)
|
| 31 |
+
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_cuda_device_id():
|
| 36 |
+
return (
|
| 37 |
+
int(shared.cmd_opts.device_id)
|
| 38 |
+
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
|
| 39 |
+
else 0
|
| 40 |
+
) or torch.cuda.current_device()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_cuda_device_string():
|
| 44 |
+
if shared.cmd_opts.device_id is not None:
|
| 45 |
+
return f"cuda:{shared.cmd_opts.device_id}"
|
| 46 |
+
|
| 47 |
+
return "cuda"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_optimal_device_name():
|
| 51 |
+
if torch.cuda.is_available():
|
| 52 |
+
return get_cuda_device_string()
|
| 53 |
+
|
| 54 |
+
if has_mps():
|
| 55 |
+
return "mps"
|
| 56 |
+
|
| 57 |
+
if has_xpu():
|
| 58 |
+
return xpu_specific.get_xpu_device_string()
|
| 59 |
+
|
| 60 |
+
if npu_specific.has_npu:
|
| 61 |
+
return npu_specific.get_npu_device_string()
|
| 62 |
+
|
| 63 |
+
return "cpu"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_optimal_device():
|
| 67 |
+
return torch.device(get_optimal_device_name())
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_device_for(task):
|
| 71 |
+
if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:
|
| 72 |
+
return cpu
|
| 73 |
+
|
| 74 |
+
return get_optimal_device()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def torch_gc():
|
| 78 |
+
|
| 79 |
+
if torch.cuda.is_available():
|
| 80 |
+
with torch.cuda.device(get_cuda_device_string()):
|
| 81 |
+
torch.cuda.empty_cache()
|
| 82 |
+
torch.cuda.ipc_collect()
|
| 83 |
+
|
| 84 |
+
if has_mps():
|
| 85 |
+
mac_specific.torch_mps_gc()
|
| 86 |
+
|
| 87 |
+
if has_xpu():
|
| 88 |
+
xpu_specific.torch_xpu_gc()
|
| 89 |
+
|
| 90 |
+
if npu_specific.has_npu:
|
| 91 |
+
torch_npu_set_device()
|
| 92 |
+
npu_specific.torch_npu_gc()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def torch_npu_set_device():
|
| 96 |
+
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
|
| 97 |
+
if npu_specific.has_npu:
|
| 98 |
+
torch.npu.set_device(0)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def enable_tf32():
|
| 102 |
+
if torch.cuda.is_available():
|
| 103 |
+
|
| 104 |
+
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
| 105 |
+
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
| 106 |
+
if cuda_no_autocast():
|
| 107 |
+
torch.backends.cudnn.benchmark = True
|
| 108 |
+
|
| 109 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 110 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
errors.run(enable_tf32, "Enabling TF32")
|
| 114 |
+
|
| 115 |
+
cpu: torch.device = torch.device("cpu")
|
| 116 |
+
fp8: bool = False
|
| 117 |
+
device: torch.device = None
|
| 118 |
+
device_interrogate: torch.device = None
|
| 119 |
+
device_gfpgan: torch.device = None
|
| 120 |
+
device_esrgan: torch.device = None
|
| 121 |
+
device_codeformer: torch.device = None
|
| 122 |
+
dtype: torch.dtype = torch.float16
|
| 123 |
+
dtype_vae: torch.dtype = torch.float16
|
| 124 |
+
dtype_unet: torch.dtype = torch.float16
|
| 125 |
+
dtype_inference: torch.dtype = torch.float16
|
| 126 |
+
unet_needs_upcast = False
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def cond_cast_unet(input):
|
| 130 |
+
return input.to(dtype_unet) if unet_needs_upcast else input
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def cond_cast_float(input):
|
| 134 |
+
return input.float() if unet_needs_upcast else input
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
nv_rng = None
|
| 138 |
+
patch_module_list = [
|
| 139 |
+
torch.nn.Linear,
|
| 140 |
+
torch.nn.Conv2d,
|
| 141 |
+
torch.nn.MultiheadAttention,
|
| 142 |
+
torch.nn.GroupNorm,
|
| 143 |
+
torch.nn.LayerNorm,
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def manual_cast_forward(target_dtype):
|
| 148 |
+
def forward_wrapper(self, *args, **kwargs):
|
| 149 |
+
if any(
|
| 150 |
+
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
|
| 151 |
+
for arg in args
|
| 152 |
+
):
|
| 153 |
+
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
| 154 |
+
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
| 155 |
+
|
| 156 |
+
org_dtype = target_dtype
|
| 157 |
+
for param in self.parameters():
|
| 158 |
+
if param.dtype != target_dtype:
|
| 159 |
+
org_dtype = param.dtype
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
if org_dtype != target_dtype:
|
| 163 |
+
self.to(target_dtype)
|
| 164 |
+
result = self.org_forward(*args, **kwargs)
|
| 165 |
+
if org_dtype != target_dtype:
|
| 166 |
+
self.to(org_dtype)
|
| 167 |
+
|
| 168 |
+
if target_dtype != dtype_inference:
|
| 169 |
+
if isinstance(result, tuple):
|
| 170 |
+
result = tuple(
|
| 171 |
+
i.to(dtype_inference)
|
| 172 |
+
if isinstance(i, torch.Tensor)
|
| 173 |
+
else i
|
| 174 |
+
for i in result
|
| 175 |
+
)
|
| 176 |
+
elif isinstance(result, torch.Tensor):
|
| 177 |
+
result = result.to(dtype_inference)
|
| 178 |
+
return result
|
| 179 |
+
return forward_wrapper
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@contextlib.contextmanager
|
| 183 |
+
def manual_cast(target_dtype):
|
| 184 |
+
applied = False
|
| 185 |
+
for module_type in patch_module_list:
|
| 186 |
+
if hasattr(module_type, "org_forward"):
|
| 187 |
+
continue
|
| 188 |
+
applied = True
|
| 189 |
+
org_forward = module_type.forward
|
| 190 |
+
if module_type == torch.nn.MultiheadAttention:
|
| 191 |
+
module_type.forward = manual_cast_forward(torch.float32)
|
| 192 |
+
else:
|
| 193 |
+
module_type.forward = manual_cast_forward(target_dtype)
|
| 194 |
+
module_type.org_forward = org_forward
|
| 195 |
+
try:
|
| 196 |
+
yield None
|
| 197 |
+
finally:
|
| 198 |
+
if applied:
|
| 199 |
+
for module_type in patch_module_list:
|
| 200 |
+
if hasattr(module_type, "org_forward"):
|
| 201 |
+
module_type.forward = module_type.org_forward
|
| 202 |
+
delattr(module_type, "org_forward")
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def autocast(disable=False):
|
| 206 |
+
if disable:
|
| 207 |
+
return contextlib.nullcontext()
|
| 208 |
+
|
| 209 |
+
if fp8 and device==cpu:
|
| 210 |
+
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
| 211 |
+
|
| 212 |
+
if fp8 and dtype_inference == torch.float32:
|
| 213 |
+
return manual_cast(dtype)
|
| 214 |
+
|
| 215 |
+
if dtype == torch.float32 or dtype_inference == torch.float32:
|
| 216 |
+
return contextlib.nullcontext()
|
| 217 |
+
|
| 218 |
+
if has_xpu() or has_mps() or cuda_no_autocast():
|
| 219 |
+
return manual_cast(dtype)
|
| 220 |
+
|
| 221 |
+
return torch.autocast("cuda")
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def without_autocast(disable=False):
|
| 225 |
+
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class NansException(Exception):
|
| 229 |
+
pass
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def test_for_nans(x, where):
|
| 233 |
+
if shared.cmd_opts.disable_nan_check:
|
| 234 |
+
return
|
| 235 |
+
|
| 236 |
+
if not torch.all(torch.isnan(x)).item():
|
| 237 |
+
return
|
| 238 |
+
|
| 239 |
+
if where == "unet":
|
| 240 |
+
message = "A tensor with all NaNs was produced in Unet."
|
| 241 |
+
|
| 242 |
+
if not shared.cmd_opts.no_half:
|
| 243 |
+
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
|
| 244 |
+
|
| 245 |
+
elif where == "vae":
|
| 246 |
+
message = "A tensor with all NaNs was produced in VAE."
|
| 247 |
+
|
| 248 |
+
if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
|
| 249 |
+
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
|
| 250 |
+
else:
|
| 251 |
+
message = "A tensor with all NaNs was produced."
|
| 252 |
+
|
| 253 |
+
message += " Use --disable-nan-check commandline argument to disable this check."
|
| 254 |
+
|
| 255 |
+
raise NansException(message)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
@lru_cache
|
| 259 |
+
def first_time_calculation():
|
| 260 |
+
"""
|
| 261 |
+
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
|
| 262 |
+
spends about 2.7 seconds doing that, at least with NVidia.
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
x = torch.zeros((1, 1)).to(device, dtype)
|
| 266 |
+
linear = torch.nn.Linear(1, 1).to(device, dtype)
|
| 267 |
+
linear(x)
|
| 268 |
+
|
| 269 |
+
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
| 270 |
+
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
| 271 |
+
conv2d(x)
|
TEED/loss2.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from TEED.utils.AF.Fsmish import smish as Fsmish
|
| 4 |
+
|
| 5 |
+
def bdcn_loss2(inputs, targets, l_weight=1.1):
|
| 6 |
+
# bdcn loss modified in DexiNed
|
| 7 |
+
|
| 8 |
+
targets = targets.long()
|
| 9 |
+
mask = targets.float()
|
| 10 |
+
num_positive = torch.sum((mask > 0.0).float()).float() # >0.1
|
| 11 |
+
num_negative = torch.sum((mask <= 0.0).float()).float() # <= 0.1
|
| 12 |
+
|
| 13 |
+
mask[mask > 0.] = 1.0 * num_negative / (num_positive + num_negative) #0.1
|
| 14 |
+
mask[mask <= 0.] = 1.1 * num_positive / (num_positive + num_negative) # before mask[mask <= 0.1]
|
| 15 |
+
inputs= torch.sigmoid(inputs)
|
| 16 |
+
cost = torch.nn.BCELoss(mask, reduction='none')(inputs, targets.float())
|
| 17 |
+
cost = torch.sum(cost.float().mean((1, 2, 3))) # before sum
|
| 18 |
+
return l_weight*cost
|
| 19 |
+
|
| 20 |
+
# ------------ cats losses ----------
|
| 21 |
+
def bdrloss(prediction, label, radius,device='cpu'):
|
| 22 |
+
'''
|
| 23 |
+
The boundary tracing loss that handles the confusing pixels.
|
| 24 |
+
'''
|
| 25 |
+
|
| 26 |
+
filt = torch.ones(1, 1, 2*radius+1, 2*radius+1)
|
| 27 |
+
filt.requires_grad = False
|
| 28 |
+
filt = filt.to(device)
|
| 29 |
+
|
| 30 |
+
bdr_pred = prediction * label
|
| 31 |
+
pred_bdr_sum = label * F.conv2d(bdr_pred, filt, bias=None, stride=1, padding=radius)
|
| 32 |
+
|
| 33 |
+
texture_mask = F.conv2d(label.float(), filt, bias=None, stride=1, padding=radius)
|
| 34 |
+
mask = (texture_mask != 0).float()
|
| 35 |
+
mask[label == 1] = 0
|
| 36 |
+
pred_texture_sum = F.conv2d(prediction * (1-label) * mask, filt, bias=None, stride=1, padding=radius)
|
| 37 |
+
|
| 38 |
+
softmax_map = torch.clamp(pred_bdr_sum / (pred_texture_sum + pred_bdr_sum + 1e-10), 1e-10, 1 - 1e-10)
|
| 39 |
+
cost = -label * torch.log(softmax_map)
|
| 40 |
+
cost[label == 0] = 0
|
| 41 |
+
|
| 42 |
+
return torch.sum(cost.float().mean((1, 2, 3)))
|
| 43 |
+
|
| 44 |
+
def textureloss(prediction, label, mask_radius, device='cpu'):
|
| 45 |
+
'''
|
| 46 |
+
The texture suppression loss that smooths the texture regions.
|
| 47 |
+
'''
|
| 48 |
+
filt1 = torch.ones(1, 1, 3, 3)
|
| 49 |
+
filt1.requires_grad = False
|
| 50 |
+
filt1 = filt1.to(device)
|
| 51 |
+
filt2 = torch.ones(1, 1, 2*mask_radius+1, 2*mask_radius+1)
|
| 52 |
+
filt2.requires_grad = False
|
| 53 |
+
filt2 = filt2.to(device)
|
| 54 |
+
|
| 55 |
+
pred_sums = F.conv2d(prediction.float(), filt1, bias=None, stride=1, padding=1)
|
| 56 |
+
label_sums = F.conv2d(label.float(), filt2, bias=None, stride=1, padding=mask_radius)
|
| 57 |
+
|
| 58 |
+
mask = 1 - torch.gt(label_sums, 0).float()
|
| 59 |
+
|
| 60 |
+
loss = -torch.log(torch.clamp(1-pred_sums/9, 1e-10, 1-1e-10))
|
| 61 |
+
loss[mask == 0] = 0
|
| 62 |
+
|
| 63 |
+
return torch.sum(loss.float().mean((1, 2, 3)))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def cats_loss(prediction, label, l_weight=[0.,0.], device='cpu'):
|
| 67 |
+
# tracingLoss
|
| 68 |
+
|
| 69 |
+
tex_factor,bdr_factor = l_weight
|
| 70 |
+
balanced_w = 1.1
|
| 71 |
+
label = label.float()
|
| 72 |
+
prediction = prediction.float()
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
mask = label.clone()
|
| 75 |
+
|
| 76 |
+
num_positive = torch.sum((mask == 1).float()).float()
|
| 77 |
+
num_negative = torch.sum((mask == 0).float()).float()
|
| 78 |
+
beta = num_negative / (num_positive + num_negative)
|
| 79 |
+
mask[mask == 1] = beta
|
| 80 |
+
mask[mask == 0] = balanced_w * (1 - beta)
|
| 81 |
+
mask[mask == 2] = 0
|
| 82 |
+
|
| 83 |
+
prediction = torch.sigmoid(prediction)
|
| 84 |
+
|
| 85 |
+
cost = torch.nn.functional.binary_cross_entropy(
|
| 86 |
+
prediction.float(), label.float(), weight=mask, reduction='none')
|
| 87 |
+
cost = torch.sum(cost.float().mean((1, 2, 3))) # by me
|
| 88 |
+
label_w = (label != 0).float()
|
| 89 |
+
textcost = textureloss(prediction.float(), label_w.float(), mask_radius=4, device=device)
|
| 90 |
+
bdrcost = bdrloss(prediction.float(), label_w.float(), radius=4, device=device)
|
| 91 |
+
|
| 92 |
+
return cost + bdr_factor * bdrcost + tex_factor * textcost
|
TEED/ted.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TEED: is a Tiny but Efficient Edge Detection, it comes from the LDC-B3
|
| 2 |
+
# with a Slightly modification
|
| 3 |
+
# LDC parameters:
|
| 4 |
+
# 155665
|
| 5 |
+
# TED > 58K
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from TEED.utils.AF.Fsmish import smish as Fsmish
|
| 12 |
+
from TEED.utils.AF.Xsmish import Smish
|
| 13 |
+
from TEED.utils.img_processing import count_parameters
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def weight_init(m):
|
| 17 |
+
if isinstance(m, (nn.Conv2d,)):
|
| 18 |
+
torch.nn.init.xavier_normal_(m.weight, gain=1.0)
|
| 19 |
+
|
| 20 |
+
if m.bias is not None:
|
| 21 |
+
torch.nn.init.zeros_(m.bias)
|
| 22 |
+
|
| 23 |
+
# for fusion layer
|
| 24 |
+
if isinstance(m, (nn.ConvTranspose2d,)):
|
| 25 |
+
torch.nn.init.xavier_normal_(m.weight, gain=1.0)
|
| 26 |
+
if m.bias is not None:
|
| 27 |
+
torch.nn.init.zeros_(m.bias)
|
| 28 |
+
|
| 29 |
+
class CoFusion(nn.Module):
|
| 30 |
+
# from LDC
|
| 31 |
+
|
| 32 |
+
def __init__(self, in_ch, out_ch):
|
| 33 |
+
super(CoFusion, self).__init__()
|
| 34 |
+
self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3,
|
| 35 |
+
stride=1, padding=1) # before 64
|
| 36 |
+
self.conv3= nn.Conv2d(32, out_ch, kernel_size=3,
|
| 37 |
+
stride=1, padding=1)# before 64 instead of 32
|
| 38 |
+
self.relu = nn.ReLU()
|
| 39 |
+
self.norm_layer1 = nn.GroupNorm(4, 32) # before 64
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
# fusecat = torch.cat(x, dim=1)
|
| 43 |
+
attn = self.relu(self.norm_layer1(self.conv1(x)))
|
| 44 |
+
attn = F.softmax(self.conv3(attn), dim=1)
|
| 45 |
+
return ((x * attn).sum(1)).unsqueeze(1)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class CoFusion2(nn.Module):
|
| 49 |
+
# TEDv14-3
|
| 50 |
+
def __init__(self, in_ch, out_ch):
|
| 51 |
+
super(CoFusion2, self).__init__()
|
| 52 |
+
self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3,
|
| 53 |
+
stride=1, padding=1) # before 64
|
| 54 |
+
# self.conv2 = nn.Conv2d(32, 32, kernel_size=3,
|
| 55 |
+
# stride=1, padding=1)# before 64
|
| 56 |
+
self.conv3 = nn.Conv2d(32, out_ch, kernel_size=3,
|
| 57 |
+
stride=1, padding=1)# before 64 instead of 32
|
| 58 |
+
self.smish= Smish()#nn.ReLU(inplace=True)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
# fusecat = torch.cat(x, dim=1)
|
| 63 |
+
attn = self.conv1(self.smish(x))
|
| 64 |
+
attn = self.conv3(self.smish(attn)) # before , )dim=1)
|
| 65 |
+
|
| 66 |
+
# return ((fusecat * attn).sum(1)).unsqueeze(1)
|
| 67 |
+
return ((x * attn).sum(1)).unsqueeze(1)
|
| 68 |
+
|
| 69 |
+
class DoubleFusion(nn.Module):
|
| 70 |
+
# TED fusion before the final edge map prediction
|
| 71 |
+
def __init__(self, in_ch, out_ch):
|
| 72 |
+
super(DoubleFusion, self).__init__()
|
| 73 |
+
self.DWconv1 = nn.Conv2d(in_ch, in_ch*8, kernel_size=3,
|
| 74 |
+
stride=1, padding=1, groups=in_ch) # before 64
|
| 75 |
+
self.PSconv1 = nn.PixelShuffle(1)
|
| 76 |
+
|
| 77 |
+
self.DWconv2 = nn.Conv2d(24, 24*1, kernel_size=3,
|
| 78 |
+
stride=1, padding=1,groups=24)# before 64 instead of 32
|
| 79 |
+
|
| 80 |
+
self.AF= Smish()#XAF() #nn.Tanh()# XAF() # # Smish()#
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
# fusecat = torch.cat(x, dim=1)
|
| 85 |
+
attn = self.PSconv1(self.DWconv1(self.AF(x))) # #TEED best res TEDv14 [8, 32, 352, 352]
|
| 86 |
+
|
| 87 |
+
attn2 = self.PSconv1(self.DWconv2(self.AF(attn))) # #TEED best res TEDv14[8, 3, 352, 352]
|
| 88 |
+
|
| 89 |
+
return Fsmish(((attn2 +attn).sum(1)).unsqueeze(1)) #TED best res
|
| 90 |
+
|
| 91 |
+
class _DenseLayer(nn.Sequential):
|
| 92 |
+
def __init__(self, input_features, out_features):
|
| 93 |
+
super(_DenseLayer, self).__init__()
|
| 94 |
+
|
| 95 |
+
self.add_module('conv1', nn.Conv2d(input_features, out_features,
|
| 96 |
+
kernel_size=3, stride=1, padding=2, bias=True)),
|
| 97 |
+
self.add_module('smish1', Smish()),
|
| 98 |
+
self.add_module('conv2', nn.Conv2d(out_features, out_features,
|
| 99 |
+
kernel_size=3, stride=1, bias=True))
|
| 100 |
+
def forward(self, x):
|
| 101 |
+
x1, x2 = x
|
| 102 |
+
|
| 103 |
+
new_features = super(_DenseLayer, self).forward(Fsmish(x1)) # F.relu()
|
| 104 |
+
|
| 105 |
+
return 0.5 * (new_features + x2), x2
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class _DenseBlock(nn.Sequential):
|
| 109 |
+
def __init__(self, num_layers, input_features, out_features):
|
| 110 |
+
super(_DenseBlock, self).__init__()
|
| 111 |
+
for i in range(num_layers):
|
| 112 |
+
layer = _DenseLayer(input_features, out_features)
|
| 113 |
+
self.add_module('denselayer%d' % (i + 1), layer)
|
| 114 |
+
input_features = out_features
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class UpConvBlock(nn.Module):
|
| 118 |
+
def __init__(self, in_features, up_scale):
|
| 119 |
+
super(UpConvBlock, self).__init__()
|
| 120 |
+
self.up_factor = 2
|
| 121 |
+
self.constant_features = 16
|
| 122 |
+
|
| 123 |
+
layers = self.make_deconv_layers(in_features, up_scale)
|
| 124 |
+
assert layers is not None, layers
|
| 125 |
+
self.features = nn.Sequential(*layers)
|
| 126 |
+
|
| 127 |
+
def make_deconv_layers(self, in_features, up_scale):
|
| 128 |
+
layers = []
|
| 129 |
+
all_pads=[0,0,1,3,7]
|
| 130 |
+
for i in range(up_scale):
|
| 131 |
+
kernel_size = 2 ** up_scale
|
| 132 |
+
pad = all_pads[up_scale] # kernel_size-1
|
| 133 |
+
out_features = self.compute_out_features(i, up_scale)
|
| 134 |
+
layers.append(nn.Conv2d(in_features, out_features, 1))
|
| 135 |
+
layers.append(Smish())
|
| 136 |
+
layers.append(nn.ConvTranspose2d(
|
| 137 |
+
out_features, out_features, kernel_size, stride=2, padding=pad))
|
| 138 |
+
in_features = out_features
|
| 139 |
+
return layers
|
| 140 |
+
|
| 141 |
+
def compute_out_features(self, idx, up_scale):
|
| 142 |
+
return 1 if idx == up_scale - 1 else self.constant_features
|
| 143 |
+
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
return self.features(x)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class SingleConvBlock(nn.Module):
|
| 149 |
+
def __init__(self, in_features, out_features, stride, use_ac=False):
|
| 150 |
+
super(SingleConvBlock, self).__init__()
|
| 151 |
+
# self.use_bn = use_bs
|
| 152 |
+
self.use_ac=use_ac
|
| 153 |
+
self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride,
|
| 154 |
+
bias=True)
|
| 155 |
+
if self.use_ac:
|
| 156 |
+
self.smish = Smish()
|
| 157 |
+
|
| 158 |
+
def forward(self, x):
|
| 159 |
+
x = self.conv(x)
|
| 160 |
+
if self.use_ac:
|
| 161 |
+
return self.smish(x)
|
| 162 |
+
else:
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
class DoubleConvBlock(nn.Module):
|
| 166 |
+
def __init__(self, in_features, mid_features,
|
| 167 |
+
out_features=None,
|
| 168 |
+
stride=1,
|
| 169 |
+
use_act=True):
|
| 170 |
+
super(DoubleConvBlock, self).__init__()
|
| 171 |
+
|
| 172 |
+
self.use_act = use_act
|
| 173 |
+
if out_features is None:
|
| 174 |
+
out_features = mid_features
|
| 175 |
+
self.conv1 = nn.Conv2d(in_features, mid_features,
|
| 176 |
+
3, padding=1, stride=stride)
|
| 177 |
+
self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1)
|
| 178 |
+
self.smish= Smish()#nn.ReLU(inplace=True)
|
| 179 |
+
|
| 180 |
+
def forward(self, x):
|
| 181 |
+
x = self.conv1(x)
|
| 182 |
+
x = self.smish(x)
|
| 183 |
+
x = self.conv2(x)
|
| 184 |
+
if self.use_act:
|
| 185 |
+
x = self.smish(x)
|
| 186 |
+
return x
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class TED(nn.Module):
|
| 190 |
+
""" Definition of Tiny and Efficient Edge Detector
|
| 191 |
+
model
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
def __init__(self):
|
| 195 |
+
super(TED, self).__init__()
|
| 196 |
+
self.block_1 = DoubleConvBlock(3, 16, 16, stride=2,)
|
| 197 |
+
self.block_2 = DoubleConvBlock(16, 32, use_act=False)
|
| 198 |
+
self.dblock_3 = _DenseBlock(1, 32, 48) # [32,48,100,100] before (2, 32, 64)
|
| 199 |
+
|
| 200 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 201 |
+
|
| 202 |
+
# skip1 connection, see fig. 2
|
| 203 |
+
self.side_1 = SingleConvBlock(16, 32, 2)
|
| 204 |
+
|
| 205 |
+
# skip2 connection, see fig. 2
|
| 206 |
+
self.pre_dense_3 = SingleConvBlock(32, 48, 1) # before (32, 64, 1)
|
| 207 |
+
|
| 208 |
+
# USNet
|
| 209 |
+
self.up_block_1 = UpConvBlock(16, 1)
|
| 210 |
+
self.up_block_2 = UpConvBlock(32, 1)
|
| 211 |
+
self.up_block_3 = UpConvBlock(48, 2) # (32, 64, 1)
|
| 212 |
+
|
| 213 |
+
self.block_cat = DoubleFusion(3,3) # TEED: DoubleFusion
|
| 214 |
+
|
| 215 |
+
self.apply(weight_init)
|
| 216 |
+
|
| 217 |
+
def slice(self, tensor, slice_shape):
|
| 218 |
+
t_shape = tensor.shape
|
| 219 |
+
img_h, img_w = slice_shape
|
| 220 |
+
if img_w!=t_shape[-1] or img_h!=t_shape[2]:
|
| 221 |
+
new_tensor = F.interpolate(
|
| 222 |
+
tensor, size=(img_h, img_w), mode='bicubic',align_corners=False)
|
| 223 |
+
|
| 224 |
+
else:
|
| 225 |
+
new_tensor=tensor
|
| 226 |
+
# tensor[..., :height, :width]
|
| 227 |
+
return new_tensor
|
| 228 |
+
def resize_input(self,tensor):
|
| 229 |
+
t_shape = tensor.shape
|
| 230 |
+
if t_shape[2] % 8 != 0 or t_shape[3] % 8 != 0:
|
| 231 |
+
img_w= ((t_shape[3]// 8) + 1) * 8
|
| 232 |
+
img_h = ((t_shape[2] // 8) + 1) * 8
|
| 233 |
+
new_tensor = F.interpolate(
|
| 234 |
+
tensor, size=(img_h, img_w), mode='bicubic', align_corners=False)
|
| 235 |
+
else:
|
| 236 |
+
new_tensor = tensor
|
| 237 |
+
return new_tensor
|
| 238 |
+
|
| 239 |
+
def crop_bdcn(data1, h, w, crop_h, crop_w):
|
| 240 |
+
# Based on BDCN Implementation @ https://github.com/pkuCactus/BDCN
|
| 241 |
+
_, _, h1, w1 = data1.size()
|
| 242 |
+
assert (h <= h1 and w <= w1)
|
| 243 |
+
data = data1[:, :, crop_h:crop_h + h, crop_w:crop_w + w]
|
| 244 |
+
return data
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def forward(self, x, single_test=False):
|
| 248 |
+
assert x.ndim == 4, x.shape
|
| 249 |
+
# supose the image size is 352x352
|
| 250 |
+
|
| 251 |
+
# Block 1
|
| 252 |
+
block_1 = self.block_1(x) # [8,16,176,176]
|
| 253 |
+
block_1_side = self.side_1(block_1) # 16 [8,32,88,88]
|
| 254 |
+
|
| 255 |
+
# Block 2
|
| 256 |
+
block_2 = self.block_2(block_1) # 32 # [8,32,176,176]
|
| 257 |
+
block_2_down = self.maxpool(block_2) # [8,32,88,88]
|
| 258 |
+
block_2_add = block_2_down + block_1_side # [8,32,88,88]
|
| 259 |
+
|
| 260 |
+
# Block 3
|
| 261 |
+
block_3_pre_dense = self.pre_dense_3(block_2_down) # [8,64,88,88] block 3 L connection
|
| 262 |
+
block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) # [8,64,88,88]
|
| 263 |
+
|
| 264 |
+
# upsampling blocks
|
| 265 |
+
out_1 = self.up_block_1(block_1)
|
| 266 |
+
out_2 = self.up_block_2(block_2)
|
| 267 |
+
out_3 = self.up_block_3(block_3)
|
| 268 |
+
|
| 269 |
+
results = [out_1, out_2, out_3]
|
| 270 |
+
|
| 271 |
+
# concatenate multiscale outputs
|
| 272 |
+
block_cat = torch.cat(results, dim=1) # Bx6xHxW
|
| 273 |
+
block_cat = self.block_cat(block_cat) # Bx1xHxW DoubleFusion
|
| 274 |
+
|
| 275 |
+
results.append(block_cat)
|
| 276 |
+
return results
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
if __name__ == '__main__':
|
| 280 |
+
batch_size = 8
|
| 281 |
+
img_height = 352
|
| 282 |
+
img_width = 352
|
| 283 |
+
|
| 284 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 285 |
+
device = "cpu"
|
| 286 |
+
input = torch.rand(batch_size, 3, img_height, img_width).to(device)
|
| 287 |
+
# target = torch.rand(batch_size, 1, img_height, img_width).to(device)
|
| 288 |
+
print(f"input shape: {input.shape}")
|
| 289 |
+
model = TED().to(device)
|
| 290 |
+
output = model(input)
|
| 291 |
+
print(f"output shapes: {[t.shape for t in output]}")
|
| 292 |
+
|
| 293 |
+
# for i in range(20000):
|
| 294 |
+
# print(i)
|
| 295 |
+
# output = model(input)
|
| 296 |
+
# loss = nn.MSELoss()(output[-1], target)
|
| 297 |
+
# loss.backward()
|
TEED/util.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def load_model(filename: str, remote_url: str, model_dir: str) -> str:
|
| 7 |
+
"""
|
| 8 |
+
Load the model from the specified filename and remote URL if it doesn't exist locally.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
filename (str): The filename of the model.
|
| 12 |
+
remote_url (str): The remote URL of the model.
|
| 13 |
+
"""
|
| 14 |
+
local_path = os.path.join(model_dir, filename)
|
| 15 |
+
if not os.path.exists(local_path):
|
| 16 |
+
from scripts.utils import load_file_from_url
|
| 17 |
+
load_file_from_url(remote_url, model_dir=model_dir)
|
| 18 |
+
return local_path
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def HWC3(x):
|
| 22 |
+
assert x.dtype == np.uint8
|
| 23 |
+
if x.ndim == 2:
|
| 24 |
+
x = x[:, :, None]
|
| 25 |
+
assert x.ndim == 3
|
| 26 |
+
H, W, C = x.shape
|
| 27 |
+
assert C == 1 or C == 3 or C == 4
|
| 28 |
+
if C == 3:
|
| 29 |
+
return x
|
| 30 |
+
if C == 1:
|
| 31 |
+
return np.concatenate([x, x, x], axis=2)
|
| 32 |
+
if C == 4:
|
| 33 |
+
color = x[:, :, 0:3].astype(np.float32)
|
| 34 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
| 35 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
| 36 |
+
y = y.clip(0, 255).astype(np.uint8)
|
| 37 |
+
return y
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def make_noise_disk(H, W, C, F):
|
| 41 |
+
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
|
| 42 |
+
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
|
| 43 |
+
noise = noise[F: F + H, F: F + W]
|
| 44 |
+
noise -= np.min(noise)
|
| 45 |
+
noise /= np.max(noise)
|
| 46 |
+
if C == 1:
|
| 47 |
+
noise = noise[:, :, None]
|
| 48 |
+
return noise
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def nms(x, t, s):
|
| 52 |
+
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
| 53 |
+
|
| 54 |
+
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
| 55 |
+
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
| 56 |
+
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
| 57 |
+
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
| 58 |
+
|
| 59 |
+
y = np.zeros_like(x)
|
| 60 |
+
|
| 61 |
+
for f in [f1, f2, f3, f4]:
|
| 62 |
+
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
| 63 |
+
|
| 64 |
+
z = np.zeros_like(y, dtype=np.uint8)
|
| 65 |
+
z[y > t] = 255
|
| 66 |
+
return z
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def min_max_norm(x):
|
| 70 |
+
x -= np.min(x)
|
| 71 |
+
x /= np.maximum(np.max(x), 1e-5)
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def safe_step(x, step=2):
|
| 76 |
+
y = x.astype(np.float32) * float(step + 1)
|
| 77 |
+
y = y.astype(np.int32).astype(np.float32) / float(step)
|
| 78 |
+
return y
|