lime-j commited on
Commit
15a930e
·
1 Parent(s): e0037be

Upload 89 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +9 -0
  2. RDNet-main/RDNet-main/.gitignore +1 -0
  3. RDNet-main/RDNet-main/README.md +99 -0
  4. RDNet-main/RDNet-main/VOC2012_224_train_png.txt +0 -0
  5. RDNet-main/RDNet-main/data/VOC2012_224_train_png.txt +0 -0
  6. RDNet-main/RDNet-main/data/__pycache__/dataset_sir.cpython-38.pyc +0 -0
  7. RDNet-main/RDNet-main/data/__pycache__/image_folder.cpython-38.pyc +0 -0
  8. RDNet-main/RDNet-main/data/__pycache__/torchdata.cpython-38.pyc +0 -0
  9. RDNet-main/RDNet-main/data/__pycache__/transforms.cpython-38.pyc +0 -0
  10. RDNet-main/RDNet-main/data/dataset_sir.py +332 -0
  11. RDNet-main/RDNet-main/data/image_folder.py +51 -0
  12. RDNet-main/RDNet-main/data/real_test.txt +20 -0
  13. RDNet-main/RDNet-main/data/torchdata.py +67 -0
  14. RDNet-main/RDNet-main/data/transforms.py +301 -0
  15. RDNet-main/RDNet-main/engine.py +178 -0
  16. RDNet-main/RDNet-main/figures/Input_car.jpg +0 -0
  17. RDNet-main/RDNet-main/figures/Input_class.png +3 -0
  18. RDNet-main/RDNet-main/figures/Input_green.png +3 -0
  19. RDNet-main/RDNet-main/figures/Ours_car.png +3 -0
  20. RDNet-main/RDNet-main/figures/Ours_class.png +3 -0
  21. RDNet-main/RDNet-main/figures/Ours_green.png +3 -0
  22. RDNet-main/RDNet-main/figures/Ours_white.png +3 -0
  23. RDNet-main/RDNet-main/figures/Title.png +0 -0
  24. RDNet-main/RDNet-main/figures/input_white.jpg +0 -0
  25. RDNet-main/RDNet-main/figures/net.png +3 -0
  26. RDNet-main/RDNet-main/figures/result.png +3 -0
  27. RDNet-main/RDNet-main/figures/vis.png +3 -0
  28. RDNet-main/RDNet-main/models/__init__.py +11 -0
  29. RDNet-main/RDNet-main/models/__pycache__/__init__.cpython-38.pyc +0 -0
  30. RDNet-main/RDNet-main/models/__pycache__/base_model.cpython-38.pyc +0 -0
  31. RDNet-main/RDNet-main/models/__pycache__/cls_model_eval_nocls_reg.cpython-38.pyc +0 -0
  32. RDNet-main/RDNet-main/models/__pycache__/losses.cpython-38.pyc +0 -0
  33. RDNet-main/RDNet-main/models/__pycache__/networks.cpython-38.pyc +0 -0
  34. RDNet-main/RDNet-main/models/__pycache__/vgg.cpython-38.pyc +0 -0
  35. RDNet-main/RDNet-main/models/__pycache__/vit_feature_extractor.cpython-38.pyc +0 -0
  36. RDNet-main/RDNet-main/models/arch/NAFNET.py +480 -0
  37. RDNet-main/RDNet-main/models/arch/RDnet_.py +202 -0
  38. RDNet-main/RDNet-main/models/arch/__pycache__/RDnet_.cpython-38.pyc +0 -0
  39. RDNet-main/RDNet-main/models/arch/__pycache__/classifier.cpython-38.pyc +0 -0
  40. RDNet-main/RDNet-main/models/arch/__pycache__/focalnet.cpython-38.pyc +0 -0
  41. RDNet-main/RDNet-main/models/arch/__pycache__/modules_sig.cpython-38.pyc +0 -0
  42. RDNet-main/RDNet-main/models/arch/__pycache__/reverse_function.cpython-38.pyc +0 -0
  43. RDNet-main/RDNet-main/models/arch/classifier.py +49 -0
  44. RDNet-main/RDNet-main/models/arch/decode.py +36 -0
  45. RDNet-main/RDNet-main/models/arch/focalnet.py +589 -0
  46. RDNet-main/RDNet-main/models/arch/modules_sig.py +304 -0
  47. RDNet-main/RDNet-main/models/arch/reverse_function.py +153 -0
  48. RDNet-main/RDNet-main/models/arch/vgg.py +90 -0
  49. RDNet-main/RDNet-main/models/base_model.py +71 -0
  50. RDNet-main/RDNet-main/models/cls_model_eval_nocls_reg.py +517 -0
.gitattributes CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ RDNet-main/RDNet-main/figures/Input_class.png filter=lfs diff=lfs merge=lfs -text
37
+ RDNet-main/RDNet-main/figures/Input_green.png filter=lfs diff=lfs merge=lfs -text
38
+ RDNet-main/RDNet-main/figures/net.png filter=lfs diff=lfs merge=lfs -text
39
+ RDNet-main/RDNet-main/figures/Ours_car.png filter=lfs diff=lfs merge=lfs -text
40
+ RDNet-main/RDNet-main/figures/Ours_class.png filter=lfs diff=lfs merge=lfs -text
41
+ RDNet-main/RDNet-main/figures/Ours_green.png filter=lfs diff=lfs merge=lfs -text
42
+ RDNet-main/RDNet-main/figures/Ours_white.png filter=lfs diff=lfs merge=lfs -text
43
+ RDNet-main/RDNet-main/figures/result.png filter=lfs diff=lfs merge=lfs -text
44
+ RDNet-main/RDNet-main/figures/vis.png filter=lfs diff=lfs merge=lfs -text
RDNet-main/RDNet-main/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
RDNet-main/RDNet-main/README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="https://github.com/lime-j/RDNet/blob/main/figures/Title.png?raw=true" width=95%>
3
+ <p>
4
+
5
+ # Reversible Decoupling Network for Single Image Reflection Removal
6
+
7
+ <div align="center">
8
+
9
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/reversible-decoupling-network-for-single/reflection-removal-on-sir-2-objects)](https://paperswithcode.com/sota/reflection-removal-on-sir-2-objects?p=reversible-decoupling-network-for-single)
10
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/reversible-decoupling-network-for-single/reflection-removal-on-sir-2-wild)](https://paperswithcode.com/sota/reflection-removal-on-sir-2-wild?p=reversible-decoupling-network-for-single)
11
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/reversible-decoupling-network-for-single/reflection-removal-on-sir-2-postcard)](https://paperswithcode.com/sota/reflection-removal-on-sir-2-postcard?p=reversible-decoupling-network-for-single)
12
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/reversible-decoupling-network-for-single/reflection-removal-on-nature)](https://paperswithcode.com/sota/reflection-removal-on-nature?p=reversible-decoupling-network-for-single)
13
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/reversible-decoupling-network-for-single/reflection-removal-on-real20)](https://paperswithcode.com/sota/reflection-removal-on-real20?p=reversible-decoupling-network-for-single)
14
+
15
+ </div>
16
+ <p align="center" style="font-size: larger;">
17
+ <a href="https://arxiv.org/abs/2410.08063"> Reversible Decoupling Network for Single Image Reflection Removal</a>
18
+ </p>
19
+ <p align="center">
20
+ <a href="https://github.com/WHiTEWoLFJ"> Hao Zhao</a> ⚔️,
21
+ <a href="https://github.com/lime-j"> Mingjia Li</a> ⚔️,
22
+ <a href="https://github.com/mingcv"> Qiming Hu</a>,
23
+ <a href="https://sites.google.com/view/xjguo"> Xiaojie Guo</a> 🦅,
24
+ <p align="center">(⚔️: equal contribution, 🦅 : corresponding author)</p>
25
+ </p>
26
+
27
+ <p align="center">
28
+ <img src="https://github.com/lime-j/RDNet/blob/main/figures/net.png?raw=true" width=95%>
29
+ <p>
30
+ **Our work is accepted by CVPR 2025! See you at the conference!**
31
+ <details>
32
+ <summary>Click for the Abstract of RDNet</summary>
33
+ We present a Reversible Decoupling Network (RDNet), which employs a reversible encoder to secure valuable information while flexibly decoupling transmission-and-reflection-relevant features during the forward pass. Furthermore, we customize a transmission-rate-aware prompt generator to dynamically calibrate features, further boosting performance. Extensive experiments demonstrate the superiority of RDNet over existing SOTA methods on five widely-adopted benchmark datasets.
34
+ </details>
35
+
36
+ ## 🚀Todo
37
+
38
+ - [ ] Release the Training code of RDNet.
39
+
40
+ ## 🌠 Gallery
41
+
42
+
43
+ <table class="center">
44
+ <tr>
45
+ <td><p style="text-align: center">Class Room</p></td>
46
+ <td><p style="text-align: center">White Wall Chamber</p></td>
47
+ </tr>
48
+ <tr>
49
+ <td>
50
+ <div style="width: 100%; max-width: 600px; position: relative;">
51
+ <img src="https://github.com/lime-j/RDNet/blob/main/figures/Input_class.png?raw=true" style="width: 100%; height: 300px; display: block;">
52
+ <img src="https://github.com/lime-j/RDNet/blob/main/figures/Ours_class.png?raw=true" style="width: 100%; height: 300px; display: block; position: absolute; top: 0; left: 0; opacity: 0; transition: opacity 0.5s;" onmouseover="this.style.opacity=1;" onmouseout="this.style.opacity=0;">
53
+ </div>
54
+ </td>
55
+ <td>
56
+ <div style="width: 100%; max-width: 600px; position: relative;">
57
+ <img src="https://github.com/lime-j/RDNet/blob/main/figures/input_white.jpg?raw=true" style="width: 100%; height: 300px; display: block;">
58
+ <img src="https://github.com/lime-j/RDNet/blob/main/figures/Ours_white.png?raw=true" style="width: 100%; height: 300px; display: block; position: absolute; top: 0; left: 0; opacity: 0; transition: opacity 0.5s;" onmouseover="this.style.opacity=1;" onmouseout="this.style.opacity=0;">
59
+ </div>
60
+ </td>
61
+ </tr>
62
+ <tr>
63
+ <td><p style="text-align: center">Car Window</p></td>
64
+ <td><p style="text-align: center">Very Green Office</p></td>
65
+ </tr>
66
+ <tr>
67
+ <td>
68
+ <div style="width: 100%; max-width: 600px; position: relative;">
69
+ <img src="https://github.com/lime-j/RDNet/blob/main/figures/Input_car.jpg?raw=true" style="width: 100%; height: 300px; display: block;">
70
+ <img src="https://github.com/lime-j/RDNet/blob/main/figures/Ours_car.png?raw=true" style="width: 100%; height: 300px; display: block; position: absolute; top: 0; left: 0; opacity: 0; transition: opacity 0.5s;" onmouseover="this.style.opacity=1;" onmouseout="this.style.opacity=0;">
71
+ </div>
72
+ </td>
73
+ <td>
74
+ <div style="width: 100%; max-width: 600px; position: relative;">
75
+ <img src="https://github.com/lime-j/RDNet/blob/main/figures/Input_green.png?raw=true" style="width: 100%; height: 300px; display: block;">
76
+ <img src="https://github.com/lime-j/RDNet/blob/main/figures/Ours_green.png?raw=true" style="width: 100%; height: 300px; display: block; position: absolute; top: 0; left: 0; opacity: 0; transition: opacity 0.5s;" onmouseover="this.style.opacity=1;" onmouseout="this.style.opacity=0;">
77
+ </div>
78
+ </td>
79
+ </tr>
80
+ </table>
81
+
82
+ ## Requirements
83
+ We recommend torch 2.x for our code, but it should works fine with most of the modern versions.
84
+
85
+ ```
86
+ pip install torch>=2.0 torchvision
87
+ pip install einops ema-pytorch fsspec fvcore huggingface-hub matplotlib numpy opencv-python omegaconf pytorch-msssim scikit-image scikit-learn scipy tensorboard tensorboardx wandb timm
88
+ ```
89
+
90
+ # Testing
91
+ The checkpoint for the main network is available at https://checkpoints.mingjia.li/rdnet.pth ; while the model for cls_model is at https://checkpoints.mingjia.li/cls_model.pth . Please put the cls_model.pth under "pretrained" folder.
92
+
93
+ ```python
94
+ python3 test_sirs.py --icnn_path <path to the main checkpoint> --resume
95
+ ```
96
+ # Training
97
+
98
+ Training script will be released in a few days.
99
+
RDNet-main/RDNet-main/VOC2012_224_train_png.txt ADDED
The diff for this file is too large to render. See raw diff
 
RDNet-main/RDNet-main/data/VOC2012_224_train_png.txt ADDED
The diff for this file is too large to render. See raw diff
 
RDNet-main/RDNet-main/data/__pycache__/dataset_sir.cpython-38.pyc ADDED
Binary file (10.9 kB). View file
 
RDNet-main/RDNet-main/data/__pycache__/image_folder.cpython-38.pyc ADDED
Binary file (1.58 kB). View file
 
RDNet-main/RDNet-main/data/__pycache__/torchdata.cpython-38.pyc ADDED
Binary file (2.86 kB). View file
 
RDNet-main/RDNet-main/data/__pycache__/transforms.cpython-38.pyc ADDED
Binary file (9.37 kB). View file
 
RDNet-main/RDNet-main/data/dataset_sir.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os.path
3
+ import os.path
4
+ import random
5
+ from os.path import join
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch.utils.data
10
+ import torchvision.transforms.functional as TF
11
+ from PIL import Image
12
+ from scipy.signal import convolve2d
13
+
14
+ from data.image_folder import make_dataset
15
+ from data.torchdata import Dataset as BaseDataset
16
+ from data.transforms import to_tensor
17
+
18
+
19
+ def __scale_width(img, target_width):
20
+ ow, oh = img.size
21
+ if (ow == target_width):
22
+ return img
23
+ w = target_width
24
+ h = int(target_width * oh / ow)
25
+ h = math.ceil(h / 2.) * 2 # round up to even
26
+ return img.resize((w, h), Image.BICUBIC)
27
+
28
+
29
+ def __scale_height(img, target_height):
30
+ ow, oh = img.size
31
+ if (oh == target_height):
32
+ return img
33
+ h = target_height
34
+ w = int(target_height * ow / oh)
35
+ w = math.ceil(w / 2.) * 2
36
+ return img.resize((w, h), Image.BICUBIC)
37
+
38
+
39
+ def paired_data_transforms(img_1, img_2, unaligned_transforms=False):
40
+ def get_params(img, output_size):
41
+ w, h = img.size
42
+ th, tw = output_size
43
+ if w == tw and h == th:
44
+ return 0, 0, h, w
45
+
46
+ i = random.randint(0, h - th)
47
+ j = random.randint(0, w - tw)
48
+ return i, j, th, tw
49
+
50
+ target_size = int(random.randint(320, 640) / 2.) * 2
51
+ ow, oh = img_1.size
52
+ if ow >= oh:
53
+ img_1 = __scale_height(img_1, target_size)
54
+ img_2 = __scale_height(img_2, target_size)
55
+ else:
56
+ img_1 = __scale_width(img_1, target_size)
57
+ img_2 = __scale_width(img_2, target_size)
58
+
59
+ if random.random() < 0.5:
60
+ img_1 = TF.hflip(img_1)
61
+ img_2 = TF.hflip(img_2)
62
+
63
+ if random.random() < 0.5:
64
+ angle = random.choice([90, 180, 270])
65
+ img_1 = TF.rotate(img_1, angle)
66
+ img_2 = TF.rotate(img_2, angle)
67
+
68
+ i, j, h, w = get_params(img_1, (320, 320))
69
+ img_1 = TF.crop(img_1, i, j, h, w)
70
+
71
+ if unaligned_transforms:
72
+ # print('random shift')
73
+ i_shift = random.randint(-10, 10)
74
+ j_shift = random.randint(-10, 10)
75
+ i += i_shift
76
+ j += j_shift
77
+
78
+ img_2 = TF.crop(img_2, i, j, h, w)
79
+
80
+ return img_1, img_2
81
+
82
+
83
+ class ReflectionSynthesis(object):
84
+ def __init__(self):
85
+ # Kernel Size of the Gaussian Blurry
86
+ self.kernel_sizes = [5, 7, 9, 11]
87
+ self.kernel_probs = [0.1, 0.2, 0.3, 0.4]
88
+
89
+ # Sigma of the Gaussian Blurry
90
+ self.sigma_range = [2, 5]
91
+ self.alpha_range = [0.8, 1.0]
92
+ self.beta_range = [0.4, 1.0]
93
+
94
+ def __call__(self, T_, R_):
95
+ T_ = np.asarray(T_, np.float32) / 255.
96
+ R_ = np.asarray(R_, np.float32) / 255.
97
+
98
+ kernel_size = np.random.choice(self.kernel_sizes, p=self.kernel_probs)
99
+ sigma = np.random.uniform(self.sigma_range[0], self.sigma_range[1])
100
+ kernel = cv2.getGaussianKernel(kernel_size, sigma)
101
+ kernel2d = np.dot(kernel, kernel.T)
102
+ for i in range(3):
103
+ R_[..., i] = convolve2d(R_[..., i], kernel2d, mode='same')
104
+
105
+ a = np.random.uniform(self.alpha_range[0], self.alpha_range[1])
106
+ b = np.random.uniform(self.beta_range[0], self.beta_range[1])
107
+ T, R = a * T_, b * R_
108
+
109
+ if random.random() < 0.7:
110
+ I = T + R - T * R
111
+
112
+ else:
113
+ I = T + R
114
+ if np.max(I) > 1:
115
+ m = I[I > 1]
116
+ m = (np.mean(m) - 1) * 1.3
117
+ I = np.clip(T + np.clip(R - m, 0, 1), 0, 1)
118
+
119
+ return T_, R_, I
120
+
121
+
122
+ class DataLoader(torch.utils.data.DataLoader):
123
+ def __init__(self, dataset, batch_size, shuffle, *args, **kwargs):
124
+ super(DataLoader, self).__init__(dataset, batch_size, shuffle, *args, **kwargs)
125
+ self.shuffle = shuffle
126
+
127
+ def reset(self):
128
+ if self.shuffle:
129
+ print('Reset Dataset...')
130
+ self.dataset.reset()
131
+
132
+
133
+ class DSRDataset(BaseDataset):
134
+ def __init__(self, datadir, fns=None, size=None, enable_transforms=True):
135
+ super(DSRDataset, self).__init__()
136
+ self.size = size
137
+ self.datadir = datadir
138
+ self.enable_transforms = enable_transforms
139
+ sortkey = lambda key: os.path.split(key)[-1]
140
+ self.paths = sorted(make_dataset(datadir, fns), key=sortkey)
141
+ if size is not None:
142
+ self.paths = np.random.choice(self.paths, size)
143
+
144
+ self.syn_model = ReflectionSynthesis()
145
+ self.reset(shuffle=False)
146
+
147
+ def reset(self, shuffle=True):
148
+ if shuffle:
149
+ random.shuffle(self.paths)
150
+ num_paths = len(self.paths) // 2
151
+ self.B_paths = self.paths[0:num_paths]
152
+ self.R_paths = self.paths[num_paths:2 * num_paths]
153
+
154
+ def data_synthesis(self, t_img, r_img):
155
+ if self.enable_transforms:
156
+ t_img, r_img = paired_data_transforms(t_img, r_img)
157
+
158
+ t_img, r_img, m_img = self.syn_model(t_img, r_img)
159
+
160
+ B = TF.to_tensor(t_img)
161
+ R = TF.to_tensor(r_img)
162
+ M = TF.to_tensor(m_img)
163
+
164
+ return B, R, M
165
+
166
+ def __getitem__(self, index):
167
+ index_B = index % len(self.B_paths)
168
+ index_R = index % len(self.R_paths)
169
+
170
+ B_path = self.B_paths[index_B]
171
+ R_path = self.R_paths[index_R]
172
+
173
+ t_img = Image.open(B_path).convert('RGB')
174
+ r_img = Image.open(R_path).convert('RGB')
175
+
176
+ B, R, M = self.data_synthesis(t_img, r_img)
177
+ fn = os.path.basename(B_path)
178
+ return {'input': M, 'target_t': B, 'target_r': M-B, 'fn': fn, 'real': False}
179
+
180
+ def __len__(self):
181
+ if self.size is not None:
182
+ return min(max(len(self.B_paths), len(self.R_paths)), self.size)
183
+ else:
184
+ return max(len(self.B_paths), len(self.R_paths))
185
+
186
+
187
+ class DSRTestDataset(BaseDataset):
188
+ def __init__(self, datadir, fns=None, size=None, enable_transforms=False, unaligned_transforms=False,
189
+ round_factor=1, flag=None, if_align=True):
190
+ super(DSRTestDataset, self).__init__()
191
+ self.size = size
192
+ self.datadir = datadir
193
+ self.fns = fns or os.listdir(join(datadir, 'blended'))
194
+ self.enable_transforms = enable_transforms
195
+ self.unaligned_transforms = unaligned_transforms
196
+ self.round_factor = round_factor
197
+ self.flag = flag
198
+ self.if_align = True # if_align
199
+
200
+ if size is not None:
201
+ self.fns = self.fns[:size]
202
+
203
+ def align(self, x1, x2):
204
+ h, w = x1.height, x1.width
205
+ h, w = h // 32 * 32, w // 32 * 32
206
+ x1 = x1.resize((w, h))
207
+ x2 = x2.resize((w, h))
208
+ return x1, x2
209
+
210
+ def __getitem__(self, index):
211
+ fn = self.fns[index]
212
+
213
+ t_img = Image.open(join(self.datadir, 'transmission_layer', fn)).convert('RGB')
214
+ m_img = Image.open(join(self.datadir, 'blended', fn)).convert('RGB')
215
+
216
+ if self.if_align:
217
+ t_img, m_img = self.align(t_img, m_img)
218
+
219
+ if self.enable_transforms:
220
+ t_img, m_img = paired_data_transforms(t_img, m_img, self.unaligned_transforms)
221
+
222
+ B = TF.to_tensor(t_img)
223
+ M = TF.to_tensor(m_img)
224
+
225
+ dic = {'input': M, 'target_t': B, 'fn': fn, 'real': True, 'target_r': M - B}
226
+ if self.flag is not None:
227
+ dic.update(self.flag)
228
+ return dic
229
+
230
+ def __len__(self):
231
+ if self.size is not None:
232
+ return min(len(self.fns), self.size)
233
+ else:
234
+ return len(self.fns)
235
+
236
+
237
+ class SIRTestDataset(BaseDataset):
238
+ def __init__(self, datadir, fns=None, size=None, if_align=True):
239
+ super(SIRTestDataset, self).__init__()
240
+ self.size = size
241
+ self.datadir = datadir
242
+ self.fns = fns or os.listdir(join(datadir, 'blended'))
243
+ self.if_align = if_align
244
+
245
+ if size is not None:
246
+ self.fns = self.fns[:size]
247
+
248
+ def align(self, x1, x2, x3):
249
+ h, w = x1.height, x1.width
250
+ h, w = h // 32 * 32, w // 32 * 32
251
+ x1 = x1.resize((w, h))
252
+ x2 = x2.resize((w, h))
253
+ x3 = x3.resize((w, h))
254
+ return x1, x2, x3
255
+
256
+ def __getitem__(self, index):
257
+ fn = self.fns[index]
258
+
259
+ t_img = Image.open(join(self.datadir, 'transmission_layer', fn)).convert('RGB')
260
+ r_img = Image.open(join(self.datadir, 'reflection_layer', fn)).convert('RGB')
261
+ m_img = Image.open(join(self.datadir, 'blended', fn)).convert('RGB')
262
+
263
+ if self.if_align:
264
+ t_img, r_img, m_img = self.align(t_img, r_img, m_img)
265
+
266
+ B = TF.to_tensor(t_img)
267
+ R = TF.to_tensor(r_img)
268
+ M = TF.to_tensor(m_img)
269
+
270
+ dic = {'input': M, 'target_t': B, 'fn': fn, 'real': True, 'target_r': R, 'target_r_hat': M - B}
271
+ return dic
272
+
273
+ def __len__(self):
274
+ if self.size is not None:
275
+ return min(len(self.fns), self.size)
276
+ else:
277
+ return len(self.fns)
278
+
279
+
280
+ class RealDataset(BaseDataset):
281
+ def __init__(self, datadir, fns=None, size=None):
282
+ super(RealDataset, self).__init__()
283
+ self.size = size
284
+ self.datadir = datadir
285
+ self.fns = fns or os.listdir(join(datadir))
286
+
287
+ if size is not None:
288
+ self.fns = self.fns[:size]
289
+
290
+ def align(self, x):
291
+ h, w = x.height, x.width
292
+ h, w = h // 32 * 32, w // 32 * 32
293
+ x = x.resize((w, h))
294
+ return x
295
+
296
+ def __getitem__(self, index):
297
+ fn = self.fns[index]
298
+ B = -1
299
+ m_img = Image.open(join(self.datadir, fn)).convert('RGB')
300
+ M = to_tensor(self.align(m_img))
301
+ data = {'input': M, 'target_t': B, 'fn': fn}
302
+ return data
303
+
304
+ def __len__(self):
305
+ if self.size is not None:
306
+ return min(len(self.fns), self.size)
307
+ else:
308
+ return len(self.fns)
309
+
310
+
311
+ class FusionDataset(BaseDataset):
312
+ def __init__(self, datasets, fusion_ratios=None):
313
+ self.datasets = datasets
314
+ self.size = sum([len(dataset) for dataset in datasets])
315
+ self.fusion_ratios = fusion_ratios or [1. / len(datasets)] * len(datasets)
316
+ print('[i] using a fusion dataset: %d %s imgs fused with ratio %s' % (
317
+ self.size, [len(dataset) for dataset in datasets], self.fusion_ratios))
318
+
319
+ def reset(self):
320
+ for dataset in self.datasets:
321
+ dataset.reset()
322
+
323
+ def __getitem__(self, index):
324
+ residual = 1
325
+ for i, ratio in enumerate(self.fusion_ratios):
326
+ if random.random() < ratio / residual or i == len(self.fusion_ratios) - 1:
327
+ dataset = self.datasets[i]
328
+ return dataset[index % len(dataset)]
329
+ residual -= ratio
330
+
331
+ def __len__(self):
332
+ return self.size
RDNet-main/RDNet-main/data/image_folder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###############################################################################
2
+ # Code from
3
+ # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4
+ # Modified the original code so that it also loads images from the current
5
+ # directory as well as the subdirectories
6
+ ###############################################################################
7
+
8
+ import torch.utils.data as data
9
+
10
+ from PIL import Image
11
+ import os
12
+ import os.path
13
+
14
+ IMG_EXTENSIONS = [
15
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
16
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
17
+ ]
18
+
19
+
20
+ def read_fns(filename):
21
+ with open(filename) as f:
22
+ fns = f.readlines()
23
+ fns = [fn.strip() for fn in fns]
24
+ return fns
25
+
26
+
27
+ def is_image_file(filename):
28
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
29
+
30
+
31
+ def make_dataset(dir, fns=None):
32
+ images = []
33
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
34
+
35
+ if fns is None:
36
+ for root, _, fnames in sorted(os.walk(dir)):
37
+ for fname in fnames:
38
+ if is_image_file(fname):
39
+ path = os.path.join(root, fname)
40
+ images.append(path)
41
+ else:
42
+ for fname in fns:
43
+ if is_image_file(fname):
44
+ path = os.path.join(dir, fname)
45
+ images.append(path)
46
+
47
+ return images
48
+
49
+
50
+ def default_loader(path):
51
+ return Image.open(path).convert('RGB')
RDNet-main/RDNet-main/data/real_test.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 3.jpg
2
+ 4.jpg
3
+ 9.jpg
4
+ 12.jpg
5
+ 15.jpg
6
+ 22.jpg
7
+ 23.jpg
8
+ 25.jpg
9
+ 29.jpg
10
+ 39.jpg
11
+ 46.jpg
12
+ 47.jpg
13
+ 58.jpg
14
+ 86.jpg
15
+ 87.jpg
16
+ 89.jpg
17
+ 93.jpg
18
+ 103.jpg
19
+ 107.jpg
20
+ 110.jpg
RDNet-main/RDNet-main/data/torchdata.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import warnings
3
+
4
+
5
+ class Dataset(object):
6
+ """An abstract class representing a Dataset.
7
+
8
+ All other datasets should subclass it. All subclasses should override
9
+ ``__len__``, that provides the size of the dataset, and ``__getitem__``,
10
+ supporting integer indexing in range from 0 to len(self) exclusive.
11
+ """
12
+
13
+ def __getitem__(self, index):
14
+ raise NotImplementedError
15
+
16
+ def __len__(self):
17
+ raise NotImplementedError
18
+
19
+ def __add__(self, other):
20
+ return ConcatDataset([self, other])
21
+
22
+ def reset(self):
23
+ return
24
+
25
+
26
+ class ConcatDataset(Dataset):
27
+ """
28
+ Dataset to concatenate multiple datasets.
29
+ Purpose: useful to assemble different existing datasets, possibly
30
+ large-scale datasets as the concatenation operation is done in an
31
+ on-the-fly manner.
32
+
33
+ Arguments:
34
+ datasets (sequence): List of datasets to be concatenated
35
+ """
36
+
37
+ @staticmethod
38
+ def cumsum(sequence):
39
+ r, s = [], 0
40
+ for e in sequence:
41
+ l = len(e)
42
+ r.append(l + s)
43
+ s += l
44
+ return r
45
+
46
+ def __init__(self, datasets):
47
+ super(ConcatDataset, self).__init__()
48
+ assert len(datasets) > 0, 'datasets should not be an empty iterable'
49
+ self.datasets = list(datasets)
50
+ self.cumulative_sizes = self.cumsum(self.datasets)
51
+
52
+ def __len__(self):
53
+ return self.cumulative_sizes[-1]
54
+
55
+ def __getitem__(self, idx):
56
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
57
+ if dataset_idx == 0:
58
+ sample_idx = idx
59
+ else:
60
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
61
+ return self.datasets[dataset_idx][sample_idx]
62
+
63
+ @property
64
+ def cummulative_sizes(self):
65
+ warnings.warn("cummulative_sizes attribute is renamed to "
66
+ "cumulative_sizes", DeprecationWarning, stacklevel=2)
67
+ return self.cumulative_sizes
RDNet-main/RDNet-main/data/transforms.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+
3
+ import math
4
+ import random
5
+
6
+ import torch
7
+ from PIL import Image
8
+
9
+ try:
10
+ import accimage
11
+ except ImportError:
12
+ accimage = None
13
+ import numpy as np
14
+ import scipy.stats as st
15
+ import cv2
16
+ import collections
17
+ import torchvision.transforms as transforms
18
+ import util.util as util
19
+ from scipy.signal import convolve2d
20
+
21
+
22
+ # utility
23
+ def _is_pil_image(img):
24
+ if accimage is not None:
25
+ return isinstance(img, (Image.Image, accimage.Image))
26
+ else:
27
+ return isinstance(img, Image.Image)
28
+
29
+
30
+ def _is_tensor_image(img):
31
+ return torch.is_tensor(img) and img.ndimension() == 3
32
+
33
+
34
+ def _is_numpy_image(img):
35
+ return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
36
+
37
+
38
+ def arrshow(arr):
39
+ Image.fromarray(arr.astype(np.uint8)).show()
40
+
41
+
42
+ def get_transform(opt):
43
+ transform_list = []
44
+ osizes = util.parse_args(opt.loadSize)
45
+ fineSize = util.parse_args(opt.fineSize)
46
+ if opt.resize_or_crop == 'resize_and_crop':
47
+ transform_list.append(
48
+ transforms.RandomChoice([
49
+ transforms.Resize([osize, osize], Image.BICUBIC) for osize in osizes
50
+ ]))
51
+ transform_list.append(transforms.RandomCrop(fineSize))
52
+ elif opt.resize_or_crop == 'crop':
53
+ transform_list.append(transforms.RandomCrop(fineSize))
54
+ elif opt.resize_or_crop == 'scale_width':
55
+ transform_list.append(transforms.Lambda(
56
+ lambda img: __scale_width(img, fineSize)))
57
+ elif opt.resize_or_crop == 'scale_width_and_crop':
58
+ transform_list.append(transforms.Lambda(
59
+ lambda img: __scale_width(img, opt.loadSize)))
60
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
61
+
62
+ if opt.isTrain and not opt.no_flip:
63
+ transform_list.append(transforms.RandomHorizontalFlip())
64
+
65
+ return transforms.Compose(transform_list)
66
+
67
+
68
+ to_norm_tensor = transforms.Compose([
69
+ transforms.ToTensor(),
70
+ transforms.Normalize(
71
+ (0.5, 0.5, 0.5),
72
+ (0.5, 0.5, 0.5)
73
+ )
74
+ ])
75
+
76
+ to_tensor = transforms.ToTensor()
77
+
78
+
79
+ def __scale_width(img, target_width):
80
+ ow, oh = img.size
81
+ if (ow == target_width):
82
+ return img
83
+ w = target_width
84
+ h = int(target_width * oh / ow)
85
+ h = math.ceil(h / 2.) * 2 # round up to even
86
+ return img.resize((w, h), Image.BICUBIC)
87
+
88
+
89
+ # functional
90
+ def gaussian_blur(img, kernel_size, sigma):
91
+ if not _is_pil_image(img):
92
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
93
+
94
+ img = np.asarray(img)
95
+ # the 3rd dimension (i.e. inter-band) would be filtered which is unwanted for our purpose
96
+ # new = gaussian_filter(img, sigma=sigma, truncate=truncate)
97
+ if isinstance(kernel_size, int):
98
+ kernel_size = (kernel_size, kernel_size)
99
+ elif isinstance(kernel_size, collections.Sequence):
100
+ assert len(kernel_size) == 2
101
+ new = cv2.GaussianBlur(img, kernel_size, sigma) # apply gaussian filter band by band
102
+ return Image.fromarray(new)
103
+
104
+
105
+ # transforms
106
+ class GaussianBlur(object):
107
+ def __init__(self, kernel_size=11, sigma=3):
108
+ self.kernel_size = kernel_size
109
+ self.sigma = sigma
110
+
111
+ def __call__(self, img):
112
+ return gaussian_blur(img, self.kernel_size, self.sigma)
113
+
114
+
115
+ class ReflectionSythesis_0(object):
116
+ """Reflection image data synthesis for weakly-supervised learning
117
+ of ICCV 2017 paper *"A Generic Deep Architecture for Single Image Reflection Removal and Image Smoothing"*
118
+ """
119
+
120
+ def __init__(self, kernel_sizes=None, low_sigma=2, high_sigma=5, low_gamma=1.3,
121
+ high_gamma=1.3, low_delta=0.4, high_delta=1.8):
122
+ self.kernel_sizes = kernel_sizes or [11]
123
+ self.low_sigma = low_sigma
124
+ self.high_sigma = high_sigma
125
+ self.low_gamma = low_gamma
126
+ self.high_gamma = high_gamma
127
+ self.low_delta = low_delta
128
+ self.high_delta = high_delta
129
+ print('[i] reflection sythesis model: {}'.format({
130
+ 'kernel_sizes': kernel_sizes, 'low_sigma': low_sigma, 'high_sigma': high_sigma,
131
+ 'low_gamma': low_gamma, 'high_gamma': high_gamma}))
132
+
133
+ def __call__(self, B, R):
134
+ if not _is_pil_image(B):
135
+ raise TypeError('B should be PIL Image. Got {}'.format(type(B)))
136
+ if not _is_pil_image(R):
137
+ raise TypeError('R should be PIL Image. Got {}'.format(type(R)))
138
+ B_ = np.asarray(B, np.float32)
139
+ if random.random() < 0.4:
140
+ B_ = np.tile(np.random.uniform(0, 30, (1, 1, 1)), B_.shape) / 255.
141
+ else:
142
+ B_ = np.tile(np.random.normal(50, 50, (1, 1, 3)), (B_.shape[0], B_.shape[1], 1)).clip(0, 255) / 255.
143
+ R_ = np.asarray(R, np.float32) / 255.
144
+
145
+ kernel_size = np.random.choice(self.kernel_sizes)
146
+ sigma = np.random.uniform(self.low_sigma, self.high_sigma)
147
+ gamma = np.random.uniform(self.low_gamma, self.high_gamma)
148
+ delta = np.random.uniform(self.low_delta, self.high_delta)
149
+ R_blur = R_
150
+ kernel = cv2.getGaussianKernel(11, sigma)
151
+ kernel2d = np.dot(kernel, kernel.T)
152
+
153
+ for i in range(3):
154
+ R_blur[..., i] = convolve2d(R_blur[..., i], kernel2d, mode='same')
155
+
156
+ R_blur = np.clip(R_blur - np.mean(R_blur) * gamma, 0, 1)
157
+ R_blur = np.clip(R_blur * delta, 0, 1)
158
+ M_ = np.clip(R_blur + B_, 0, 1)
159
+
160
+ return B_, R_blur, M_
161
+
162
+
163
+ class ReflectionSythesis_1(object):
164
+ """Reflection image data synthesis for weakly-supervised learning
165
+ of ICCV 2017 paper *"A Generic Deep Architecture for Single Image Reflection Removal and Image Smoothing"*
166
+ """
167
+
168
+ def __init__(self, kernel_sizes=None, low_sigma=2, high_sigma=5, low_gamma=1.3, high_gamma=1.3):
169
+ self.kernel_sizes = kernel_sizes or [11]
170
+ self.low_sigma = low_sigma
171
+ self.high_sigma = high_sigma
172
+ self.low_gamma = low_gamma
173
+ self.high_gamma = high_gamma
174
+ print('[i] reflection sythesis model: {}'.format({
175
+ 'kernel_sizes': kernel_sizes, 'low_sigma': low_sigma, 'high_sigma': high_sigma,
176
+ 'low_gamma': low_gamma, 'high_gamma': high_gamma}))
177
+
178
+ def __call__(self, B, R):
179
+ if not _is_pil_image(B):
180
+ raise TypeError('B should be PIL Image. Got {}'.format(type(B)))
181
+ if not _is_pil_image(R):
182
+ raise TypeError('R should be PIL Image. Got {}'.format(type(R)))
183
+
184
+ B_ = np.asarray(B, np.float32) / 255.
185
+ R_ = np.asarray(R, np.float32) / 255.
186
+
187
+ kernel_size = np.random.choice(self.kernel_sizes)
188
+ sigma = np.random.uniform(self.low_sigma, self.high_sigma)
189
+ gamma = np.random.uniform(self.low_gamma, self.high_gamma)
190
+ R_blur = R_
191
+ kernel = cv2.getGaussianKernel(11, sigma)
192
+ kernel2d = np.dot(kernel, kernel.T)
193
+
194
+ for i in range(3):
195
+ R_blur[..., i] = convolve2d(R_blur[..., i], kernel2d, mode='same')
196
+
197
+ M_ = B_ + R_blur
198
+
199
+ if np.max(M_) > 1:
200
+ m = M_[M_ > 1]
201
+ m = (np.mean(m) - 1) * gamma
202
+ R_blur = np.clip(R_blur - m, 0, 1)
203
+ M_ = np.clip(R_blur + B_, 0, 1)
204
+
205
+ return B_, R_blur, M_
206
+
207
+
208
+ class Sobel(object):
209
+ def __call__(self, img):
210
+ if not _is_pil_image(img):
211
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
212
+
213
+ gray_img = np.array(img.convert('L'))
214
+ x = cv2.Sobel(gray_img, cv2.CV_16S, 1, 0)
215
+ y = cv2.Sobel(gray_img, cv2.CV_16S, 0, 1)
216
+
217
+ absX = cv2.convertScaleAbs(x)
218
+ absY = cv2.convertScaleAbs(y)
219
+
220
+ dst = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
221
+ return Image.fromarray(dst)
222
+
223
+
224
+ class ReflectionSythesis_2(object):
225
+ """Reflection image data synthesis for weakly-supervised learning
226
+ of CVPR 2018 paper *"Single Image Reflection Separation with Perceptual Losses"*
227
+ """
228
+
229
+ def __init__(self, kernel_sizes=None):
230
+ self.kernel_sizes = kernel_sizes or np.linspace(1, 5, 80)
231
+
232
+ @staticmethod
233
+ def gkern(kernlen=100, nsig=1):
234
+ """Returns a 2D Gaussian kernel array."""
235
+ interval = (2 * nsig + 1.) / (kernlen)
236
+ x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kernlen + 1)
237
+ kern1d = np.diff(st.norm.cdf(x))
238
+ kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
239
+ kernel = kernel_raw / kernel_raw.sum()
240
+ kernel = kernel / kernel.max()
241
+ return kernel
242
+
243
+ def __call__(self, t, r):
244
+ t = np.float32(t) / 255.
245
+ r = np.float32(r) / 255.
246
+ ori_t = t
247
+ # create a vignetting mask
248
+ g_mask = self.gkern(560, 3)
249
+ g_mask = np.dstack((g_mask, g_mask, g_mask))
250
+ sigma = self.kernel_sizes[np.random.randint(0, len(self.kernel_sizes))]
251
+
252
+ t = np.power(t, 2.2)
253
+ r = np.power(r, 2.2)
254
+
255
+ sz = int(2 * np.ceil(2 * sigma) + 1)
256
+
257
+ r_blur = cv2.GaussianBlur(r, (sz, sz), sigma, sigma, 0)
258
+ blend = r_blur + t
259
+
260
+ att = 1.08 + np.random.random() / 10.0
261
+
262
+ for i in range(3):
263
+ maski = blend[:, :, i] > 1
264
+ mean_i = max(1., np.sum(blend[:, :, i] * maski) / (maski.sum() + 1e-6))
265
+ r_blur[:, :, i] = r_blur[:, :, i] - (mean_i - 1) * att
266
+ r_blur[r_blur >= 1] = 1
267
+ r_blur[r_blur <= 0] = 0
268
+
269
+ h, w = r_blur.shape[0:2]
270
+ neww = np.random.randint(0, 560 - w - 10)
271
+ newh = np.random.randint(0, 560 - h - 10)
272
+ alpha1 = g_mask[newh:newh + h, neww:neww + w, :]
273
+ alpha2 = 1 - np.random.random() / 5.0
274
+ r_blur_mask = np.multiply(r_blur, alpha1)
275
+ blend = r_blur_mask + t * alpha2
276
+
277
+ t = np.power(t, 1 / 2.2)
278
+ r_blur_mask = np.power(r_blur_mask, 1 / 2.2)
279
+ blend = np.power(blend, 1 / 2.2)
280
+ blend[blend >= 1] = 1
281
+ blend[blend <= 0] = 0
282
+
283
+ return np.float32(ori_t), np.float32(r_blur_mask), np.float32(blend)
284
+
285
+
286
+ # Examples
287
+ if __name__ == '__main__':
288
+ """cv2 imread"""
289
+ # img = cv2.imread('testdata_reflection_real/19-input.png')
290
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
291
+ # img2 = cv2.GaussianBlur(img, (11,11), 3)
292
+
293
+ """Sobel Operator"""
294
+ # img = np.array(Image.open('datasets/VOC224/train/B/2007_000250.png').convert('L'))
295
+
296
+ """Reflection Sythesis"""
297
+ b = Image.open('')
298
+ r = Image.open('')
299
+ G = ReflectionSythesis_0()
300
+ m, r = G(b, r)
301
+ r.show()
RDNet-main/RDNet-main/engine.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import util.util as util
3
+ from models import make_model
4
+ import time
5
+ import os
6
+ import sys
7
+ from os.path import join
8
+ from util.visualizer import Visualizer
9
+ import tqdm
10
+ import visdom
11
+ import numpy as np
12
+ from tools import mutils
13
+
14
+ class Engine(object):
15
+ def __init__(self, opt,eval_dataset_real,eval_dataset_solidobject,eval_dataset_postcard,eval_dataloader_wild):
16
+ self.opt = opt
17
+ self.writer = None
18
+ self.visualizer = None
19
+ self.model = None
20
+ self.best_val_loss = 1e6
21
+ self.eval_dataset_real = eval_dataset_real
22
+ self.eval_dataset_solidobject = eval_dataset_solidobject
23
+ self.eval_dataset_postcard = eval_dataset_postcard
24
+ self.eval_dataloader_wild = eval_dataloader_wild
25
+ self.result_dir = os.path.join(f'./experiment/{self.opt.name}/results',
26
+ mutils.get_formatted_time())
27
+ self.biggest_psnr=0
28
+ self.__setup()
29
+
30
+ def __setup(self):
31
+ self.basedir = join('experiment', self.opt.name)
32
+ os.makedirs(self.basedir, exist_ok=True)
33
+
34
+ opt = self.opt
35
+
36
+ """Model"""
37
+ self.model = make_model(self.opt.model) # models.__dict__[self.opt.model]()
38
+ self.model.initialize(opt)
39
+ if True:
40
+ print("IN")
41
+ self.writer = util.get_summary_writer(os.path.join(self.basedir, 'logs'))
42
+ self.visualizer = Visualizer(opt)
43
+
44
+ def train(self, train_loader, **kwargs):
45
+ print('\nEpoch: %d' % self.epoch)
46
+ avg_meters = util.AverageMeters()
47
+ opt = self.opt
48
+ model = self.model
49
+ epoch = self.epoch
50
+
51
+ epoch_start_time = time.time()
52
+ for i, data in tqdm.tqdm(enumerate(train_loader)):
53
+
54
+ iter_start_time = time.time()
55
+ iterations = self.iterations
56
+
57
+ model.set_input(data, mode='train')
58
+ model.optimize_parameters(**kwargs)
59
+
60
+ errors = model.get_current_errors()
61
+ avg_meters.update(errors)
62
+ util.progress_bar(i, len(train_loader), str(avg_meters))
63
+ util.write_loss(self.writer, 'train', avg_meters, iterations)
64
+ if iterations%100==0:
65
+ imgs=[]
66
+ output_clean,output_reflection,input=model.return_output()
67
+ # output_clean,input=model.return_output()
68
+
69
+ output_clean=np.transpose(output_clean,(2,0,1))/255
70
+ #output_reflection = np.transpose(output_reflection, (2, 0, 1))/255
71
+ input = np.transpose(input, (2, 0, 1))/255
72
+ imgs.append(output_clean)
73
+ #imgs.append(output_reflection)
74
+ imgs.append(input)
75
+ util.get_visual(self.writer,iterations,imgs)
76
+ if iterations % opt.print_freq == 0 and opt.display_id != 0:
77
+ t = (time.time() - iter_start_time)
78
+
79
+ self.iterations += 1
80
+
81
+ self.epoch += 1
82
+
83
+ if True:#not self.opt.no_log:
84
+ if self.epoch % opt.save_epoch_freq == 0:
85
+ save_dir = os.path.join(self.result_dir, '%03d' % self.epoch)
86
+ os.makedirs(save_dir, exist_ok=True)
87
+ matrix_real=self.eval(self.eval_dataset_real, dataset_name='testdata_real20', savedir=save_dir, suffix='real20')
88
+ matrix_solid=self.eval(self.eval_dataset_solidobject, dataset_name='testdata_solidobject', savedir=save_dir,
89
+ suffix='solidobject')
90
+ matrix_post=self.eval(self.eval_dataset_postcard, dataset_name='testdata_postcard', savedir=save_dir, suffix='postcard')
91
+ matrix_wild=self.eval(self.eval_dataloader_wild, dataset_name='testdata_wild', savedir=save_dir, suffix='wild')
92
+ sum_PSNR_real=matrix_real['PSNR']*20
93
+ sum_PSNR_solid=matrix_solid['PSNR']*200
94
+ sum_PSNR_post=matrix_post['PSNR']*199
95
+ sum_PSNR_wild=matrix_wild['PSNR']*55
96
+ print("sum_PSNR_real: ",matrix_real['PSNR'],"sum_PSNR_solid: ",matrix_solid['PSNR'],"sum_PSNR_post: ",matrix_post['PSNR'],"sum_PSNR_wild: ",matrix_wild['PSNR'])
97
+ sum_PSNR = float(sum_PSNR_real + sum_PSNR_solid + sum_PSNR_post + sum_PSNR_wild)/474.0
98
+ print('总PSNR:', sum_PSNR)
99
+ if sum_PSNR>self.biggest_psnr:
100
+ self.biggest_psnr=sum_PSNR
101
+ print('saving the model at epoch %d, iters %d' %(self.epoch, self.iterations))
102
+ model.save()
103
+ print('highest: ',self.biggest_psnr,' name: ',opt.name)
104
+
105
+ print('saving the latest model at the end of epoch %d, iters %d' %
106
+ (self.epoch, self.iterations))
107
+ model.save(label='latest')
108
+
109
+ print('Time Taken: %d sec' %
110
+ (time.time() - epoch_start_time))
111
+
112
+ # model.update_learning_rate()
113
+ try:
114
+ train_loader.reset()
115
+ except:
116
+ pass
117
+
118
+ def eval(self, val_loader, dataset_name, savedir='./tmp', loss_key=None, **kwargs):
119
+ # print(dataset_name)
120
+ if savedir is not None:
121
+ os.makedirs(savedir, exist_ok=True)
122
+ self.f = open(os.path.join(savedir, 'metrics.txt'), 'w+')
123
+ self.f.write(dataset_name + '\n')
124
+ avg_meters = util.AverageMeters()
125
+ model = self.model
126
+ opt = self.opt
127
+ with torch.no_grad():
128
+ for i, data in enumerate(val_loader):
129
+ if self.opt.select is not None and data['fn'][0] not in [f'{self.opt.select}.jpg']:
130
+ continue
131
+ #print(data.shape())
132
+ index = model.eval(data, savedir=savedir, **kwargs)
133
+
134
+ # print(data['fn'][0], index)
135
+ if savedir is not None:
136
+ self.f.write(f"{data['fn'][0]} {index['PSNR']} {index['SSIM']}\n")
137
+ avg_meters.update(index)
138
+ util.progress_bar(i, len(val_loader), str(avg_meters))
139
+
140
+ if not opt.no_log:
141
+ util.write_loss(self.writer, join('eval', dataset_name), avg_meters, self.epoch)
142
+
143
+ if loss_key is not None:
144
+ val_loss = avg_meters[loss_key]
145
+ if val_loss < self.best_val_loss:
146
+ self.best_val_loss = val_loss
147
+ print('saving the best model at the end of epoch %d, iters %d' %
148
+ (self.epoch, self.iterations))
149
+ model.save(label='best_{}_{}'.format(loss_key, dataset_name))
150
+
151
+ return avg_meters
152
+
153
+ def test(self, test_loader, savedir=None, **kwargs):
154
+ model = self.model
155
+ opt = self.opt
156
+ with torch.no_grad():
157
+ for i, data in enumerate(test_loader):
158
+ model.test(data, savedir=savedir, **kwargs)
159
+ util.progress_bar(i, len(test_loader))
160
+
161
+ def save_eval(self, label):
162
+ self.model.save_eval(label)
163
+
164
+ @property
165
+ def iterations(self):
166
+ return self.model.iterations
167
+
168
+ @iterations.setter
169
+ def iterations(self, i):
170
+ self.model.iterations = i
171
+
172
+ @property
173
+ def epoch(self):
174
+ return self.model.epoch
175
+
176
+ @epoch.setter
177
+ def epoch(self, e):
178
+ self.model.epoch = e
RDNet-main/RDNet-main/figures/Input_car.jpg ADDED
RDNet-main/RDNet-main/figures/Input_class.png ADDED

Git LFS Details

  • SHA256: 9b3823f5b2f4319e23470a1a747bb2974ddc63f323fed61eb8ceedfce4d48343
  • Pointer size: 131 Bytes
  • Size of remote file: 246 kB
RDNet-main/RDNet-main/figures/Input_green.png ADDED

Git LFS Details

  • SHA256: 62805a64a7167f0000a4ec1c8e92f0b45a2f7f6fdd9ec1bb7d623ae2f5d5cffe
  • Pointer size: 131 Bytes
  • Size of remote file: 418 kB
RDNet-main/RDNet-main/figures/Ours_car.png ADDED

Git LFS Details

  • SHA256: 313fbf8070c481775b44153eaea645f35ca8112d7616b5af8ab2e982a37e030e
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB
RDNet-main/RDNet-main/figures/Ours_class.png ADDED

Git LFS Details

  • SHA256: e4d97e42e8953fb7c5af9b8d7cfd2123ffeb10e734f50f98bd40b7f531f2f02b
  • Pointer size: 131 Bytes
  • Size of remote file: 280 kB
RDNet-main/RDNet-main/figures/Ours_green.png ADDED

Git LFS Details

  • SHA256: ee3fb53a2f9f410c2e3b8d9679ba3296034786c922fcc70fcd6681af0ce43b36
  • Pointer size: 131 Bytes
  • Size of remote file: 414 kB
RDNet-main/RDNet-main/figures/Ours_white.png ADDED

Git LFS Details

  • SHA256: 9b79ca2d5c76f21e947ec93752ae21e33c301f4099edb8375925a6bb0274977d
  • Pointer size: 131 Bytes
  • Size of remote file: 187 kB
RDNet-main/RDNet-main/figures/Title.png ADDED
RDNet-main/RDNet-main/figures/input_white.jpg ADDED
RDNet-main/RDNet-main/figures/net.png ADDED

Git LFS Details

  • SHA256: d0293129d5ef9c40eb72c2cb33863f4a37b45062f4369285387081da3644a8bf
  • Pointer size: 131 Bytes
  • Size of remote file: 725 kB
RDNet-main/RDNet-main/figures/result.png ADDED

Git LFS Details

  • SHA256: 7bf2e5f68b691f3b0f6246d35f88ffe2a36a12b3c79b7020ba9483ce9fef355c
  • Pointer size: 131 Bytes
  • Size of remote file: 184 kB
RDNet-main/RDNet-main/figures/vis.png ADDED

Git LFS Details

  • SHA256: 325aed759f19aaae59e9a06c1ae4b8c1e4d3adf1cae2d8c092c1c836834828d8
  • Pointer size: 132 Bytes
  • Size of remote file: 2.21 MB
RDNet-main/RDNet-main/models/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ from models.arch import *
4
+
5
+ from models.cls_model_eval_nocls_reg import ClsModel as ClsReg
6
+
7
+
8
+ def make_model(name: str):
9
+
10
+ model = ClsReg()
11
+ return model
RDNet-main/RDNet-main/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (414 Bytes). View file
 
RDNet-main/RDNet-main/models/__pycache__/base_model.cpython-38.pyc ADDED
Binary file (3.02 kB). View file
 
RDNet-main/RDNet-main/models/__pycache__/cls_model_eval_nocls_reg.cpython-38.pyc ADDED
Binary file (17.4 kB). View file
 
RDNet-main/RDNet-main/models/__pycache__/losses.cpython-38.pyc ADDED
Binary file (15.3 kB). View file
 
RDNet-main/RDNet-main/models/__pycache__/networks.cpython-38.pyc ADDED
Binary file (9.34 kB). View file
 
RDNet-main/RDNet-main/models/__pycache__/vgg.cpython-38.pyc ADDED
Binary file (2.15 kB). View file
 
RDNet-main/RDNet-main/models/__pycache__/vit_feature_extractor.cpython-38.pyc ADDED
Binary file (6.95 kB). View file
 
RDNet-main/RDNet-main/models/arch/NAFNET.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2022 megvii-model. All Rights Reserved.
3
+ # ------------------------------------------------------------------------
4
+
5
+ '''
6
+ Simple Baselines for Image Restoration
7
+
8
+ @article{chen2022simple,
9
+ title={Simple Baselines for Image Restoration},
10
+ author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
11
+ journal={arXiv preprint arXiv:2204.04676},
12
+ year={2022}
13
+ }
14
+ '''
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ # from .models.archs.arch_util import LayerNorm2d
20
+ import sys
21
+ sys.path.append('/ghome/zhuyr/Deref_RW/networks/')
22
+
23
+ class LayerNormFunction(torch.autograd.Function):
24
+
25
+ @staticmethod
26
+ def forward(ctx, x, weight, bias, eps):
27
+ ctx.eps = eps
28
+ N, C, H, W = x.size()
29
+ mu = x.mean(1, keepdim=True)
30
+ var = (x - mu).pow(2).mean(1, keepdim=True)
31
+ y = (x - mu) / (var + eps).sqrt()
32
+ ctx.save_for_backward(y, var, weight)
33
+ y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
34
+ return y
35
+
36
+ @staticmethod
37
+ def backward(ctx, grad_output):
38
+ eps = ctx.eps
39
+
40
+ N, C, H, W = grad_output.size()
41
+ y, var, weight = ctx.saved_variables
42
+ g = grad_output * weight.view(1, C, 1, 1)
43
+ mean_g = g.mean(dim=1, keepdim=True)
44
+
45
+ mean_gy = (g * y).mean(dim=1, keepdim=True)
46
+ gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
47
+ return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
48
+ dim=0), None
49
+
50
+ class LayerNorm2d(nn.Module):
51
+
52
+ def __init__(self, channels, eps=1e-6):
53
+ super(LayerNorm2d, self).__init__()
54
+ self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
55
+ self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
56
+ self.eps = eps
57
+
58
+ def forward(self, x):
59
+ return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
60
+
61
+ class SimpleGate(nn.Module):
62
+ def forward(self, x):
63
+ x1, x2 = x.chunk(2, dim=1)
64
+ return x1 * x2
65
+
66
+ class NAFBlock(nn.Module):
67
+ def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
68
+ super().__init__()
69
+ dw_channel = c * DW_Expand
70
+ self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
71
+ self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
72
+ bias=True)
73
+ self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
74
+
75
+ # Simplified Channel Attention
76
+ self.sca = nn.Sequential(
77
+ nn.AdaptiveAvgPool2d(1),
78
+ nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
79
+ groups=1, bias=True),
80
+ )
81
+
82
+ # SimpleGate
83
+ self.sg = SimpleGate()
84
+
85
+ ffn_channel = FFN_Expand * c
86
+ self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
87
+ self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
88
+
89
+ self.norm1 = LayerNorm2d(c)
90
+ self.norm2 = LayerNorm2d(c)
91
+
92
+ self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
93
+ self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
94
+
95
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
96
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
97
+
98
+ def forward(self, inp):
99
+ x = inp
100
+
101
+ x = self.norm1(x)
102
+
103
+ x = self.conv1(x)
104
+ x = self.conv2(x)
105
+ x = self.sg(x)
106
+ x = x * self.sca(x)
107
+ x = self.conv3(x)
108
+
109
+ x = self.dropout1(x)
110
+
111
+ y = inp + x * self.beta
112
+
113
+ x = self.conv4(self.norm2(y))
114
+ x = self.sg(x)
115
+ x = self.conv5(x)
116
+
117
+ x = self.dropout2(x)
118
+
119
+ return y + x * self.gamma
120
+
121
+
122
+ class NAFNet(nn.Module):
123
+
124
+ def __init__(self, img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 28],
125
+ dec_blk_nums=[1, 1, 1, 1], global_residual = False, drop_flag = False, drop_rate = 0.4):
126
+ super().__init__()
127
+
128
+ self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
129
+ bias=True)
130
+ self.ending = nn.Conv2d(in_channels=width, out_channels=3, kernel_size=3, padding=1, stride=1, groups=1,
131
+ bias=True)
132
+
133
+ self.encoders = nn.ModuleList()
134
+ self.decoders = nn.ModuleList()
135
+ self.middle_blks = nn.ModuleList()
136
+ self.ups = nn.ModuleList()
137
+ self.downs = nn.ModuleList()
138
+ self.global_residual = global_residual
139
+ self.drop_flag = drop_flag
140
+
141
+ if drop_flag:
142
+ self.dropout = nn.Dropout2d(p=drop_rate)
143
+
144
+ chan = width
145
+ for num in enc_blk_nums:
146
+ self.encoders.append(
147
+ nn.Sequential(
148
+ *[NAFBlock(chan) for _ in range(num)]
149
+ )
150
+ )
151
+ self.downs.append(
152
+ nn.Conv2d(chan, 2*chan, 2, 2)
153
+ )
154
+ chan = chan * 2
155
+
156
+ self.middle_blks = \
157
+ nn.Sequential(
158
+ *[NAFBlock(chan) for _ in range(middle_blk_num)]
159
+ )
160
+
161
+ for num in dec_blk_nums:
162
+ self.ups.append(
163
+ nn.Sequential(
164
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
165
+ nn.PixelShuffle(2)
166
+ )
167
+ )
168
+ chan = chan // 2
169
+ self.decoders.append(
170
+ nn.Sequential(
171
+ *[NAFBlock(chan) for _ in range(num)]
172
+ )
173
+ )
174
+
175
+ self.padder_size = 2 ** len(self.encoders)
176
+
177
+ def forward(self, inp):
178
+ B, C, H, W = inp.shape
179
+ inp = self.check_image_size(inp)
180
+ base_inp = inp[:, :3, :, :]
181
+ x = self.intro(inp)
182
+
183
+ encs = []
184
+
185
+ for encoder, down in zip(self.encoders, self.downs):
186
+ x = encoder(x)
187
+ encs.append(x)
188
+ x = down(x)
189
+
190
+ x = self.middle_blks(x)
191
+
192
+ for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
193
+ x = up(x)
194
+ x = x + enc_skip
195
+ x = decoder(x)
196
+
197
+ if self.drop_flag:
198
+ x = self.dropout(x)
199
+
200
+ x = self.ending(x)
201
+ if self.global_residual:
202
+ #print(x.shape, inp.shape, base_inp.shape)
203
+ x = x + base_inp
204
+ else:
205
+ x
206
+ return x[:, :, :H, :W]
207
+
208
+ def check_image_size(self, x):
209
+ _, _, h, w = x.size()
210
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
211
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
212
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
213
+ return x
214
+
215
+
216
+
217
+ class NAFNet_wDetHead(nn.Module):
218
+
219
+ def __init__(self, img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 28],
220
+ dec_blk_nums=[1, 1, 1, 1], global_residual = False, drop_flag = False, drop_rate = 0.4,
221
+ concat = False, merge_manner = 0):
222
+ super().__init__()
223
+
224
+ self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
225
+ bias=True)
226
+ self.ending = nn.Conv2d(in_channels=width, out_channels=3, kernel_size=3, padding=1, stride=1, groups=1,
227
+ bias=True)
228
+
229
+ self.encoders = nn.ModuleList()
230
+ self.decoders = nn.ModuleList()
231
+ self.middle_blks = nn.ModuleList()
232
+ self.ups = nn.ModuleList()
233
+ self.downs = nn.ModuleList()
234
+ self.global_residual = global_residual
235
+ self.drop_flag = drop_flag
236
+ self.concat = concat
237
+ self.merge_manner = merge_manner
238
+
239
+ if drop_flag:
240
+ self.dropout = nn.Dropout2d(p=drop_rate)
241
+
242
+ # --------------------------- Merge sparse & Img -------------------------------------------------------
243
+ self.intro_Det = nn.Conv2d(in_channels=1, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
244
+ bias=True)
245
+ self.DetEnc = nn.Sequential( *[NAFBlock(width) for _ in range(3)] )
246
+ if self.concat:
247
+ self.Merge_conv = nn.Conv2d(in_channels=width *2 , out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
248
+ bias=True)
249
+ else:
250
+ self.Merge_conv = nn.Conv2d(in_channels=width , out_channels=width, kernel_size=3, padding=1, stride=1,
251
+ groups=1,
252
+ bias=True)
253
+ # --------------------------- Merge sparse & Img -------------------------------------------------------
254
+
255
+ chan = width
256
+ for num in enc_blk_nums:
257
+ self.encoders.append(
258
+ nn.Sequential(
259
+ *[NAFBlock(chan) for _ in range(num)]
260
+ )
261
+ )
262
+ self.downs.append(
263
+ nn.Conv2d(chan, 2*chan, 2, 2)
264
+ )
265
+ chan = chan * 2
266
+
267
+ self.middle_blks = \
268
+ nn.Sequential(
269
+ *[NAFBlock(chan) for _ in range(middle_blk_num)]
270
+ )
271
+
272
+ for num in dec_blk_nums:
273
+ self.ups.append(
274
+ nn.Sequential(
275
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
276
+ nn.PixelShuffle(2)
277
+ )
278
+ )
279
+ chan = chan // 2
280
+ self.decoders.append(
281
+ nn.Sequential(
282
+ *[NAFBlock(chan) for _ in range(num)]
283
+ )
284
+ )
285
+
286
+ self.padder_size = 2 ** len(self.encoders)
287
+
288
+ def forward(self, inp, spare_ref):
289
+ B, C, H, W = inp.shape
290
+ inp = self.check_image_size(inp)
291
+ base_inp = inp #[:, :3, :, :]
292
+ x = self.intro(inp)
293
+
294
+ fea_sparse = self.DetEnc(self.intro_Det(spare_ref))
295
+
296
+ if self.merge_manner ==0 and self.concat:
297
+ x = torch.cat([x, fea_sparse], dim=1)
298
+ x = self.Merge_conv(x)
299
+ elif self.merge_manner == 1 and not self.concat:
300
+ x = x + fea_sparse
301
+ x = self.Merge_conv(x)
302
+ elif self.merge_manner == 2 and not self.concat:
303
+ x = x + fea_sparse *x
304
+ x = self.Merge_conv(x)
305
+ else:
306
+ x = x
307
+ print('Merge Flag Error!!!(No Merge Operation) ---zyr 1031 ')
308
+
309
+ encs = []
310
+
311
+ for encoder, down in zip(self.encoders, self.downs):
312
+ x = encoder(x)
313
+ encs.append(x)
314
+ x = down(x)
315
+
316
+ x = self.middle_blks(x)
317
+
318
+ for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
319
+ x = up(x)
320
+ x = x + enc_skip
321
+ x = decoder(x)
322
+
323
+ if self.drop_flag:
324
+ x = self.dropout(x)
325
+
326
+ x = self.ending(x)
327
+ if self.global_residual:
328
+ #print(x.shape, inp.shape, base_inp.shape)
329
+ x = x + base_inp
330
+ else:
331
+ x
332
+ return x[:, :, :H, :W]
333
+
334
+ def check_image_size(self, x):
335
+ _, _, h, w = x.size()
336
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
337
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
338
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
339
+ return x
340
+
341
+
342
+ class NAFNet_refine(nn.Module):
343
+
344
+ def __init__(self, img_channel=6, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 28],
345
+ dec_blk_nums=[1, 1, 1, 1], global_residual = False):
346
+ super().__init__()
347
+
348
+ self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
349
+ bias=True)
350
+ self.ending = nn.Conv2d(in_channels=width, out_channels=3, kernel_size=3, padding=1, stride=1, groups=1,
351
+ bias=True)
352
+
353
+ self.encoders = nn.ModuleList()
354
+ self.decoders = nn.ModuleList()
355
+ self.middle_blks = nn.ModuleList()
356
+ self.ups = nn.ModuleList()
357
+ self.downs = nn.ModuleList()
358
+ self.global_residual = global_residual
359
+
360
+ chan = width
361
+ for num in enc_blk_nums:
362
+ self.encoders.append(
363
+ nn.Sequential(
364
+ *[NAFBlock(chan) for _ in range(num)]
365
+ )
366
+ )
367
+ self.downs.append(
368
+ nn.Conv2d(chan, 2*chan, 2, 2)
369
+ )
370
+ chan = chan * 2
371
+
372
+ self.middle_blks = \
373
+ nn.Sequential(
374
+ *[NAFBlock(chan) for _ in range(middle_blk_num)]
375
+ )
376
+
377
+ for num in dec_blk_nums:
378
+ self.ups.append(
379
+ nn.Sequential(
380
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
381
+ nn.PixelShuffle(2)
382
+ )
383
+ )
384
+ chan = chan // 2
385
+ self.decoders.append(
386
+ nn.Sequential(
387
+ *[NAFBlock(chan) for _ in range(num)]
388
+ )
389
+ )
390
+
391
+ self.padder_size = 2 ** len(self.encoders)
392
+
393
+ def forward(self, inp, pre_pred):
394
+ B, C, H, W = inp.shape
395
+ inp = self.check_image_size(inp)
396
+ pre_pred = self.check_image_size(pre_pred)
397
+
398
+ network_in = torch.cat([inp, pre_pred ], dim= 1)
399
+
400
+ x = self.intro(network_in)
401
+
402
+ encs = []
403
+
404
+ for encoder, down in zip(self.encoders, self.downs):
405
+ x = encoder(x)
406
+ encs.append(x)
407
+ x = down(x)
408
+
409
+ x = self.middle_blks(x)
410
+
411
+ for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
412
+ x = up(x)
413
+ x = x + enc_skip
414
+ x = decoder(x)
415
+
416
+
417
+ x = self.ending(x)
418
+ if self.global_residual:
419
+
420
+ x = x + inp[:3,:,:,:]
421
+ else:
422
+ x
423
+ return x[:, :, :H, :W]
424
+
425
+ def check_image_size(self, x):
426
+ _, _, h, w = x.size()
427
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
428
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
429
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
430
+ return x
431
+
432
+
433
+ def print_param_number(net):
434
+ print('#generator parameters:', sum(param.numel() for param in net.parameters()))
435
+ if __name__ == '__main__':
436
+ img_channel = 3
437
+ width = 32
438
+
439
+ # enc_blks = [2, 2, 4, 8]
440
+ # middle_blk_num = 12
441
+ # dec_blks = [2, 2, 2, 2]
442
+
443
+ # enc_blks = [2, 2, 4, 8]
444
+ # middle_blk_num = 12
445
+ # dec_blks = [2, 2, 2, 2]
446
+
447
+ # enc_blks = [1, 1, 1, 28]
448
+ # middle_blk_num = 1
449
+ # dec_blks = [1, 1, 1, 1]
450
+
451
+ enc_blks = [1, 1, 1, 28]
452
+ middle_blk_num = 1
453
+ dec_blks = [1, 1, 1, 1]
454
+
455
+ net = NAFNet_wDetHead(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
456
+ enc_blk_nums=enc_blks, dec_blk_nums=dec_blks,global_residual = True,
457
+ concat= True, merge_manner= 2) #.cuda()
458
+ #print(net)
459
+ size = 352
460
+ input = torch.randn([1,3,128, 128])#.cuda() inp_shape = (5, 3, 128, 128)
461
+ spare = torch.randn([1,1,128, 128])
462
+ print(net(input, spare).size())
463
+ print_param_number(net)
464
+
465
+
466
+
467
+ #net_local = NAFNetLocal()#.cuda()
468
+
469
+ #print_param_number(net)
470
+ # print(net_local(input).size())
471
+ # inp_shape = (3, 256, 256)
472
+ #
473
+ # from ptflops import get_model_complexity_info
474
+ #
475
+ # macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
476
+ #
477
+ # params = float(params[:-3])
478
+ # macs = float(macs[:-4])
479
+ #
480
+ # print(macs, params)
RDNet-main/RDNet-main/models/arch/RDnet_.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from models.arch.focalnet import build_focalnet
3
+ import torch
4
+ import torch.nn as nn
5
+ from models.arch.modules_sig import ConvNextBlock, Decoder, LayerNorm, NAFBlock, SimDecoder, UpSampleConvnext
6
+ from models.arch.reverse_function import ReverseFunction
7
+ from timm.models.layers import trunc_normal_
8
+
9
+ class Fusion(nn.Module):
10
+ def __init__(self, level, channels, first_col) -> None:
11
+ super().__init__()
12
+
13
+ self.level = level
14
+ self.first_col = first_col
15
+ self.down = nn.Sequential(
16
+ nn.Conv2d(channels[level - 1], channels[level], kernel_size=2, stride=2),
17
+ LayerNorm(channels[level], eps=1e-6, data_format="channels_first"),
18
+ ) if level in [1, 2, 3] else nn.Identity()
19
+ if not first_col:
20
+ self.up = UpSampleConvnext(1, channels[level + 1], channels[level]) if level in [0, 1, 2] else nn.Identity()
21
+
22
+ def forward(self, *args):
23
+
24
+ c_down, c_up = args
25
+ channels_dowm=c_down.size(1)
26
+ if self.first_col:
27
+ x_clean = self.down(c_down)
28
+ return x_clean
29
+ if c_up is not None:
30
+ channels_up=c_up.size(1)
31
+ if self.level == 3:
32
+ x_clean = self.down(c_down)
33
+ else:
34
+ x_clean = self.up(c_up) + self.down(c_down)
35
+
36
+ return x_clean
37
+
38
+ class Level(nn.Module):
39
+ def __init__(self, level, channels, layers, kernel_size, first_col, dp_rate=0.0, block_type=ConvNextBlock) -> None:
40
+ super().__init__()
41
+ countlayer = sum(layers[:level])
42
+ expansion = 4
43
+ self.fusion = Fusion(level, channels, first_col)
44
+ modules = [block_type(channels[level], expansion * channels[level], channels[level], kernel_size=kernel_size,
45
+ layer_scale_init_value=1e-6, drop_path=dp_rate[countlayer + i]) for i in
46
+ range(layers[level])]
47
+ self.blocks = nn.Sequential(*modules)
48
+
49
+ def forward(self, *args):
50
+ x = self.fusion(*args)
51
+ x_clean = self.blocks(x)
52
+ return x_clean
53
+
54
+
55
+ class SubNet(nn.Module):
56
+ def __init__(self, channels, layers, kernel_size, first_col, dp_rates, save_memory, block_type=ConvNextBlock) -> None:
57
+ super().__init__()
58
+ shortcut_scale_init_value = 0.5
59
+ self.save_memory = save_memory
60
+ self.alpha0 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[0], 1, 1)),
61
+ requires_grad=True) if shortcut_scale_init_value > 0 else None
62
+ self.alpha1 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[1], 1, 1)),
63
+ requires_grad=True) if shortcut_scale_init_value > 0 else None
64
+ self.alpha2 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[2], 1, 1)),
65
+ requires_grad=True) if shortcut_scale_init_value > 0 else None
66
+ self.alpha3 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[3], 1, 1)),
67
+ requires_grad=True) if shortcut_scale_init_value > 0 else None
68
+
69
+ self.level0 = Level(0, channels, layers, kernel_size, first_col, dp_rates, block_type=block_type)
70
+
71
+ self.level1 = Level(1, channels, layers, kernel_size, first_col, dp_rates, block_type=block_type)
72
+
73
+ self.level2 = Level(2, channels, layers, kernel_size, first_col, dp_rates, block_type=block_type)
74
+
75
+ self.level3 = Level(3, channels, layers, kernel_size, first_col, dp_rates, block_type=block_type)
76
+
77
+ def _forward_nonreverse(self, *args):
78
+ x, c0, c1, c2, c3 = args
79
+ c0 = self.alpha0 * c0 + self.level0(x, c1)
80
+ c1 = self.alpha1 * c1 + self.level1(c0, c2)
81
+ c2 = self.alpha2 * c2 + self.level2(c1, c3)
82
+ c3 = self.alpha3 * c3 + self.level3(c2, None)
83
+ return c0, c1, c2, c3
84
+
85
+ def _forward_reverse(self, *args):
86
+ x, c0, c1, c2, c3 = args
87
+ local_funs = [self.level0, self.level1, self.level2, self.level3]
88
+ alpha = [self.alpha0, self.alpha1, self.alpha2, self.alpha3]
89
+ _, c0, c1, c2, c3 = ReverseFunction.apply(
90
+ local_funs, alpha, *args)
91
+
92
+ return c0, c1, c2, c3
93
+
94
+ def forward(self, *args):
95
+
96
+ self._clamp_abs(self.alpha0.data, 1e-3)
97
+ self._clamp_abs(self.alpha1.data, 1e-3)
98
+ self._clamp_abs(self.alpha2.data, 1e-3)
99
+ self._clamp_abs(self.alpha3.data, 1e-3)
100
+ if self.save_memory:
101
+ return self._forward_reverse(*args)
102
+ else:
103
+ return self._forward_nonreverse(*args)
104
+
105
+ def _clamp_abs(self, data, value):
106
+ with torch.no_grad():
107
+ sign = data.sign()
108
+ data.abs_().clamp_(value)
109
+ data *= sign
110
+
111
+ class StarReLU(nn.Module):
112
+ """
113
+ StarReLU: s * relu(x) ** 2 + b
114
+ """
115
+ def __init__(self, scale_value=1.0, bias_value=0.0,
116
+ scale_learnable=True, bias_learnable=True,
117
+ mode=None, inplace=True):
118
+ super().__init__()
119
+ self.inplace = inplace
120
+ self.relu = nn.ReLU(inplace=inplace)
121
+ self.scale = nn.Parameter(scale_value * torch.ones(1),
122
+ requires_grad=scale_learnable)
123
+ self.bias = nn.Parameter(bias_value * torch.ones(1),
124
+ requires_grad=bias_learnable)
125
+ def forward(self, x):
126
+ return self.scale * self.relu(x)**2 + self.bias
127
+
128
+ class FullNet_NLP(nn.Module):
129
+ def __init__(self, channels=[32, 64, 96, 128], layers=[2, 3, 6, 3], num_subnet=5,loss_col=4, kernel_size=3, num_classes=1000,
130
+ drop_path=0.0, save_memory=True, inter_supv=True, head_init_scale=None, pretrained_cols=16) -> None:
131
+ super().__init__()
132
+ self.num_subnet = num_subnet
133
+ self.Loss_col=(loss_col+1)
134
+ self.inter_supv = inter_supv
135
+ self.channels = channels
136
+ self.layers = layers
137
+ self.stem_comp = nn.Sequential(
138
+ nn.Conv2d(3, channels[0], kernel_size=5, stride=2, padding=2),
139
+ LayerNorm(channels[0], eps=1e-6, data_format="channels_first")
140
+ )
141
+ self.prompt=nn.Sequential(nn.Linear(in_features=6,out_features=512),
142
+ StarReLU(),
143
+ nn.Linear(in_features=512,out_features=channels[0]),
144
+ StarReLU(),
145
+ )
146
+ dp_rate = [x.item() for x in torch.linspace(0, drop_path, sum(layers))]
147
+ for i in range(num_subnet):
148
+ first_col = True if i == 0 else False
149
+ self.add_module(f'subnet{str(i)}', SubNet(
150
+ channels, layers, kernel_size, first_col,
151
+ dp_rates=dp_rate, save_memory=save_memory,
152
+ block_type=NAFBlock))
153
+
154
+ channels.reverse()
155
+ self.decoder_blocks = nn.ModuleList(
156
+ [Decoder(depth=[1, 1, 1, 1], dim=channels, block_type=NAFBlock, kernel_size=3) for _ in
157
+ range(3)])
158
+
159
+ self.apply(self._init_weights)
160
+ self.baseball = build_focalnet('focalnet_L_384_22k_fl4')
161
+ self.baseball_adapter = nn.ModuleList()
162
+ self.baseball_adapter.append(nn.Conv2d(192, 64, kernel_size=1))
163
+ self.baseball_adapter.append(nn.Conv2d(192, 64, kernel_size=1))
164
+ self.baseball_adapter.append(nn.Conv2d(192 * 2, 64 * 2, kernel_size=1))
165
+ self.baseball_adapter.append(nn.Conv2d(192 * 4, 64 * 4, kernel_size=1))
166
+ self.baseball_adapter.append(nn.Conv2d(192 * 8, 64 * 8, kernel_size=1))
167
+
168
+ def forward(self, x_in,alpha,prompt=True):
169
+ x_cls_out = []
170
+ x_img_out = []
171
+ c0, c1, c2, c3 = 0, 0, 0, 0
172
+ interval = self.num_subnet // 4
173
+
174
+ x_base, x_stem = self.baseball(x_in)
175
+ c0, c1, c2, c3 = x_base
176
+ x_stem = self.baseball_adapter[0](x_stem)
177
+ c0, c1, c2, c3 = self.baseball_adapter[1](c0),\
178
+ self.baseball_adapter[2](c1),\
179
+ self.baseball_adapter[3](c2),\
180
+ self.baseball_adapter[4](c3)
181
+ if prompt==True:
182
+ prompt_alpha=self.prompt(alpha)
183
+ prompt_alpha = prompt_alpha.unsqueeze(-1).unsqueeze(-1)
184
+ x=prompt_alpha*x_stem
185
+ else :
186
+ x = x_stem
187
+ for i in range(self.num_subnet):
188
+ c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3)
189
+ if i>(self.num_subnet-self.Loss_col):
190
+ x_img_out.append(torch.cat([x_in, x_in], dim=-3) - self.decoder_blocks[-1](c3, c2, c1, c0) )
191
+
192
+ return x_cls_out, x_img_out
193
+
194
+ def _init_weights(self, module):
195
+ if isinstance(module, nn.Conv2d):
196
+ trunc_normal_(module.weight, std=.02)
197
+ nn.init.constant_(module.bias, 0)
198
+ elif isinstance(module, nn.Linear):
199
+ trunc_normal_(module.weight, std=.02)
200
+ nn.init.constant_(module.bias, 0)
201
+
202
+
RDNet-main/RDNet-main/models/arch/__pycache__/RDnet_.cpython-38.pyc ADDED
Binary file (8.23 kB). View file
 
RDNet-main/RDNet-main/models/arch/__pycache__/classifier.cpython-38.pyc ADDED
Binary file (2.14 kB). View file
 
RDNet-main/RDNet-main/models/arch/__pycache__/focalnet.cpython-38.pyc ADDED
Binary file (15.8 kB). View file
 
RDNet-main/RDNet-main/models/arch/__pycache__/modules_sig.cpython-38.pyc ADDED
Binary file (11 kB). View file
 
RDNet-main/RDNet-main/models/arch/__pycache__/reverse_function.cpython-38.pyc ADDED
Binary file (4.74 kB). View file
 
RDNet-main/RDNet-main/models/arch/classifier.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import timm
3
+ import torch
4
+ import torch.nn.functional as F
5
+ class PretrainedConvNext(nn.Module):
6
+ def __init__(self, model_name='convnext_base', pretrained=True):
7
+ super(PretrainedConvNext, self).__init__()
8
+ # Load the pretrained ConvNext model from timm
9
+ self.model = timm.create_model(model_name, pretrained=False, num_classes=0)
10
+ self.head = nn.Linear(768, 6)
11
+ def forward(self, x):
12
+ with torch.no_grad():
13
+ cls_input = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=True)
14
+ # Forward pass through the ConvNext model
15
+ out = self.model(cls_input)
16
+ out = self.head(out)
17
+ # alpha, beta = out[..., :3].unsqueeze(-1).unsqueeze(-1),\
18
+ # out[..., 3:].unsqueeze(-1).unsqueeze(-1)
19
+
20
+ #out = alpha * x + beta
21
+ # print(out.shape)
22
+ return out#alpha,beta#out #out[..., :3], out[..., 3:]
23
+ class PretrainedConvNext_e2e(nn.Module):
24
+ def __init__(self, model_name='convnext_base', pretrained=True):
25
+ super(PretrainedConvNext_e2e, self).__init__()
26
+ # Load the pretrained ConvNext model from timm
27
+ self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
28
+ self.head = nn.Linear(768, 6)
29
+ def forward(self, x):
30
+ with torch.no_grad():
31
+ cls_input = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=True)
32
+ # Forward pass through the ConvNext model
33
+ out = self.model(cls_input)
34
+ out = self.head(out)
35
+ alpha, beta = out[..., :3].unsqueeze(-1).unsqueeze(-1),\
36
+ out[..., 3:].unsqueeze(-1).unsqueeze(-1)
37
+
38
+ out = alpha * x + beta
39
+ #print(out.shape)
40
+ return out#alpha,beta#out #out[..., :3], out[..., 3:]
41
+
42
+ if __name__ == "__main__":
43
+ model = PretrainedConvNext('convnext_small_in22k')
44
+ print("Testing PretrainedConvNext model...")
45
+ # Assuming a dummy input tensor of size (1, 3, 224, 224) similar to an image in the ImageNet dataset
46
+ dummy_input = torch.randn(20, 3, 224, 224)
47
+ output_x, output_y = model(dummy_input)
48
+ print("Output shape:", output_x.shape)
49
+ print("Test completed successfully.")
RDNet-main/RDNet-main/models/arch/decode.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ def make_layers(cfg, batch_norm=False):
4
+ layers = []
5
+ in_channels = 3
6
+ for v in cfg:
7
+ if v == 'M':
8
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
9
+ else:
10
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
11
+ if batch_norm:
12
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
13
+ else:
14
+ layers += [conv2d, nn.ReLU(inplace=True)]
15
+ in_channels = v
16
+ return nn.Sequential(*layers)
17
+
18
+ cfgs = {
19
+ 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
20
+ }
21
+
22
+
23
+ class VGG(nn.Module):
24
+ def __init__(self,features):
25
+ super(VGG, self).__init__()
26
+ self.features = features
27
+
28
+ def forward(self, x):
29
+ x = self.features(x)
30
+
31
+ def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
32
+ model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
33
+ return model
34
+
35
+ def encoder(pretrained=False, progress=True, **kwargs):
36
+ return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
RDNet-main/RDNet-main/models/arch/focalnet.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # FocalNet for Semantic Segmentation
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Jianwei Yang
6
+ # --------------------------------------------------------
7
+ import math
8
+ import time
9
+ import numpy as np
10
+ import json
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
16
+
17
+ class Mlp(nn.Module):
18
+ """ Multilayer perceptron."""
19
+
20
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
21
+ super().__init__()
22
+ out_features = out_features or in_features
23
+ hidden_features = hidden_features or in_features
24
+ self.fc1 = nn.Linear(in_features, hidden_features)
25
+ self.act = act_layer()
26
+ self.fc2 = nn.Linear(hidden_features, out_features)
27
+ self.drop = nn.Dropout(drop)
28
+
29
+ def forward(self, x):
30
+ x = self.fc1(x)
31
+ x = self.act(x)
32
+ x = self.drop(x)
33
+ x = self.fc2(x)
34
+ x = self.drop(x)
35
+ return x
36
+
37
+ class FocalModulation(nn.Module):
38
+ """ Focal Modulation
39
+
40
+ Args:
41
+ dim (int): Number of input channels.
42
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
43
+ focal_level (int): Number of focal levels
44
+ focal_window (int): Focal window size at focal level 1
45
+ focal_factor (int, default=2): Step to increase the focal window
46
+ use_postln (bool, default=False): Whether use post-modulation layernorm
47
+ """
48
+
49
+ def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False,
50
+ use_postln_in_modulation=False, normalize_modulator=False):
51
+
52
+ super().__init__()
53
+ self.dim = dim
54
+
55
+ # specific args for focalv3
56
+ self.focal_level = focal_level
57
+ self.focal_window = focal_window
58
+ self.focal_factor = focal_factor
59
+ self.use_postln_in_modulation = use_postln_in_modulation
60
+ self.normalize_modulator = normalize_modulator
61
+
62
+ self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True)
63
+ self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True)
64
+
65
+ self.act = nn.GELU()
66
+ self.proj = nn.Linear(dim, dim)
67
+ self.proj_drop = nn.Dropout(proj_drop)
68
+ self.focal_layers = nn.ModuleList()
69
+
70
+ if self.use_postln_in_modulation:
71
+ self.ln = nn.LayerNorm(dim)
72
+
73
+ for k in range(self.focal_level):
74
+ kernel_size = self.focal_factor*k + self.focal_window
75
+ self.focal_layers.append(
76
+ nn.Sequential(
77
+ nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim,
78
+ padding=kernel_size//2, bias=False),
79
+ nn.GELU(),
80
+ )
81
+ )
82
+
83
+ def forward(self, x):
84
+ """ Forward function.
85
+
86
+ Args:
87
+ x: input features with shape of (B, H, W, C)
88
+ """
89
+ B, nH, nW, C = x.shape
90
+ x = self.f(x)
91
+ x = x.permute(0, 3, 1, 2).contiguous()
92
+ q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1)
93
+
94
+ ctx_all = 0
95
+ for l in range(self.focal_level):
96
+ ctx = self.focal_layers[l](ctx)
97
+ ctx_all = ctx_all + ctx*gates[:, l:l+1]
98
+ ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
99
+ ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:]
100
+ if self.normalize_modulator:
101
+ ctx_all = ctx_all / (self.focal_level+1)
102
+
103
+ x_out = q * self.h(ctx_all)
104
+ x_out = x_out.permute(0, 2, 3, 1).contiguous()
105
+ if self.use_postln_in_modulation:
106
+ x_out = self.ln(x_out)
107
+ x_out = self.proj(x_out)
108
+ x_out = self.proj_drop(x_out)
109
+ return x_out
110
+
111
+ class FocalModulationBlock(nn.Module):
112
+ """ Focal Modulation Block.
113
+
114
+ Args:
115
+ dim (int): Number of input channels.
116
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
117
+ drop (float, optional): Dropout rate. Default: 0.0
118
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
119
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
120
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
121
+ focal_level (int): number of focal levels
122
+ focal_window (int): focal kernel size at level 1
123
+ """
124
+
125
+ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.,
126
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm,
127
+ focal_level=2, focal_window=9,
128
+ use_postln=False, use_postln_in_modulation=False,
129
+ normalize_modulator=False,
130
+ use_layerscale=False,
131
+ layerscale_value=1e-4):
132
+ super().__init__()
133
+ self.dim = dim
134
+ self.mlp_ratio = mlp_ratio
135
+ self.focal_window = focal_window
136
+ self.focal_level = focal_level
137
+ self.use_postln = use_postln
138
+ self.use_layerscale = use_layerscale
139
+
140
+ self.norm1 = norm_layer(dim)
141
+ self.modulation = FocalModulation(
142
+ dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop,
143
+ use_postln_in_modulation=use_postln_in_modulation,
144
+ normalize_modulator=normalize_modulator,
145
+ )
146
+
147
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
148
+ self.norm2 = norm_layer(dim)
149
+ mlp_hidden_dim = int(dim * mlp_ratio)
150
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
151
+
152
+ self.H = None
153
+ self.W = None
154
+
155
+ self.gamma_1 = 1.0
156
+ self.gamma_2 = 1.0
157
+ if self.use_layerscale:
158
+ self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
159
+ self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
160
+
161
+ def forward(self, x):
162
+ """ Forward function.
163
+
164
+ Args:
165
+ x: Input feature, tensor size (B, H*W, C).
166
+ H, W: Spatial resolution of the input feature.
167
+ """
168
+ B, L, C = x.shape
169
+ H, W = self.H, self.W
170
+ assert L == H * W, "input feature has wrong size"
171
+
172
+ shortcut = x
173
+ if not self.use_postln:
174
+ x = self.norm1(x)
175
+ x = x.view(B, H, W, C)
176
+
177
+ # FM
178
+ x = self.modulation(x).view(B, H * W, C)
179
+ if self.use_postln:
180
+ x = self.norm1(x)
181
+
182
+ # FFN
183
+ x = shortcut + self.drop_path(self.gamma_1 * x)
184
+
185
+ if self.use_postln:
186
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
187
+ else:
188
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
189
+
190
+ return x
191
+
192
+ class BasicLayer(nn.Module):
193
+ """ A basic focal modulation layer for one stage.
194
+
195
+ Args:
196
+ dim (int): Number of feature channels
197
+ depth (int): Depths of this stage.
198
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
199
+ drop (float, optional): Dropout rate. Default: 0.0
200
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
201
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
202
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
203
+ focal_level (int): Number of focal levels
204
+ focal_window (int): Focal window size at focal level 1
205
+ use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False
206
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
207
+ """
208
+
209
+ def __init__(self,
210
+ dim,
211
+ depth,
212
+ mlp_ratio=4.,
213
+ drop=0.,
214
+ drop_path=0.,
215
+ norm_layer=nn.LayerNorm,
216
+ downsample=None,
217
+ focal_window=9,
218
+ focal_level=2,
219
+ use_conv_embed=False,
220
+ use_postln=False,
221
+ use_postln_in_modulation=False,
222
+ normalize_modulator=False,
223
+ use_layerscale=False,
224
+ use_checkpoint=False
225
+ ):
226
+ super().__init__()
227
+ self.depth = depth
228
+ self.use_checkpoint = use_checkpoint
229
+
230
+ # build blocks
231
+ self.blocks = nn.ModuleList([
232
+ FocalModulationBlock(
233
+ dim=dim,
234
+ mlp_ratio=mlp_ratio,
235
+ drop=drop,
236
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
237
+ focal_window=focal_window,
238
+ focal_level=focal_level,
239
+ use_postln=use_postln,
240
+ use_postln_in_modulation=use_postln_in_modulation,
241
+ normalize_modulator=normalize_modulator,
242
+ use_layerscale=use_layerscale,
243
+ norm_layer=norm_layer)
244
+ for i in range(depth)])
245
+
246
+ # patch merging layer
247
+ if downsample is not None:
248
+ self.downsample = downsample(
249
+ patch_size=2,
250
+ in_chans=dim, embed_dim=2*dim,
251
+ use_conv_embed=use_conv_embed,
252
+ norm_layer=norm_layer,
253
+ is_stem=False
254
+ )
255
+
256
+ else:
257
+ self.downsample = None
258
+
259
+ def forward(self, x, H, W):
260
+ """ Forward function.
261
+
262
+ Args:
263
+ x: Input feature, tensor size (B, H*W, C).
264
+ H, W: Spatial resolution of the input feature.
265
+ """
266
+
267
+ for blk in self.blocks:
268
+ blk.H, blk.W = H, W
269
+ if self.use_checkpoint:
270
+ x = checkpoint.checkpoint(blk, x)
271
+ else:
272
+ x = blk(x)
273
+ if self.downsample is not None:
274
+ x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W)
275
+ x_down = self.downsample(x_reshaped)
276
+ x_down = x_down.flatten(2).transpose(1, 2)
277
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
278
+ return x, H, W, x_down, Wh, Ww
279
+ else:
280
+ return x, H, W, x, H, W
281
+
282
+
283
+ class PatchEmbed(nn.Module):
284
+ """ Image to Patch Embedding
285
+
286
+ Args:
287
+ patch_size (int): Patch token size. Default: 4.
288
+ in_chans (int): Number of input image channels. Default: 3.
289
+ embed_dim (int): Number of linear projection output channels. Default: 96.
290
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
291
+ use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False
292
+ is_stem (bool): Is the stem block or not.
293
+ """
294
+
295
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False):
296
+ super().__init__()
297
+ patch_size = to_2tuple(patch_size)
298
+ self.patch_size = patch_size
299
+
300
+ self.in_chans = in_chans
301
+ self.embed_dim = embed_dim
302
+
303
+ if use_conv_embed:
304
+ # if we choose to use conv embedding, then we treat the stem and non-stem differently
305
+ if is_stem:
306
+ kernel_size = 7; padding = 3; stride = 2
307
+ else:
308
+ kernel_size = 3; padding = 1; stride = 2
309
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
310
+ else:
311
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
312
+
313
+ if norm_layer is not None:
314
+ self.norm = norm_layer(embed_dim)
315
+ else:
316
+ self.norm = None
317
+
318
+ def forward(self, x):
319
+ """Forward function."""
320
+ _, _, H, W = x.size()
321
+ if W % self.patch_size[1] != 0:
322
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
323
+ if H % self.patch_size[0] != 0:
324
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
325
+
326
+ x = self.proj(x) # B C Wh Ww
327
+ if self.norm is not None:
328
+ Wh, Ww = x.size(2), x.size(3)
329
+ x = x.flatten(2).transpose(1, 2)
330
+ x = self.norm(x)
331
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
332
+
333
+ return x
334
+
335
+
336
+ class FocalNet(nn.Module):
337
+ """ FocalNet backbone.
338
+
339
+ Args:
340
+ pretrain_img_size (int): Input image size for training the pretrained model,
341
+ used in absolute postion embedding. Default 224.
342
+ patch_size (int | tuple(int)): Patch size. Default: 4.
343
+ in_chans (int): Number of input image channels. Default: 3.
344
+ embed_dim (int): Number of linear projection output channels. Default: 96.
345
+ depths (tuple[int]): Depths of each Swin Transformer stage.
346
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
347
+ drop_rate (float): Dropout rate.
348
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
349
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
350
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
351
+ out_indices (Sequence[int]): Output from which stages.
352
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
353
+ -1 means not freezing any parameters.
354
+ focal_levels (Sequence[int]): Number of focal levels at four stages
355
+ focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages
356
+ use_conv_embed (bool): Whether use overlapped convolution for patch embedding
357
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
358
+ """
359
+
360
+ def __init__(self,
361
+ pretrain_img_size=1600,
362
+ patch_size=4,
363
+ in_chans=3,
364
+ embed_dim=96,
365
+ depths=[2, 2, 6, 2],
366
+ mlp_ratio=4.,
367
+ drop_rate=0.,
368
+ drop_path_rate=0.3, # 0.3 or 0.4 works better for large+ models
369
+ norm_layer=nn.LayerNorm,
370
+ patch_norm=True,
371
+ out_indices=(0, 1, 2, 3),
372
+ frozen_stages=-1,
373
+ focal_levels=[3,3,3,3],
374
+ focal_windows=[3,3,3,3],
375
+ use_conv_embed=False,
376
+ use_postln=False,
377
+ use_postln_in_modulation=False,
378
+ use_layerscale=False,
379
+ normalize_modulator=False,
380
+ use_checkpoint=False,
381
+ ):
382
+ super().__init__()
383
+
384
+ self.pretrain_img_size = pretrain_img_size
385
+ self.num_layers = len(depths)
386
+ self.embed_dim = embed_dim
387
+ self.patch_norm = patch_norm
388
+ self.out_indices = out_indices
389
+ self.frozen_stages = frozen_stages
390
+
391
+ # split image into non-overlapping patches
392
+ self.patch_embed = PatchEmbed(
393
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
394
+ norm_layer=norm_layer if self.patch_norm else None,
395
+ use_conv_embed=use_conv_embed, is_stem=True)
396
+
397
+ self.pos_drop = nn.Dropout(p=drop_rate)
398
+
399
+ # stochastic depth
400
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
401
+
402
+ # build layers
403
+ self.layers = nn.ModuleList()
404
+ for i_layer in range(self.num_layers):
405
+ layer = BasicLayer(
406
+ dim=int(embed_dim * 2 ** i_layer),
407
+ depth=depths[i_layer],
408
+ mlp_ratio=mlp_ratio,
409
+ drop=drop_rate,
410
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
411
+ norm_layer=norm_layer,
412
+ downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
413
+ focal_window=focal_windows[i_layer],
414
+ focal_level=focal_levels[i_layer],
415
+ use_conv_embed=use_conv_embed,
416
+ use_postln=use_postln,
417
+ use_postln_in_modulation=use_postln_in_modulation,
418
+ normalize_modulator=normalize_modulator,
419
+ use_layerscale=use_layerscale,
420
+ use_checkpoint=use_checkpoint)
421
+ self.layers.append(layer)
422
+
423
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
424
+ self.num_features = num_features
425
+
426
+ # add a norm layer for each output
427
+ for i_layer in out_indices:
428
+ layer = norm_layer(num_features[i_layer])
429
+ layer_name = f'norm{i_layer}'
430
+ self.add_module(layer_name, layer)
431
+
432
+ self._freeze_stages()
433
+
434
+ def _freeze_stages(self):
435
+ if self.frozen_stages >= 0:
436
+ self.patch_embed.eval()
437
+ for param in self.patch_embed.parameters():
438
+ param.requires_grad = False
439
+
440
+ if self.frozen_stages >= 2:
441
+ self.pos_drop.eval()
442
+ for i in range(0, self.frozen_stages - 1):
443
+ m = self.layers[i]
444
+ m.eval()
445
+ for param in m.parameters():
446
+ param.requires_grad = False
447
+
448
+ def init_weights(self, pretrained=None):
449
+ """Initialize the weights in backbone.
450
+
451
+ Args:
452
+ pretrained (str, optional): Path to pre-trained weights.
453
+ Defaults to None.
454
+ """
455
+
456
+ def _init_weights(m):
457
+ if isinstance(m, nn.Linear):
458
+ trunc_normal_(m.weight, std=.02)
459
+ if isinstance(m, nn.Linear) and m.bias is not None:
460
+ nn.init.constant_(m.bias, 0)
461
+ elif isinstance(m, nn.LayerNorm):
462
+ nn.init.constant_(m.bias, 0)
463
+ nn.init.constant_(m.weight, 1.0)
464
+
465
+ if isinstance(pretrained, str):
466
+ self.apply(_init_weights)
467
+ logger = get_root_logger()
468
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
469
+ elif pretrained is None:
470
+ self.apply(_init_weights)
471
+ else:
472
+ raise TypeError('pretrained must be a str or None')
473
+
474
+ def forward(self, x):
475
+ """Forward function."""
476
+ x_emb = self.patch_embed(x)
477
+ Wh, Ww = x_emb.size(2), x_emb.size(3)
478
+
479
+ x = x_emb.flatten(2).transpose(1, 2)
480
+ x = self.pos_drop(x)
481
+
482
+ outs = []
483
+ for i in range(self.num_layers):
484
+ layer = self.layers[i]
485
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
486
+ if i in self.out_indices:
487
+ norm_layer = getattr(self, f'norm{i}')
488
+ x_out = norm_layer(x_out)
489
+
490
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
491
+ outs.append(out)
492
+ return outs, x_emb
493
+
494
+ def train(self, mode=True):
495
+ """Convert the model into training mode while keep layers freezed."""
496
+ super(FocalNet, self).train(mode)
497
+ self._freeze_stages()
498
+
499
+
500
+
501
+ def build_focalnet(modelname, **kw):
502
+ assert modelname in [
503
+ 'focalnet_L_384_22k',
504
+ 'focalnet_L_384_22k_fl4',
505
+ 'focalnet_XL_384_22k',
506
+ 'focalnet_XL_384_22k_fl4',
507
+ 'focalnet_H_224_22k',
508
+ 'focalnet_H_224_22k_fl4',
509
+ ]
510
+
511
+ if 'focal_levels' in kw:
512
+ kw['focal_levels'] = [kw['focal_levels']] * 4
513
+
514
+ if 'focal_windows' in kw:
515
+ kw['focal_windows'] = [kw['focal_windows']] * 4
516
+
517
+ model_para_dict = {
518
+ 'focalnet_L_384_22k': dict(
519
+ embed_dim=192,
520
+ depths=[ 2, 2, 18, 2 ],
521
+ focal_levels=kw.get('focal_levels', [3, 3, 3, 3]),
522
+ focal_windows=kw.get('focal_windows', [5, 5, 5, 5]),
523
+ use_conv_embed=True,
524
+ use_postln=True,
525
+ use_postln_in_modulation=False,
526
+ use_layerscale=True,
527
+ normalize_modulator=False,
528
+ ),
529
+ 'focalnet_L_384_22k_fl4': dict(
530
+ embed_dim=192,
531
+ depths=[ 2, 2, 18, 2 ],
532
+ focal_levels=kw.get('focal_levels', [4, 4, 4, 4]),
533
+ focal_windows=kw.get('focal_windows', [3, 3, 3, 3]),
534
+ use_conv_embed=True,
535
+ use_postln=True,
536
+ use_postln_in_modulation=False,
537
+ use_layerscale=True,
538
+ normalize_modulator=True,
539
+ ),
540
+ 'focalnet_XL_384_22k': dict(
541
+ embed_dim=256,
542
+ depths=[ 2, 2, 18, 2 ],
543
+ focal_levels=kw.get('focal_levels', [3, 3, 3, 3]),
544
+ focal_windows=kw.get('focal_windows', [5, 5, 5, 5]),
545
+ use_conv_embed=True,
546
+ use_postln=True,
547
+ use_postln_in_modulation=False,
548
+ use_layerscale=True,
549
+ normalize_modulator=False,
550
+ ),
551
+ 'focalnet_XL_384_22k_fl4': dict(
552
+ embed_dim=256,
553
+ depths=[ 2, 2, 18, 2 ],
554
+ focal_levels=kw.get('focal_levels', [4, 4, 4, 4]),
555
+ focal_windows=kw.get('focal_windows', [3, 3, 3, 3]),
556
+ use_conv_embed=True,
557
+ use_postln=True,
558
+ use_postln_in_modulation=False,
559
+ use_layerscale=True,
560
+ normalize_modulator=True,
561
+ ),
562
+ 'focalnet_H_224_22k': dict(
563
+ embed_dim=352,
564
+ depths=[ 2, 2, 18, 2 ],
565
+ focal_levels=kw.get('focal_levels', [3, 3, 3, 3]),
566
+ focal_windows=kw.get('focal_windows', [3, 3, 3, 3]),
567
+ use_conv_embed=True,
568
+ use_postln=True,
569
+ use_layerscale=True,
570
+ use_postln_in_modulation=True,
571
+ normalize_modulator=False,
572
+ ),
573
+ 'focalnet_H_224_22k_fl4': dict(
574
+ embed_dim=352,
575
+ depths=[ 2, 2, 18, 2 ],
576
+ focal_levels=kw.get('focal_levels', [4, 4, 4, 4]),
577
+ focal_windows=kw.get('focal_windows', [3, 3, 3, 3]),
578
+ use_conv_embed=True,
579
+ use_postln=True,
580
+ use_postln_in_modulation=True,
581
+ use_layerscale=True,
582
+ normalize_modulator=False,
583
+ ),
584
+ }
585
+
586
+ kw_cgf = model_para_dict[modelname]
587
+ kw_cgf.update(kw)
588
+ model = FocalNet(**kw_cgf)
589
+ return model
RDNet-main/RDNet-main/models/arch/modules_sig.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Reversible Column Networks
3
+ # Copyright (c) 2022 Megvii Inc.
4
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
5
+ # Written by Yuxuan Cai
6
+ # --------------------------------------------------------
7
+
8
+ import imp
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from timm.models.layers import DropPath
13
+
14
+
15
+
16
+
17
+ class LayerNormFunction(torch.autograd.Function):
18
+
19
+ @staticmethod
20
+ def forward(ctx, x, weight, bias, eps):
21
+ ctx.eps = eps
22
+ N, C, H, W = x.size()
23
+ mu = x.mean(1, keepdim=True)
24
+ var = (x - mu).pow(2).mean(1, keepdim=True)
25
+ y = (x - mu) / (var + eps).sqrt()
26
+ ctx.save_for_backward(y, var, weight)
27
+ y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
28
+ return y
29
+
30
+ @staticmethod
31
+ def backward(ctx, grad_output):
32
+ eps = ctx.eps
33
+
34
+ N, C, H, W = grad_output.size()
35
+ y, var, weight = ctx.saved_variables
36
+ g = grad_output * weight.view(1, C, 1, 1)
37
+ mean_g = g.mean(dim=1, keepdim=True)
38
+
39
+ mean_gy = (g * y).mean(dim=1, keepdim=True)
40
+ gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
41
+ return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
42
+ dim=0), None
43
+
44
+ class LayerNorm2d(nn.Module):
45
+
46
+ def __init__(self, channels, eps=1e-6):
47
+ super(LayerNorm2d, self).__init__()
48
+ self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
49
+ self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
50
+ self.eps = eps
51
+
52
+ def forward(self, x):
53
+ return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
54
+
55
+ class SimpleGate(nn.Module):
56
+ def forward(self, x):
57
+ x1, x2 = x.chunk(2, dim=1)
58
+ return x1 * x2
59
+
60
+ class NAFBlock(nn.Module):
61
+ def __init__(self, dim, expand_dim, out_dim, kernel_size=3, layer_scale_init_value=1e-6, drop_path=0.):
62
+ super().__init__()
63
+ drop_out_rate = 0.
64
+ dw_channel = expand_dim
65
+ self.conv1 = nn.Conv2d(in_channels=dim, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
66
+ self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=kernel_size, padding=1, stride=1, groups=dw_channel,
67
+ bias=True)
68
+ self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
69
+
70
+ # Simplified Channel Attention
71
+ self.sca = nn.Sequential(
72
+ nn.AdaptiveAvgPool2d(1),
73
+ nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
74
+ groups=1, bias=True),
75
+ )
76
+
77
+ # SimpleGate
78
+ self.sg = SimpleGate()
79
+
80
+ ffn_channel = expand_dim
81
+ self.conv4 = nn.Conv2d(in_channels=dim, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
82
+ self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=out_dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
83
+
84
+ self.norm1 = LayerNorm2d(dim)
85
+ self.norm2 = LayerNorm2d(dim)
86
+
87
+ self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
88
+ self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
89
+
90
+ self.beta = nn.Parameter(torch.ones((1, dim, 1, 1)) * layer_scale_init_value, requires_grad=True)
91
+ self.gamma = nn.Parameter(torch.ones((1, dim, 1, 1)) * layer_scale_init_value, requires_grad=True)
92
+
93
+ def forward(self, inp):
94
+ x = inp
95
+
96
+ x = self.norm1(x)
97
+
98
+ x = self.conv1(x)
99
+ x = self.conv2(x)
100
+ x = self.sg(x)
101
+ x = x * self.sca(x)
102
+ x = self.conv3(x)
103
+
104
+ x = self.dropout1(x)
105
+
106
+ y = inp + x * self.beta
107
+
108
+ x = self.conv4(self.norm2(y))
109
+ x = self.sg(x)
110
+ x = self.conv5(x)
111
+
112
+ x = self.dropout2(x)
113
+
114
+ return y + x * self.gamma
115
+
116
+
117
+ class UpSampleConvnext(nn.Module):
118
+ def __init__(self, ratio, inchannel, outchannel):
119
+ super().__init__()
120
+ self.ratio = ratio
121
+ self.channel_reschedule = nn.Sequential(
122
+ # LayerNorm(inchannel, eps=1e-6, data_format="channels_last"),
123
+ nn.Linear(inchannel, outchannel),
124
+ LayerNorm(outchannel, eps=1e-6, data_format="channels_last"))
125
+ self.upsample = nn.Upsample(scale_factor=2**ratio, mode='bilinear')
126
+ def forward(self, x):
127
+ x = x.permute(0, 2, 3, 1)
128
+ x = self.channel_reschedule(x)
129
+ x = x = x.permute(0, 3, 1, 2)
130
+
131
+ return self.upsample(x)
132
+
133
+ class LayerNorm(nn.Module):
134
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
135
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
136
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
137
+ with shape (batch_size, channels, height, width).
138
+ """
139
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first", elementwise_affine = True):
140
+ super().__init__()
141
+ self.elementwise_affine = elementwise_affine
142
+ if elementwise_affine:
143
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
144
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
145
+ self.eps = eps
146
+ self.data_format = data_format
147
+ if self.data_format not in ["channels_last", "channels_first"]:
148
+ raise NotImplementedError
149
+ self.normalized_shape = (normalized_shape, )
150
+
151
+ def forward(self, x):
152
+ if self.data_format == "channels_last":
153
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
154
+ elif self.data_format == "channels_first":
155
+ u = x.mean(1, keepdim=True)
156
+ s = (x - u).pow(2).mean(1, keepdim=True)
157
+ x = (x - u) / torch.sqrt(s + self.eps)
158
+ if self.elementwise_affine:
159
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
160
+ return x
161
+
162
+
163
+ class ConvNextBlock(nn.Module):
164
+ r""" ConvNeXt Block. There are two equivalent implementations:
165
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
166
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
167
+ We use (2) as we find it slightly faster in PyTorch
168
+
169
+ Args:
170
+ dim (int): Number of input channels.
171
+ drop_path (float): Stochastic depth rate. Default: 0.0
172
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
173
+ """
174
+ def __init__(self, in_channel, hidden_dim, out_channel, kernel_size=3, layer_scale_init_value=1e-6, drop_path= 0.0):
175
+ super().__init__()
176
+ self.dwconv = nn.Conv2d(in_channel, in_channel, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, groups=in_channel) # depthwise conv
177
+ self.norm = nn.LayerNorm(in_channel, eps=1e-6)
178
+ self.pwconv1 = nn.Linear(in_channel, hidden_dim) # pointwise/1x1 convs, implemented with linear layers
179
+ self.act = nn.GELU()
180
+ self.pwconv2 = nn.Linear(hidden_dim, out_channel)
181
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((out_channel)),
182
+ requires_grad=True) if layer_scale_init_value > 0 else None
183
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
184
+
185
+ def forward(self, x):
186
+ input = x
187
+ x = self.dwconv(x)
188
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
189
+ x = self.norm(x)
190
+ x = self.pwconv1(x)
191
+ x = self.act(x)
192
+ x = self.pwconv2(x)
193
+ if self.gamma is not None:
194
+ x = self.gamma * x
195
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
196
+
197
+ x = input + self.drop_path(x)
198
+ return x
199
+
200
+ class Decoder(nn.Module):
201
+ def __init__(self, depth=[2,2,2,2], dim=[112, 72, 40, 24], block_type = None, kernel_size = 3) -> None:
202
+ super().__init__()
203
+ self.depth = depth
204
+ self.dim = dim
205
+ self.block_type = block_type
206
+ self._build_decode_layer(dim, depth, kernel_size)
207
+ self.pixelshuffle=nn.PixelShuffle(2)
208
+ # self.star_relu=StarReLU()
209
+ self.projback_ = nn.Sequential(
210
+ nn.Conv2d(
211
+ in_channels=dim[-1],
212
+ out_channels=2 ** 2 * 3 , kernel_size=1),
213
+ nn.PixelShuffle(2)
214
+ )
215
+ self.projback_2 = nn.Sequential(
216
+ nn.Conv2d(
217
+ in_channels=dim[-1],
218
+ out_channels=2 ** 2 * 3, kernel_size=1),
219
+ nn.PixelShuffle(2)
220
+ )
221
+
222
+ def _build_decode_layer(self, dim, depth, kernel_size):
223
+ normal_layers = nn.ModuleList()
224
+ upsample_layers = nn.ModuleList()
225
+ proj_layers = nn.ModuleList()
226
+
227
+ norm_layer = LayerNorm
228
+
229
+ for i in range(1, len(dim)):
230
+ module = [self.block_type(dim[i], dim[i], dim[i], kernel_size) for _ in range(depth[i])]
231
+ normal_layers.append(nn.Sequential(*module))
232
+ upsample_layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
233
+ proj_layers.append(nn.Sequential(
234
+ nn.Conv2d(dim[i-1], dim[i], 1, 1),
235
+ norm_layer(dim[i]),
236
+ # StarReLU() #self.star_relu()
237
+ nn.GELU()
238
+ ))
239
+ for i in range(1, len(dim)):
240
+ module = [self.block_type(dim[i], dim[i], dim[i], kernel_size) for _ in range(depth[i])]
241
+ normal_layers.append(nn.Sequential(*module))
242
+ upsample_layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
243
+ proj_layers.append(nn.Sequential(
244
+ nn.Conv2d(dim[i-1], dim[i], 1, 1),
245
+ norm_layer(dim[i]),
246
+ ))
247
+ self.normal_layers = normal_layers
248
+ self.upsample_layers = upsample_layers
249
+ self.proj_layers = proj_layers
250
+
251
+ def _forward_stage(self, stage, x):
252
+ x = self.proj_layers[stage](x)
253
+ x = self.upsample_layers[stage](x)
254
+ return self.normal_layers[stage](x)
255
+
256
+ def forward(self, c3, c2, c1, c0):
257
+ c0_clean, c0_ref = c0, c0
258
+ c1_clean, c1_ref = c1, c1
259
+ c2_clean, c2_ref = c2, c2
260
+ c3_clean, c3_ref = c3, c3
261
+ x_clean = self._forward_stage(0, c3_clean) * c2_clean
262
+ x_clean = self._forward_stage(1, x_clean) * c1_clean
263
+ x_clean = self._forward_stage(2, x_clean) * c0_clean
264
+ x_clean = self.projback_(x_clean)
265
+
266
+ x_ref = self._forward_stage(3, c3_ref) * c2_ref
267
+ x_ref = self._forward_stage(4, x_ref) * c1_ref
268
+ x_ref = self._forward_stage(5, x_ref) * c0_ref
269
+ x_ref = self.projback_2(x_ref)
270
+
271
+ x=torch.cat((x_clean,x_ref),dim=1)
272
+ return x
273
+
274
+ class SimDecoder(nn.Module):
275
+ def __init__(self, in_channel, encoder_stride) -> None:
276
+ super().__init__()
277
+ self.projback = nn.Sequential(
278
+ LayerNorm(in_channel),
279
+ nn.Conv2d(
280
+ in_channels=in_channel,
281
+ out_channels=encoder_stride ** 2 * 3, kernel_size=1),
282
+ nn.PixelShuffle(encoder_stride),
283
+ )
284
+
285
+ def forward(self, c3):
286
+ return self.projback(c3)
287
+
288
+
289
+ class StarReLU(nn.Module):
290
+ """
291
+ StarReLU: s * relu(x) ** 2 + b
292
+ """
293
+ def __init__(self, scale_value=1.0, bias_value=0.0,
294
+ scale_learnable=True, bias_learnable=True,
295
+ mode=None, inplace=True):
296
+ super().__init__()
297
+ self.inplace = inplace
298
+ self.relu = nn.ReLU(inplace=inplace)
299
+ self.scale = nn.Parameter(scale_value * torch.ones(1),
300
+ requires_grad=scale_learnable)
301
+ self.bias = nn.Parameter(bias_value * torch.ones(1),
302
+ requires_grad=bias_learnable)
303
+ def forward(self, x):
304
+ return self.scale * self.relu(x)**2 + self.bias
RDNet-main/RDNet-main/models/arch/reverse_function.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from typing import Any, Iterable, List, Tuple, Callable
4
+ import torch.distributed as dist
5
+
6
+ def get_gpu_states(fwd_gpu_devices) -> Tuple[List[int], List[torch.Tensor]]:
7
+ fwd_gpu_states = []
8
+ for device in fwd_gpu_devices:
9
+ with torch.cuda.device(device):
10
+ fwd_gpu_states.append(torch.cuda.get_rng_state())
11
+
12
+ return fwd_gpu_states
13
+
14
+ def get_gpu_device(*args):
15
+
16
+ fwd_gpu_devices = list(set(arg.get_device() for arg in args
17
+ if isinstance(arg, torch.Tensor) and arg.is_cuda))
18
+ return fwd_gpu_devices
19
+
20
+ def set_device_states(fwd_cpu_state, devices, states) -> None:
21
+ torch.set_rng_state(fwd_cpu_state)
22
+ for device, state in zip(devices, states):
23
+ with torch.cuda.device(device):
24
+ torch.cuda.set_rng_state(state)
25
+
26
+ def detach_and_grad(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
27
+ if isinstance(inputs, tuple):
28
+ out = []
29
+ for inp in inputs:
30
+ if not isinstance(inp, torch.Tensor):
31
+ out.append(inp)
32
+ continue
33
+
34
+ x = inp.detach()
35
+ x.requires_grad = True
36
+ out.append(x)
37
+ return tuple(out)
38
+ else:
39
+ raise RuntimeError(
40
+ "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
41
+
42
+ def get_cpu_and_gpu_states(gpu_devices):
43
+ return torch.get_rng_state(), get_gpu_states(gpu_devices)
44
+
45
+ class ReverseFunction(torch.autograd.Function):
46
+ @staticmethod
47
+ def forward(ctx, run_functions, alpha, *args):
48
+ l0, l1, l2, l3 = run_functions
49
+ alpha0, alpha1, alpha2, alpha3 = alpha
50
+ ctx.run_functions = run_functions
51
+ ctx.alpha = alpha
52
+ ctx.preserve_rng_state = True
53
+
54
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
55
+ "dtype": torch.get_autocast_gpu_dtype(),
56
+ "cache_enabled": torch.is_autocast_cache_enabled()}
57
+ ctx.cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
58
+ "dtype": torch.get_autocast_cpu_dtype(),
59
+ "cache_enabled": torch.is_autocast_cache_enabled()}
60
+
61
+ assert len(args) == 5
62
+ [x, c0, c1, c2, c3] = args
63
+ if type(c0) == int:
64
+ ctx.first_col = True
65
+ else:
66
+ ctx.first_col = False
67
+ with torch.no_grad():
68
+ gpu_devices = get_gpu_device(*args)
69
+ ctx.gpu_devices = gpu_devices
70
+ ctx.cpu_states_0, ctx.gpu_states_0 = get_cpu_and_gpu_states(gpu_devices)
71
+ c0 = l0(x, c1) + c0*alpha0
72
+ ctx.cpu_states_1, ctx.gpu_states_1 = get_cpu_and_gpu_states(gpu_devices)
73
+ c1 = l1(c0, c2) + c1*alpha1
74
+ ctx.cpu_states_2, ctx.gpu_states_2 = get_cpu_and_gpu_states(gpu_devices)
75
+ c2 = l2(c1, c3) + c2*alpha2
76
+ ctx.cpu_states_3, ctx.gpu_states_3 = get_cpu_and_gpu_states(gpu_devices)
77
+ c3 = l3(c2, None) + c3*alpha3
78
+ ctx.save_for_backward(x, c0, c1, c2, c3)
79
+ return x, c0, c1 ,c2, c3
80
+
81
+ @staticmethod
82
+ def backward(ctx, *grad_outputs):
83
+ x, c0, c1, c2, c3 = ctx.saved_tensors
84
+ l0, l1, l2, l3 = ctx.run_functions
85
+ alpha0, alpha1, alpha2, alpha3 = ctx.alpha
86
+ gx_right, g0_right, g1_right, g2_right, g3_right = grad_outputs
87
+ (x, c0, c1, c2, c3) = detach_and_grad((x, c0, c1, c2, c3))
88
+
89
+ with torch.enable_grad(), \
90
+ torch.random.fork_rng(devices=ctx.gpu_devices, enabled=ctx.preserve_rng_state), \
91
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
92
+ torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
93
+
94
+ g3_up = g3_right
95
+ g3_left = g3_up*alpha3 ##shortcut
96
+ set_device_states(ctx.cpu_states_3, ctx.gpu_devices, ctx.gpu_states_3)
97
+ oup3 = l3(c2, None)
98
+ torch.autograd.backward(oup3, g3_up, retain_graph=True)
99
+ with torch.no_grad():
100
+ c3_left = (1/alpha3)*(c3 - oup3) ## feature reverse
101
+ g2_up = g2_right+ c2.grad
102
+ g2_left = g2_up*alpha2 ##shortcut
103
+
104
+ (c3_left,) = detach_and_grad((c3_left,))
105
+ set_device_states(ctx.cpu_states_2, ctx.gpu_devices, ctx.gpu_states_2)
106
+ oup2 = l2(c1, c3_left)
107
+ torch.autograd.backward(oup2, g2_up, retain_graph=True)
108
+ c3_left.requires_grad = False
109
+ cout3 = c3_left*alpha3 ##alpha3 update
110
+ torch.autograd.backward(cout3, g3_up)
111
+
112
+ with torch.no_grad():
113
+ c2_left = (1/alpha2)*(c2 - oup2) ## feature reverse
114
+ g3_left = g3_left + c3_left.grad if c3_left.grad is not None else g3_left
115
+ g1_up = g1_right+c1.grad
116
+ g1_left = g1_up*alpha1 ##shortcut
117
+
118
+ (c2_left,) = detach_and_grad((c2_left,))
119
+ set_device_states(ctx.cpu_states_1, ctx.gpu_devices, ctx.gpu_states_1)
120
+ oup1 = l1(c0, c2_left)
121
+ torch.autograd.backward(oup1, g1_up, retain_graph=True)
122
+ c2_left.requires_grad = False
123
+ cout2 = c2_left*alpha2 ##alpha2 update
124
+ torch.autograd.backward(cout2, g2_up)
125
+
126
+ with torch.no_grad():
127
+ c1_left = (1/alpha1)*(c1 - oup1) ## feature reverse
128
+ g0_up = g0_right + c0.grad
129
+ g0_left = g0_up*alpha0 ##shortcut
130
+ g2_left = g2_left + c2_left.grad if c2_left.grad is not None else g2_left ## Fusion
131
+
132
+ (c1_left,) = detach_and_grad((c1_left,))
133
+ set_device_states(ctx.cpu_states_0, ctx.gpu_devices, ctx.gpu_states_0)
134
+ oup0 = l0(x, c1_left)
135
+ torch.autograd.backward(oup0, g0_up, retain_graph=True)
136
+ c1_left.requires_grad = False
137
+ cout1 = c1_left*alpha1 ##alpha1 update
138
+ torch.autograd.backward(cout1, g1_up)
139
+
140
+ with torch.no_grad():
141
+ c0_left = (1/alpha0)*(c0 - oup0) ## feature reverse
142
+ gx_up = x.grad ## Fusion
143
+ g1_left = g1_left + c1_left.grad if c1_left.grad is not None else g1_left ## Fusion
144
+ c0_left.requires_grad = False
145
+ cout0 = c0_left*alpha0 ##alpha0 update
146
+ torch.autograd.backward(cout0, g0_up)
147
+
148
+ if ctx.first_col:
149
+ return None, None, gx_up, None, None, None, None
150
+ else:
151
+ return None, None, gx_up, g0_left, g1_left, g2_left, g3_left
152
+
153
+
RDNet-main/RDNet-main/models/arch/vgg.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ import torch
4
+ from torchvision import models
5
+
6
+
7
+ class Vgg16(torch.nn.Module):
8
+ def __init__(self, requires_grad=False):
9
+ super(Vgg16, self).__init__()
10
+ vgg_pretrained_features = models.vgg16(pretrained=True).features
11
+ self.slice1 = torch.nn.Sequential()
12
+ self.slice2 = torch.nn.Sequential()
13
+ self.slice3 = torch.nn.Sequential()
14
+ self.slice4 = torch.nn.Sequential()
15
+ for x in range(4):
16
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
17
+ for x in range(4, 9):
18
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
19
+ for x in range(9, 16):
20
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
21
+ for x in range(16, 23):
22
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
23
+ if not requires_grad:
24
+ for param in self.parameters():
25
+ param.requires_grad = False
26
+
27
+ def forward(self, X):
28
+ h = self.slice1(X)
29
+ h_relu1_2 = h
30
+ h = self.slice2(h)
31
+ h_relu2_2 = h
32
+ h = self.slice3(h)
33
+ h_relu3_3 = h
34
+ h = self.slice4(h)
35
+ h_relu4_3 = h
36
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
37
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
38
+ return out
39
+
40
+
41
+ class Vgg19(torch.nn.Module):
42
+ def __init__(self, requires_grad=False):
43
+ super(Vgg19, self).__init__()
44
+ # vgg_pretrained_features = models.vgg19(pretrained=True).features
45
+ self.vgg_pretrained_features = models.vgg19(pretrained=True).features
46
+ # self.slice1 = torch.nn.Sequential()
47
+ # self.slice2 = torch.nn.Sequential()
48
+ # self.slice3 = torch.nn.Sequential()
49
+ # self.slice4 = torch.nn.Sequential()
50
+ # self.slice5 = torch.nn.Sequential()
51
+ # for x in range(2):
52
+ # self.slice1.add_module(str(x), vgg_pretrained_features[x])
53
+ # for x in range(2, 7):
54
+ # self.slice2.add_module(str(x), vgg_pretrained_features[x])
55
+ # for x in range(7, 12):
56
+ # self.slice3.add_module(str(x), vgg_pretrained_features[x])
57
+ # for x in range(12, 21):
58
+ # self.slice4.add_module(str(x), vgg_pretrained_features[x])
59
+ # for x in range(21, 30):
60
+ # self.slice5.add_module(str(x), vgg_pretrained_features[x])
61
+ if not requires_grad:
62
+ for param in self.parameters():
63
+ param.requires_grad = False
64
+
65
+ def forward(self, X, indices=None):
66
+ if indices is None:
67
+ indices = [2, 7, 12, 21, 30]
68
+ out = []
69
+ # indices = sorted(indices)
70
+ for i in range(indices[-1]):
71
+ X = self.vgg_pretrained_features[i](X)
72
+ if (i + 1) in indices:
73
+ out.append(X)
74
+
75
+ return out
76
+
77
+ # h_relu1 = self.slice1(X)
78
+ # h_relu2 = self.slice2(h_relu1)
79
+ # h_relu3 = self.slice3(h_relu2)
80
+ # h_relu4 = self.slice4(h_relu3)
81
+ # h_relu5 = self.slice5(h_relu4)
82
+ # out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
83
+ # return out
84
+
85
+
86
+ if __name__ == '__main__':
87
+ vgg = Vgg19()
88
+ import ipdb
89
+
90
+ ipdb.set_trace()
RDNet-main/RDNet-main/models/base_model.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import util.util as util
4
+
5
+
6
+ class BaseModel:
7
+ def name(self):
8
+ return self.__class__.__name__.lower()
9
+
10
+ def initialize(self, opt):
11
+ self.opt = opt
12
+ self.gpu_ids = opt.gpu_ids
13
+ self.isTrain = opt.isTrain
14
+ self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
15
+ last_split = opt.checkpoints_dir.split('/')[-1]
16
+ if opt.resume and last_split != 'checkpoints' and (last_split != opt.name or opt.supp_eval):
17
+
18
+ self.save_dir = opt.checkpoints_dir
19
+ self.model_save_dir = os.path.join(opt.checkpoints_dir.replace(opt.checkpoints_dir.split('/')[-1], ''),
20
+ opt.name)
21
+ else:
22
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
23
+ self.model_save_dir = os.path.join(opt.checkpoints_dir, opt.name)
24
+ self._count = 0
25
+
26
+ def set_input(self, input):
27
+ self.input = input
28
+
29
+ def forward(self, mode='train'):
30
+ pass
31
+
32
+ # used in test time, no backprop
33
+ def test(self):
34
+ pass
35
+
36
+ def get_image_paths(self):
37
+ pass
38
+
39
+ def optimize_parameters(self):
40
+ pass
41
+
42
+ def get_current_visuals(self):
43
+ return self.input
44
+
45
+ def get_current_errors(self):
46
+ return {}
47
+
48
+ def print_optimizer_param(self):
49
+ print(self.optimizers[-1])
50
+
51
+ def save(self, label=None):
52
+ epoch = self.epoch
53
+ iterations = self.iterations
54
+
55
+ if label is None:
56
+ model_name = os.path.join(self.model_save_dir, self.opt.name + '_%03d_%08d.pt' % ((epoch), (iterations)))
57
+ else:
58
+ model_name = os.path.join(self.model_save_dir, self.opt.name + '_' + label + '.pt')
59
+
60
+ torch.save(self.state_dict(), model_name)
61
+
62
+ def save_eval(self, label=None):
63
+ model_name = os.path.join(self.model_save_dir, label + '.pt')
64
+
65
+ torch.save(self.state_dict_eval(), model_name)
66
+
67
+ def _init_optimizer(self, optimizers):
68
+ self.optimizers = optimizers
69
+ for optimizer in self.optimizers:
70
+ util.set_opt_param(optimizer, 'initial_lr', self.opt.lr)
71
+ util.set_opt_param(optimizer, 'weight_decay', self.opt.wd)
RDNet-main/RDNet-main/models/cls_model_eval_nocls_reg.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from models.losses import DINOLoss
5
+ import os
6
+ import numpy as np
7
+ from collections import OrderedDict
8
+ from ema_pytorch import EMA
9
+ from models.arch.classifier import PretrainedConvNext
10
+ import util.util as util
11
+ import util.index as index
12
+ import models.networks as networks
13
+ import models.losses as losses
14
+ from models import arch
15
+ #from models.arch.dncnn import effnetv2_s
16
+ from .base_model import BaseModel
17
+ from PIL import Image
18
+ from os.path import join
19
+ #from torchviz import make_dot
20
+ from models.arch.RDnet_ import FullNet_NLP
21
+ import timm
22
+
23
+ def tensor2im(image_tensor, imtype=np.uint8):
24
+ image_tensor = image_tensor.detach()
25
+ image_numpy = image_tensor[0].cpu().float().numpy()
26
+ image_numpy = np.clip(image_numpy, 0, 1)
27
+ if image_numpy.shape[0] == 1:
28
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
29
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0
30
+ # image_numpy = image_numpy.astype(imtype)
31
+ return image_numpy
32
+
33
+
34
+ class EdgeMap(nn.Module):
35
+ def __init__(self, scale=1):
36
+ super(EdgeMap, self).__init__()
37
+ self.scale = scale
38
+ self.requires_grad = False
39
+
40
+ def forward(self, img):
41
+ img = img / self.scale
42
+
43
+ N, C, H, W = img.shape
44
+ gradX = torch.zeros(N, 1, H, W, dtype=img.dtype, device=img.device)
45
+ gradY = torch.zeros(N, 1, H, W, dtype=img.dtype, device=img.device)
46
+
47
+ gradx = (img[..., 1:, :] - img[..., :-1, :]).abs().sum(dim=1, keepdim=True)
48
+ grady = (img[..., 1:] - img[..., :-1]).abs().sum(dim=1, keepdim=True)
49
+
50
+ gradX[..., :-1, :] += gradx
51
+ gradX[..., 1:, :] += gradx
52
+ gradX[..., 1:-1, :] /= 2
53
+
54
+ gradY[..., :-1] += grady
55
+ gradY[..., 1:] += grady
56
+ gradY[..., 1:-1] /= 2
57
+
58
+ # edge = (gradX + gradY) / 2
59
+ edge = (gradX + gradY)
60
+
61
+ return edge
62
+
63
+
64
+ class YTMTNetBase(BaseModel):
65
+ def _init_optimizer(self, optimizers):
66
+ self.optimizers = optimizers
67
+ for optimizer in self.optimizers:
68
+ util.set_opt_param(optimizer, 'initial_lr', self.opt.lr)
69
+ util.set_opt_param(optimizer, 'weight_decay', self.opt.wd)
70
+
71
+ def set_input(self, data, mode='train'):
72
+ target_t = None
73
+ target_r = None
74
+ data_name = None
75
+ identity = False
76
+ mode = mode.lower()
77
+ if mode == 'train':
78
+ input, target_t, target_r = data['input'], data['target_t'], data['target_r']
79
+ elif mode == 'eval':
80
+ input, target_t, target_r, data_name = data['input'], data['target_t'], data['target_r'], data['fn']
81
+ elif mode == 'test':
82
+ input, data_name = data['input'], data['fn']
83
+ else:
84
+ raise NotImplementedError('Mode [%s] is not implemented' % mode)
85
+
86
+ if len(self.gpu_ids) > 0: # transfer data into gpu
87
+ input = input.to(device=self.gpu_ids[0])
88
+ if target_t is not None:
89
+ target_t = target_t.to(device=self.gpu_ids[0])
90
+ if target_r is not None:
91
+ target_r = target_r.to(device=self.gpu_ids[0])
92
+
93
+ self.input = input
94
+ self.identity = identity
95
+ self.input_edge = self.edge_map(self.input)
96
+ self.target_t = target_t
97
+ self.target_r = target_r
98
+ self.data_name = data_name
99
+
100
+ self.issyn = False if 'real' in data else True
101
+ self.aligned = False if 'unaligned' in data else True
102
+
103
+ if target_t is not None:
104
+ self.target_edge = self.edge_map(self.target_t)
105
+
106
+ def eval(self, data, savedir=None, suffix=None, pieapp=None):
107
+ self._eval()
108
+ self.set_input(data, 'eval')
109
+ with torch.no_grad():
110
+ self.forward_eval()
111
+
112
+ output_i = tensor2im(self.output_j[6])
113
+ output_j = tensor2im(self.output_j[7])
114
+ target = tensor2im(self.target_t)
115
+ target_r = tensor2im(self.target_r)
116
+
117
+ if self.aligned:
118
+ res = index.quality_assess(output_i, target)
119
+ else:
120
+ res = {}
121
+
122
+ if savedir is not None:
123
+ if self.data_name is not None:
124
+ name = os.path.splitext(os.path.basename(self.data_name[0]))[0]
125
+ savedir = join(savedir, suffix, name)
126
+ os.makedirs(savedir, exist_ok=True)
127
+ Image.fromarray(output_i.astype(np.uint8)).save(
128
+ join(savedir, '{}_t.png'.format(self.opt.name)))
129
+ Image.fromarray(output_j.astype(np.uint8)).save(
130
+ join(savedir, '{}_r.png'.format(self.opt.name)))
131
+ Image.fromarray(target.astype(np.uint8)).save(join(savedir, 't_label.png'))
132
+ Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(join(savedir, 'm_input.png'))
133
+ else:
134
+ if not os.path.exists(join(savedir, 'transmission_layer')):
135
+ os.makedirs(join(savedir, 'transmission_layer'))
136
+ os.makedirs(join(savedir, 'blended'))
137
+ Image.fromarray(target.astype(np.uint8)).save(
138
+ join(savedir, 'transmission_layer', str(self._count) + '.png'))
139
+ Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(
140
+ join(savedir, 'blended', str(self._count) + '.png'))
141
+ self._count += 1
142
+
143
+ return res
144
+
145
+ def test(self, data, savedir=None):
146
+ # only the 1st input of the whole minibatch would be evaluated
147
+ self._eval()
148
+ self.set_input(data, 'test')
149
+
150
+ if self.data_name is not None and savedir is not None:
151
+ name = os.path.splitext(os.path.basename(self.data_name[0]))[0]
152
+ if not os.path.exists(join(savedir, name)):
153
+ os.makedirs(join(savedir, name))
154
+
155
+ if os.path.exists(join(savedir, name, '{}.png'.format(self.opt.name))):
156
+ return
157
+
158
+ with torch.no_grad():
159
+ output_i, output_j = self.forward()
160
+ output_i = tensor2im(output_i)
161
+ output_j = tensor2im(output_j)
162
+ if self.data_name is not None and savedir is not None:
163
+ Image.fromarray(output_i.astype(np.uint8)).save(join(savedir, name, '{}_l.png'.format(self.opt.name)))
164
+ Image.fromarray(output_j.astype(np.uint8)).save(join(savedir, name, '{}_r.png'.format(self.opt.name)))
165
+ Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(join(savedir, name, 'm_input.png'))
166
+
167
+
168
+ class ClsModel(YTMTNetBase):
169
+ def name(self):
170
+ return 'ytmtnet'
171
+
172
+ def __init__(self):
173
+ self.epoch = 0
174
+ self.iterations = 0
175
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
176
+ self.net_c = None
177
+
178
+ def print_network(self):
179
+ print('--------------------- Model ---------------------')
180
+ print('##################### NetG #####################')
181
+ networks.print_network(self.net_i)
182
+ if self.isTrain and self.opt.lambda_gan > 0:
183
+ print('##################### NetD #####################')
184
+ networks.print_network(self.netD)
185
+
186
+ def _eval(self):
187
+ self.net_i.eval()
188
+ self.net_c.eval()
189
+
190
+ def _train(self):
191
+ self.net_i.train()
192
+ self.net_c.eval()
193
+ def initialize(self, opt):
194
+ self.opt=opt
195
+ BaseModel.initialize(self, opt)
196
+
197
+ in_channels = 3
198
+ self.vgg = None
199
+
200
+ if opt.hyper:
201
+ self.vgg = losses.Vgg19(requires_grad=False).to(self.device)
202
+ in_channels += 1472
203
+ channels = [64, 128, 256, 512]
204
+ layers = [2, 2, 4, 2]
205
+ num_subnet = opt.num_subnet
206
+ self.net_c = PretrainedConvNext("convnext_small_in22k").cuda()
207
+
208
+ self.net_c.load_state_dict(torch.load('pretrained/cls_model.pth')['icnn'])
209
+
210
+ self.net_i = FullNet_NLP(channels, layers, num_subnet, opt.loss_col,num_classes=1000, drop_path=0,save_memory=True, inter_supv=True, head_init_scale=None, kernel_size=3).to(self.device)
211
+
212
+ self.edge_map = EdgeMap(scale=1).to(self.device)
213
+
214
+ if self.isTrain:
215
+ self.loss_dic = losses.init_loss(opt, self.Tensor)
216
+ vggloss = losses.ContentLoss()
217
+ vggloss.initialize(losses.VGGLoss(self.vgg))
218
+ self.loss_dic['t_vgg'] = vggloss
219
+
220
+ cxloss = losses.ContentLoss()
221
+ if opt.unaligned_loss == 'vgg':
222
+ cxloss.initialize(losses.VGGLoss(self.vgg, weights=[0.1], indices=[opt.vgg_layer]))
223
+ elif opt.unaligned_loss == 'ctx':
224
+ cxloss.initialize(losses.CXLoss(self.vgg, weights=[0.1, 0.1, 0.1], indices=[8, 13, 22]))
225
+ elif opt.unaligned_loss == 'mse':
226
+ cxloss.initialize(nn.MSELoss())
227
+ elif opt.unaligned_loss == 'ctx_vgg':
228
+ cxloss.initialize(losses.CXLoss(self.vgg, weights=[0.1, 0.1, 0.1, 0.1], indices=[8, 13, 22, 31],
229
+ criterions=[losses.CX_loss] * 3 + [nn.L1Loss()]))
230
+ else:
231
+ raise NotImplementedError
232
+ self.scaler=torch.cuda.amp.GradScaler()
233
+ with torch.autocast(device_type='cuda',dtype=torch.float16):
234
+ self.dinoloss=DINOLoss()
235
+ self.loss_dic['t_cx'] = cxloss
236
+
237
+ self.optimizer_G = torch.optim.Adam(self.net_i.parameters(),
238
+ lr=opt.lr, betas=(0.9, 0.999), weight_decay=opt.wd)
239
+
240
+
241
+ self._init_optimizer([self.optimizer_G])
242
+
243
+ if opt.resume:
244
+ self.load(self, opt.resume_epoch)
245
+
246
+
247
+ def backward_D(self):
248
+ loss_D=[]
249
+ weight=self.opt.weight_loss
250
+ for p in self.netD.parameters():
251
+ p.requires_grad = True
252
+ for i in range(4):
253
+ loss_D_1, pred_fake_1, pred_real_1 = self.loss_dic['gan'].get_loss(
254
+ self.netD, self.input, self.output_j[2*i], self.target_t)
255
+ loss_D.append(loss_D_1*weight)
256
+ weight+=self.opt.weight_loss
257
+ loss_sum=sum(loss_D)
258
+
259
+ self.loss_D, self.pred_fake, self.pred_real = (loss_sum, pred_fake_1, pred_real_1)
260
+
261
+ (self.loss_D * self.opt.lambda_gan).backward(retain_graph=True)
262
+
263
+ def get_loss(self, out_l, out_r):
264
+ loss_G_GAN_sum=[]
265
+ loss_icnn_pixel_sum=[]
266
+ loss_rcnn_pixel_sum=[]
267
+ loss_icnn_vgg_sum=[]
268
+ weight=self.opt.weight_loss
269
+ for i in range(self.opt.loss_col):
270
+ out_r_clean=out_r[2*i]
271
+ out_r_reflection=out_r[2*i+1]
272
+ if i != self.opt.loss_col -1:
273
+ loss_G_GAN = 0
274
+ loss_icnn_pixel = self.loss_dic['t_pixel'].get_loss(out_r_clean, self.target_t)
275
+ loss_rcnn_pixel = self.loss_dic['r_pixel'].get_loss(out_r_reflection, self.target_r) * 1.5 * self.opt.r_pixel_weight
276
+ loss_icnn_vgg = self.loss_dic['t_vgg'].get_loss(out_r_clean, self.target_t) * self.opt.lambda_vgg
277
+ else:
278
+ if self.opt.lambda_gan>0:
279
+
280
+ loss_G_GAN=0
281
+ else:
282
+ loss_G_GAN=0
283
+ loss_icnn_pixel = self.loss_dic['t_pixel'].get_loss(out_r_clean, self.target_t)
284
+ loss_rcnn_pixel = self.loss_dic['r_pixel'].get_loss(out_r_reflection, self.target_r) * 1.5 * self.opt.r_pixel_weight
285
+ loss_icnn_vgg = self.loss_dic['t_vgg'].get_loss(out_r_clean, self.target_t) * self.opt.lambda_vgg
286
+
287
+ loss_G_GAN_sum.append(loss_G_GAN*weight)
288
+ loss_icnn_pixel_sum.append(loss_icnn_pixel*weight)
289
+ loss_rcnn_pixel_sum.append(loss_rcnn_pixel*weight)
290
+ loss_icnn_vgg_sum.append(loss_icnn_vgg*weight)
291
+ weight=weight+self.opt.weight_loss
292
+ return sum(loss_G_GAN_sum), sum(loss_icnn_pixel_sum), sum(loss_rcnn_pixel_sum), sum(loss_icnn_vgg_sum)
293
+
294
+ def backward_G(self):
295
+
296
+ self.loss_G_GAN,self.loss_icnn_pixel, self.loss_rcnn_pixel, \
297
+ self.loss_icnn_vgg = self.get_loss(self.output_i, self.output_j)
298
+
299
+ self.loss_exclu = self.exclusion_loss(self.output_i, self.output_j, 3)
300
+
301
+ self.loss_recons = self.loss_dic['recons'](self.output_i, self.output_j, self.input) * 0.2
302
+
303
+ self.loss_G = self.loss_G_GAN +self.loss_icnn_pixel + self.loss_rcnn_pixel + \
304
+ self.loss_icnn_vgg
305
+ self.scaler.scale(self.loss_G).backward()
306
+
307
+
308
+
309
+ def hyper_column(self, input_img):
310
+ hypercolumn = self.vgg(input_img)
311
+ _, C, H, W = input_img.shape
312
+ hypercolumn = [F.interpolate(feature.detach(), size=(H, W), mode='bilinear', align_corners=False) for
313
+ feature in hypercolumn]
314
+ input_i = [input_img]
315
+ input_i.extend(hypercolumn)
316
+ input_i = torch.cat(input_i, dim=1)
317
+ return input_i
318
+
319
+ def forward(self):
320
+ # without edge
321
+
322
+ self.output_j=[]
323
+ input_i = self.input
324
+ if self.vgg is not None:
325
+ input_i = self.hyper_column(input_i)
326
+ with torch.no_grad():
327
+ ipt = self.net_c(input_i)
328
+ output_i, output_j = self.net_i(input_i,ipt,prompt=True)
329
+ self.output_i = output_i
330
+ for i in range(self.opt.loss_col):
331
+ out_reflection, out_clean = output_j[i][:, :3, ...], output_j[i][:, 3:, ...]
332
+ self.output_j.append(out_clean)
333
+ self.output_j.append(out_reflection)
334
+ return self.output_i, self.output_j
335
+
336
+
337
+ @torch.no_grad()
338
+ def forward_eval(self):
339
+
340
+ self.output_j=[]
341
+ input_i = self.input
342
+ if self.vgg is not None:
343
+ input_i = self.hyper_column(input_i)
344
+ ipt = self.net_c(input_i)
345
+
346
+ output_i, output_j = self.net_i(input_i,ipt,prompt=True)
347
+ self.output_i = output_i #alpha * output_i + beta
348
+ for i in range(self.opt.loss_col):
349
+ out_reflection, out_clean = output_j[i][:, :3, ...], output_j[i][:, 3:, ...]
350
+ self.output_j.append(out_clean)
351
+ self.output_j.append(out_reflection)
352
+ return self.output_i, self.output_j
353
+
354
+ def optimize_parameters(self):
355
+ self._train()
356
+ self.forward()
357
+ self.optimizer_G.zero_grad()
358
+ self.backward_G()
359
+ self.optimizer_G.step()
360
+
361
+ def return_output(self):
362
+ output_clean = self.output_j[1]
363
+ output_reflection = self.output_j[0]
364
+ output_clean = tensor2im(output_clean).astype(np.uint8)
365
+ output_reflection = tensor2im(output_reflection).astype(np.uint8)
366
+ input=tensor2im(self.input)
367
+ return output_clean,output_reflection,input
368
+ def exclusion_loss(self, img_T, img_R, level=3, eps=1e-6):
369
+ loss_gra=[]
370
+ weight=0.25
371
+ for i in range(4):
372
+ grad_x_loss = []
373
+ grad_y_loss = []
374
+ img_T=self.output_j[2*i]
375
+ img_R=self.output_j[2*i+1]
376
+ for l in range(level):
377
+ grad_x_T, grad_y_T = self.compute_grad(img_T)
378
+ grad_x_R, grad_y_R = self.compute_grad(img_R)
379
+
380
+ alphax = (2.0 * torch.mean(torch.abs(grad_x_T))) / (torch.mean(torch.abs(grad_x_R)) + eps)
381
+ alphay = (2.0 * torch.mean(torch.abs(grad_y_T))) / (torch.mean(torch.abs(grad_y_R)) + eps)
382
+
383
+ gradx1_s = (torch.sigmoid(grad_x_T) * 2) - 1 # mul 2 minus 1 is to change sigmoid into tanh
384
+ grady1_s = (torch.sigmoid(grad_y_T) * 2) - 1
385
+ gradx2_s = (torch.sigmoid(grad_x_R * alphax) * 2) - 1
386
+ grady2_s = (torch.sigmoid(grad_y_R * alphay) * 2) - 1
387
+
388
+ grad_x_loss.append(((torch.mean(torch.mul(gradx1_s.pow(2), gradx2_s.pow(2)))) + eps) ** 0.25)
389
+ grad_y_loss.append(((torch.mean(torch.mul(grady1_s.pow(2), grady2_s.pow(2)))) + eps) ** 0.25)
390
+
391
+ img_T = F.interpolate(img_T, scale_factor=0.5, mode='bilinear')
392
+ img_R = F.interpolate(img_R, scale_factor=0.5, mode='bilinear')
393
+ loss_gradxy = torch.sum(sum(grad_x_loss) / 3) + torch.sum(sum(grad_y_loss) / 3)
394
+ loss_gra.append(loss_gradxy*weight)
395
+ weight+=0.25
396
+
397
+
398
+ return sum(loss_gra) / 2
399
+
400
+ def contain_loss(self, img_T, img_R, img_I, eps=1e-6):
401
+ pix_num = np.prod(img_I.shape)
402
+ predict_tx, predict_ty = self.compute_grad(img_T)
403
+ predict_tx, predict_ty = self.compute_grad(img_T)
404
+ predict_rx, predict_ry = self.compute_grad(img_R)
405
+ input_x, input_y = self.compute_grad(img_I)
406
+
407
+ out = torch.norm(predict_tx / (input_x + eps), 2) ** 2 + \
408
+ torch.norm(predict_ty / (input_y + eps), 2) ** 2 + \
409
+ torch.norm(predict_rx / (input_x + eps), 2) ** 2 + \
410
+ torch.norm(predict_ry / (input_y + eps), 2) ** 2
411
+
412
+ return out / pix_num
413
+
414
+ def compute_grad(self, img):
415
+ gradx = img[:, :, 1:, :] - img[:, :, :-1, :]
416
+ grady = img[:, :, :, 1:] - img[:, :, :, :-1]
417
+ return gradx, grady
418
+
419
+ def load(self, model, resume_epoch=None):
420
+ icnn_path = model.opt.icnn_path
421
+ state_dict = torch.load(icnn_path)
422
+ model.net_i.load_state_dict(state_dict['icnn'])
423
+ return state_dict
424
+
425
+ def state_dict(self):
426
+ state_dict = {
427
+ 'icnn': self.net_i.state_dict(),
428
+ 'opt_g': self.optimizer_G.state_dict(),
429
+ #'ema' : self.ema.state_dict(),
430
+ 'epoch': self.epoch, 'iterations': self.iterations
431
+ }
432
+
433
+ if self.opt.lambda_gan > 0:
434
+ state_dict.update({
435
+ 'opt_d': self.optimizer_D.state_dict(),
436
+ 'netD': self.netD.state_dict(),
437
+ })
438
+
439
+ return state_dict
440
+ class AvgPool2d(nn.Module):
441
+ def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
442
+ super().__init__()
443
+ self.kernel_size = kernel_size
444
+ self.base_size = base_size
445
+ self.auto_pad = auto_pad
446
+
447
+ # only used for fast implementation
448
+ self.fast_imp = fast_imp
449
+ self.rs = [5, 4, 3, 2, 1]
450
+ self.max_r1 = self.rs[0]
451
+ self.max_r2 = self.rs[0]
452
+ self.train_size = train_size
453
+
454
+ def extra_repr(self) -> str:
455
+ return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
456
+ self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
457
+ )
458
+
459
+ def forward(self, x):
460
+ if self.kernel_size is None and self.base_size:
461
+ train_size = self.train_size
462
+ if isinstance(self.base_size, int):
463
+ self.base_size = (self.base_size, self.base_size)
464
+ self.kernel_size = list(self.base_size)
465
+ self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
466
+ self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
467
+
468
+ # only used for fast implementation
469
+ self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
470
+ self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
471
+
472
+ if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
473
+ return F.adaptive_avg_pool2d(x, 1)
474
+
475
+ if self.fast_imp: # Non-equivalent implementation but faster
476
+ h, w = x.shape[2:]
477
+ if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
478
+ out = F.adaptive_avg_pool2d(x, 1)
479
+ else:
480
+ r1 = [r for r in self.rs if h % r == 0][0]
481
+ r2 = [r for r in self.rs if w % r == 0][0]
482
+ # reduction_constraint
483
+ r1 = min(self.max_r1, r1)
484
+ r2 = min(self.max_r2, r2)
485
+ s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
486
+ n, c, h, w = s.shape
487
+ k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
488
+ out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
489
+ out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
490
+ else:
491
+ n, c, h, w = x.shape
492
+ s = x.cumsum(dim=-1).cumsum_(dim=-2)
493
+ s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
494
+ k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
495
+ s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
496
+ out = s4 + s1 - s2 - s3
497
+ out = out / (k1 * k2)
498
+
499
+ if self.auto_pad:
500
+ n, c, h, w = x.shape
501
+ _h, _w = out.shape[2:]
502
+ # print(x.shape, self.kernel_size)
503
+ pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
504
+ out = torch.nn.functional.pad(out, pad2d, mode='replicate')
505
+
506
+ return out
507
+
508
+ def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
509
+ for n, m in model.named_children():
510
+ if len(list(m.children())) > 0:
511
+ ## compound module, go inside it
512
+ replace_layers(m, base_size, train_size, fast_imp, **kwargs)
513
+
514
+ if isinstance(m, nn.AdaptiveAvgPool2d):
515
+ pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
516
+ assert m.output_size == 1
517
+ setattr(model, n, pool)