lime-j commited on
Commit
dfd3b88
·
1 Parent(s): a0d8acf
Files changed (49) hide show
  1. =2.0 +26 -0
  2. README.md +1 -1
  3. data/VOC2012_224_train_png.txt +0 -0
  4. data/__pycache__/dataset_sir.cpython-38.pyc +0 -0
  5. data/__pycache__/image_folder.cpython-38.pyc +0 -0
  6. data/__pycache__/torchdata.cpython-38.pyc +0 -0
  7. data/__pycache__/transforms.cpython-38.pyc +0 -0
  8. data/dataset_sir.py +0 -332
  9. data/image_folder.py +0 -51
  10. data/real_test.txt +0 -20
  11. data/torchdata.py +0 -67
  12. data/transforms.py +0 -301
  13. engine.py +0 -178
  14. figures/Input_car.jpg +0 -0
  15. figures/Input_class.png +0 -3
  16. figures/Input_green.png +0 -3
  17. figures/Ours_car.png +0 -3
  18. figures/Ours_class.png +0 -3
  19. figures/Ours_green.png +0 -3
  20. figures/Ours_white.png +0 -3
  21. figures/Title.png +0 -0
  22. figures/input_white.jpg +0 -0
  23. figures/net.png +0 -3
  24. figures/result.png +0 -3
  25. figures/vis.png +0 -3
  26. models/__init__.py +0 -11
  27. models/__pycache__/__init__.cpython-310.pyc +0 -0
  28. models/__pycache__/cls_model_eval_nocls_reg.cpython-310.pyc +0 -0
  29. models/__pycache__/losses.cpython-310.pyc +0 -0
  30. models/base_model.py +0 -71
  31. models/cls_model_eval_nocls_reg.py +0 -517
  32. models/losses.py +0 -468
  33. models/losses_opt.py +0 -404
  34. models/networks.py +0 -335
  35. models/vgg.py +0 -66
  36. models/vit_feature_extractor.py +0 -164
  37. options/__init__.py +0 -0
  38. options/__pycache__/__init__.cpython-38.pyc +0 -0
  39. options/__pycache__/base_option.cpython-38.pyc +0 -0
  40. options/base_option.py +0 -47
  41. options/net_options/__init__.py +0 -0
  42. options/net_options/__pycache__/__init__.cpython-38.pyc +0 -0
  43. options/net_options/__pycache__/base_options.cpython-38.pyc +0 -0
  44. options/net_options/__pycache__/train_options.cpython-38.pyc +0 -0
  45. options/net_options/base_options.py +0 -71
  46. options/net_options/train_options.py +0 -75
  47. pretrained/README.md +0 -3
  48. script.py +0 -64
  49. test_sirs.py +0 -60
=2.0 ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Requirement already satisfied: torch in /usr/local/lib/python3.10/site-packages (2.6.0)
2
+ Requirement already satisfied: torchvision in /usr/local/lib/python3.10/site-packages (0.21.0)
3
+ Requirement already satisfied: filelock in /usr/local/lib/python3.10/site-packages (from torch) (3.17.0)
4
+ Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.10/site-packages (from torch) (4.12.2)
5
+ Requirement already satisfied: networkx in /usr/local/lib/python3.10/site-packages (from torch) (3.4.2)
6
+ Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/site-packages (from torch) (3.1.5)
7
+ Requirement already satisfied: fsspec in /usr/local/lib/python3.10/site-packages (from torch) (2024.12.0)
8
+ Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch) (12.4.127)
9
+ Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch) (12.4.127)
10
+ Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch) (12.4.127)
11
+ Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/site-packages (from torch) (9.1.0.70)
12
+ Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.10/site-packages (from torch) (12.4.5.8)
13
+ Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.10/site-packages (from torch) (11.2.1.3)
14
+ Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.10/site-packages (from torch) (10.3.5.147)
15
+ Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.10/site-packages (from torch) (11.6.1.9)
16
+ Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.10/site-packages (from torch) (12.3.1.170)
17
+ Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.10/site-packages (from torch) (0.6.2)
18
+ Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/site-packages (from torch) (2.21.5)
19
+ Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch) (12.4.127)
20
+ Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch) (12.4.127)
21
+ Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.10/site-packages (from torch) (3.2.0)
22
+ Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/site-packages (from torch) (1.13.1)
23
+ Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/site-packages (from sympy==1.13.1->torch) (1.3.0)
24
+ Requirement already satisfied: numpy in /usr/local/lib/python3.10/site-packages (from torchvision) (2.2.3)
25
+ Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/site-packages (from torchvision) (10.4.0)
26
+ Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/site-packages (from jinja2->torch) (2.1.5)
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Lime Evangelion
3
  emoji: 💻
4
  colorFrom: indigo
5
  colorTo: blue
 
1
  ---
2
+ title: RDNet
3
  emoji: 💻
4
  colorFrom: indigo
5
  colorTo: blue
data/VOC2012_224_train_png.txt DELETED
The diff for this file is too large to render. See raw diff
 
data/__pycache__/dataset_sir.cpython-38.pyc DELETED
Binary file (10.9 kB)
 
data/__pycache__/image_folder.cpython-38.pyc DELETED
Binary file (1.58 kB)
 
data/__pycache__/torchdata.cpython-38.pyc DELETED
Binary file (2.86 kB)
 
data/__pycache__/transforms.cpython-38.pyc DELETED
Binary file (9.37 kB)
 
data/dataset_sir.py DELETED
@@ -1,332 +0,0 @@
1
- import math
2
- import os.path
3
- import os.path
4
- import random
5
- from os.path import join
6
-
7
- import cv2
8
- import numpy as np
9
- import torch.utils.data
10
- import torchvision.transforms.functional as TF
11
- from PIL import Image
12
- from scipy.signal import convolve2d
13
-
14
- from data.image_folder import make_dataset
15
- from data.torchdata import Dataset as BaseDataset
16
- from data.transforms import to_tensor
17
-
18
-
19
- def __scale_width(img, target_width):
20
- ow, oh = img.size
21
- if (ow == target_width):
22
- return img
23
- w = target_width
24
- h = int(target_width * oh / ow)
25
- h = math.ceil(h / 2.) * 2 # round up to even
26
- return img.resize((w, h), Image.BICUBIC)
27
-
28
-
29
- def __scale_height(img, target_height):
30
- ow, oh = img.size
31
- if (oh == target_height):
32
- return img
33
- h = target_height
34
- w = int(target_height * ow / oh)
35
- w = math.ceil(w / 2.) * 2
36
- return img.resize((w, h), Image.BICUBIC)
37
-
38
-
39
- def paired_data_transforms(img_1, img_2, unaligned_transforms=False):
40
- def get_params(img, output_size):
41
- w, h = img.size
42
- th, tw = output_size
43
- if w == tw and h == th:
44
- return 0, 0, h, w
45
-
46
- i = random.randint(0, h - th)
47
- j = random.randint(0, w - tw)
48
- return i, j, th, tw
49
-
50
- target_size = int(random.randint(320, 640) / 2.) * 2
51
- ow, oh = img_1.size
52
- if ow >= oh:
53
- img_1 = __scale_height(img_1, target_size)
54
- img_2 = __scale_height(img_2, target_size)
55
- else:
56
- img_1 = __scale_width(img_1, target_size)
57
- img_2 = __scale_width(img_2, target_size)
58
-
59
- if random.random() < 0.5:
60
- img_1 = TF.hflip(img_1)
61
- img_2 = TF.hflip(img_2)
62
-
63
- if random.random() < 0.5:
64
- angle = random.choice([90, 180, 270])
65
- img_1 = TF.rotate(img_1, angle)
66
- img_2 = TF.rotate(img_2, angle)
67
-
68
- i, j, h, w = get_params(img_1, (320, 320))
69
- img_1 = TF.crop(img_1, i, j, h, w)
70
-
71
- if unaligned_transforms:
72
- # print('random shift')
73
- i_shift = random.randint(-10, 10)
74
- j_shift = random.randint(-10, 10)
75
- i += i_shift
76
- j += j_shift
77
-
78
- img_2 = TF.crop(img_2, i, j, h, w)
79
-
80
- return img_1, img_2
81
-
82
-
83
- class ReflectionSynthesis(object):
84
- def __init__(self):
85
- # Kernel Size of the Gaussian Blurry
86
- self.kernel_sizes = [5, 7, 9, 11]
87
- self.kernel_probs = [0.1, 0.2, 0.3, 0.4]
88
-
89
- # Sigma of the Gaussian Blurry
90
- self.sigma_range = [2, 5]
91
- self.alpha_range = [0.8, 1.0]
92
- self.beta_range = [0.4, 1.0]
93
-
94
- def __call__(self, T_, R_):
95
- T_ = np.asarray(T_, np.float32) / 255.
96
- R_ = np.asarray(R_, np.float32) / 255.
97
-
98
- kernel_size = np.random.choice(self.kernel_sizes, p=self.kernel_probs)
99
- sigma = np.random.uniform(self.sigma_range[0], self.sigma_range[1])
100
- kernel = cv2.getGaussianKernel(kernel_size, sigma)
101
- kernel2d = np.dot(kernel, kernel.T)
102
- for i in range(3):
103
- R_[..., i] = convolve2d(R_[..., i], kernel2d, mode='same')
104
-
105
- a = np.random.uniform(self.alpha_range[0], self.alpha_range[1])
106
- b = np.random.uniform(self.beta_range[0], self.beta_range[1])
107
- T, R = a * T_, b * R_
108
-
109
- if random.random() < 0.7:
110
- I = T + R - T * R
111
-
112
- else:
113
- I = T + R
114
- if np.max(I) > 1:
115
- m = I[I > 1]
116
- m = (np.mean(m) - 1) * 1.3
117
- I = np.clip(T + np.clip(R - m, 0, 1), 0, 1)
118
-
119
- return T_, R_, I
120
-
121
-
122
- class DataLoader(torch.utils.data.DataLoader):
123
- def __init__(self, dataset, batch_size, shuffle, *args, **kwargs):
124
- super(DataLoader, self).__init__(dataset, batch_size, shuffle, *args, **kwargs)
125
- self.shuffle = shuffle
126
-
127
- def reset(self):
128
- if self.shuffle:
129
- print('Reset Dataset...')
130
- self.dataset.reset()
131
-
132
-
133
- class DSRDataset(BaseDataset):
134
- def __init__(self, datadir, fns=None, size=None, enable_transforms=True):
135
- super(DSRDataset, self).__init__()
136
- self.size = size
137
- self.datadir = datadir
138
- self.enable_transforms = enable_transforms
139
- sortkey = lambda key: os.path.split(key)[-1]
140
- self.paths = sorted(make_dataset(datadir, fns), key=sortkey)
141
- if size is not None:
142
- self.paths = np.random.choice(self.paths, size)
143
-
144
- self.syn_model = ReflectionSynthesis()
145
- self.reset(shuffle=False)
146
-
147
- def reset(self, shuffle=True):
148
- if shuffle:
149
- random.shuffle(self.paths)
150
- num_paths = len(self.paths) // 2
151
- self.B_paths = self.paths[0:num_paths]
152
- self.R_paths = self.paths[num_paths:2 * num_paths]
153
-
154
- def data_synthesis(self, t_img, r_img):
155
- if self.enable_transforms:
156
- t_img, r_img = paired_data_transforms(t_img, r_img)
157
-
158
- t_img, r_img, m_img = self.syn_model(t_img, r_img)
159
-
160
- B = TF.to_tensor(t_img)
161
- R = TF.to_tensor(r_img)
162
- M = TF.to_tensor(m_img)
163
-
164
- return B, R, M
165
-
166
- def __getitem__(self, index):
167
- index_B = index % len(self.B_paths)
168
- index_R = index % len(self.R_paths)
169
-
170
- B_path = self.B_paths[index_B]
171
- R_path = self.R_paths[index_R]
172
-
173
- t_img = Image.open(B_path).convert('RGB')
174
- r_img = Image.open(R_path).convert('RGB')
175
-
176
- B, R, M = self.data_synthesis(t_img, r_img)
177
- fn = os.path.basename(B_path)
178
- return {'input': M, 'target_t': B, 'target_r': M-B, 'fn': fn, 'real': False}
179
-
180
- def __len__(self):
181
- if self.size is not None:
182
- return min(max(len(self.B_paths), len(self.R_paths)), self.size)
183
- else:
184
- return max(len(self.B_paths), len(self.R_paths))
185
-
186
-
187
- class DSRTestDataset(BaseDataset):
188
- def __init__(self, datadir, fns=None, size=None, enable_transforms=False, unaligned_transforms=False,
189
- round_factor=1, flag=None, if_align=True):
190
- super(DSRTestDataset, self).__init__()
191
- self.size = size
192
- self.datadir = datadir
193
- self.fns = fns or os.listdir(join(datadir, 'blended'))
194
- self.enable_transforms = enable_transforms
195
- self.unaligned_transforms = unaligned_transforms
196
- self.round_factor = round_factor
197
- self.flag = flag
198
- self.if_align = True # if_align
199
-
200
- if size is not None:
201
- self.fns = self.fns[:size]
202
-
203
- def align(self, x1, x2):
204
- h, w = x1.height, x1.width
205
- h, w = h // 32 * 32, w // 32 * 32
206
- x1 = x1.resize((w, h))
207
- x2 = x2.resize((w, h))
208
- return x1, x2
209
-
210
- def __getitem__(self, index):
211
- fn = self.fns[index]
212
-
213
- t_img = Image.open(join(self.datadir, 'transmission_layer', fn)).convert('RGB')
214
- m_img = Image.open(join(self.datadir, 'blended', fn)).convert('RGB')
215
-
216
- if self.if_align:
217
- t_img, m_img = self.align(t_img, m_img)
218
-
219
- if self.enable_transforms:
220
- t_img, m_img = paired_data_transforms(t_img, m_img, self.unaligned_transforms)
221
-
222
- B = TF.to_tensor(t_img)
223
- M = TF.to_tensor(m_img)
224
-
225
- dic = {'input': M, 'target_t': B, 'fn': fn, 'real': True, 'target_r': M - B}
226
- if self.flag is not None:
227
- dic.update(self.flag)
228
- return dic
229
-
230
- def __len__(self):
231
- if self.size is not None:
232
- return min(len(self.fns), self.size)
233
- else:
234
- return len(self.fns)
235
-
236
-
237
- class SIRTestDataset(BaseDataset):
238
- def __init__(self, datadir, fns=None, size=None, if_align=True):
239
- super(SIRTestDataset, self).__init__()
240
- self.size = size
241
- self.datadir = datadir
242
- self.fns = fns or os.listdir(join(datadir, 'blended'))
243
- self.if_align = if_align
244
-
245
- if size is not None:
246
- self.fns = self.fns[:size]
247
-
248
- def align(self, x1, x2, x3):
249
- h, w = x1.height, x1.width
250
- h, w = h // 32 * 32, w // 32 * 32
251
- x1 = x1.resize((w, h))
252
- x2 = x2.resize((w, h))
253
- x3 = x3.resize((w, h))
254
- return x1, x2, x3
255
-
256
- def __getitem__(self, index):
257
- fn = self.fns[index]
258
-
259
- t_img = Image.open(join(self.datadir, 'transmission_layer', fn)).convert('RGB')
260
- r_img = Image.open(join(self.datadir, 'reflection_layer', fn)).convert('RGB')
261
- m_img = Image.open(join(self.datadir, 'blended', fn)).convert('RGB')
262
-
263
- if self.if_align:
264
- t_img, r_img, m_img = self.align(t_img, r_img, m_img)
265
-
266
- B = TF.to_tensor(t_img)
267
- R = TF.to_tensor(r_img)
268
- M = TF.to_tensor(m_img)
269
-
270
- dic = {'input': M, 'target_t': B, 'fn': fn, 'real': True, 'target_r': R, 'target_r_hat': M - B}
271
- return dic
272
-
273
- def __len__(self):
274
- if self.size is not None:
275
- return min(len(self.fns), self.size)
276
- else:
277
- return len(self.fns)
278
-
279
-
280
- class RealDataset(BaseDataset):
281
- def __init__(self, datadir, fns=None, size=None):
282
- super(RealDataset, self).__init__()
283
- self.size = size
284
- self.datadir = datadir
285
- self.fns = fns or os.listdir(join(datadir))
286
-
287
- if size is not None:
288
- self.fns = self.fns[:size]
289
-
290
- def align(self, x):
291
- h, w = x.height, x.width
292
- h, w = h // 32 * 32, w // 32 * 32
293
- x = x.resize((w, h))
294
- return x
295
-
296
- def __getitem__(self, index):
297
- fn = self.fns[index]
298
- B = -1
299
- m_img = Image.open(join(self.datadir, fn)).convert('RGB')
300
- M = to_tensor(self.align(m_img))
301
- data = {'input': M, 'target_t': B, 'fn': fn}
302
- return data
303
-
304
- def __len__(self):
305
- if self.size is not None:
306
- return min(len(self.fns), self.size)
307
- else:
308
- return len(self.fns)
309
-
310
-
311
- class FusionDataset(BaseDataset):
312
- def __init__(self, datasets, fusion_ratios=None):
313
- self.datasets = datasets
314
- self.size = sum([len(dataset) for dataset in datasets])
315
- self.fusion_ratios = fusion_ratios or [1. / len(datasets)] * len(datasets)
316
- print('[i] using a fusion dataset: %d %s imgs fused with ratio %s' % (
317
- self.size, [len(dataset) for dataset in datasets], self.fusion_ratios))
318
-
319
- def reset(self):
320
- for dataset in self.datasets:
321
- dataset.reset()
322
-
323
- def __getitem__(self, index):
324
- residual = 1
325
- for i, ratio in enumerate(self.fusion_ratios):
326
- if random.random() < ratio / residual or i == len(self.fusion_ratios) - 1:
327
- dataset = self.datasets[i]
328
- return dataset[index % len(dataset)]
329
- residual -= ratio
330
-
331
- def __len__(self):
332
- return self.size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/image_folder.py DELETED
@@ -1,51 +0,0 @@
1
- ###############################################################################
2
- # Code from
3
- # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4
- # Modified the original code so that it also loads images from the current
5
- # directory as well as the subdirectories
6
- ###############################################################################
7
-
8
- import torch.utils.data as data
9
-
10
- from PIL import Image
11
- import os
12
- import os.path
13
-
14
- IMG_EXTENSIONS = [
15
- '.jpg', '.JPG', '.jpeg', '.JPEG',
16
- '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
17
- ]
18
-
19
-
20
- def read_fns(filename):
21
- with open(filename) as f:
22
- fns = f.readlines()
23
- fns = [fn.strip() for fn in fns]
24
- return fns
25
-
26
-
27
- def is_image_file(filename):
28
- return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
29
-
30
-
31
- def make_dataset(dir, fns=None):
32
- images = []
33
- assert os.path.isdir(dir), '%s is not a valid directory' % dir
34
-
35
- if fns is None:
36
- for root, _, fnames in sorted(os.walk(dir)):
37
- for fname in fnames:
38
- if is_image_file(fname):
39
- path = os.path.join(root, fname)
40
- images.append(path)
41
- else:
42
- for fname in fns:
43
- if is_image_file(fname):
44
- path = os.path.join(dir, fname)
45
- images.append(path)
46
-
47
- return images
48
-
49
-
50
- def default_loader(path):
51
- return Image.open(path).convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/real_test.txt DELETED
@@ -1,20 +0,0 @@
1
- 3.jpg
2
- 4.jpg
3
- 9.jpg
4
- 12.jpg
5
- 15.jpg
6
- 22.jpg
7
- 23.jpg
8
- 25.jpg
9
- 29.jpg
10
- 39.jpg
11
- 46.jpg
12
- 47.jpg
13
- 58.jpg
14
- 86.jpg
15
- 87.jpg
16
- 89.jpg
17
- 93.jpg
18
- 103.jpg
19
- 107.jpg
20
- 110.jpg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/torchdata.py DELETED
@@ -1,67 +0,0 @@
1
- import bisect
2
- import warnings
3
-
4
-
5
- class Dataset(object):
6
- """An abstract class representing a Dataset.
7
-
8
- All other datasets should subclass it. All subclasses should override
9
- ``__len__``, that provides the size of the dataset, and ``__getitem__``,
10
- supporting integer indexing in range from 0 to len(self) exclusive.
11
- """
12
-
13
- def __getitem__(self, index):
14
- raise NotImplementedError
15
-
16
- def __len__(self):
17
- raise NotImplementedError
18
-
19
- def __add__(self, other):
20
- return ConcatDataset([self, other])
21
-
22
- def reset(self):
23
- return
24
-
25
-
26
- class ConcatDataset(Dataset):
27
- """
28
- Dataset to concatenate multiple datasets.
29
- Purpose: useful to assemble different existing datasets, possibly
30
- large-scale datasets as the concatenation operation is done in an
31
- on-the-fly manner.
32
-
33
- Arguments:
34
- datasets (sequence): List of datasets to be concatenated
35
- """
36
-
37
- @staticmethod
38
- def cumsum(sequence):
39
- r, s = [], 0
40
- for e in sequence:
41
- l = len(e)
42
- r.append(l + s)
43
- s += l
44
- return r
45
-
46
- def __init__(self, datasets):
47
- super(ConcatDataset, self).__init__()
48
- assert len(datasets) > 0, 'datasets should not be an empty iterable'
49
- self.datasets = list(datasets)
50
- self.cumulative_sizes = self.cumsum(self.datasets)
51
-
52
- def __len__(self):
53
- return self.cumulative_sizes[-1]
54
-
55
- def __getitem__(self, idx):
56
- dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
57
- if dataset_idx == 0:
58
- sample_idx = idx
59
- else:
60
- sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
61
- return self.datasets[dataset_idx][sample_idx]
62
-
63
- @property
64
- def cummulative_sizes(self):
65
- warnings.warn("cummulative_sizes attribute is renamed to "
66
- "cumulative_sizes", DeprecationWarning, stacklevel=2)
67
- return self.cumulative_sizes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/transforms.py DELETED
@@ -1,301 +0,0 @@
1
- from __future__ import division
2
-
3
- import math
4
- import random
5
-
6
- import torch
7
- from PIL import Image
8
-
9
- try:
10
- import accimage
11
- except ImportError:
12
- accimage = None
13
- import numpy as np
14
- import scipy.stats as st
15
- import cv2
16
- import collections
17
- import torchvision.transforms as transforms
18
- import util.util as util
19
- from scipy.signal import convolve2d
20
-
21
-
22
- # utility
23
- def _is_pil_image(img):
24
- if accimage is not None:
25
- return isinstance(img, (Image.Image, accimage.Image))
26
- else:
27
- return isinstance(img, Image.Image)
28
-
29
-
30
- def _is_tensor_image(img):
31
- return torch.is_tensor(img) and img.ndimension() == 3
32
-
33
-
34
- def _is_numpy_image(img):
35
- return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
36
-
37
-
38
- def arrshow(arr):
39
- Image.fromarray(arr.astype(np.uint8)).show()
40
-
41
-
42
- def get_transform(opt):
43
- transform_list = []
44
- osizes = util.parse_args(opt.loadSize)
45
- fineSize = util.parse_args(opt.fineSize)
46
- if opt.resize_or_crop == 'resize_and_crop':
47
- transform_list.append(
48
- transforms.RandomChoice([
49
- transforms.Resize([osize, osize], Image.BICUBIC) for osize in osizes
50
- ]))
51
- transform_list.append(transforms.RandomCrop(fineSize))
52
- elif opt.resize_or_crop == 'crop':
53
- transform_list.append(transforms.RandomCrop(fineSize))
54
- elif opt.resize_or_crop == 'scale_width':
55
- transform_list.append(transforms.Lambda(
56
- lambda img: __scale_width(img, fineSize)))
57
- elif opt.resize_or_crop == 'scale_width_and_crop':
58
- transform_list.append(transforms.Lambda(
59
- lambda img: __scale_width(img, opt.loadSize)))
60
- transform_list.append(transforms.RandomCrop(opt.fineSize))
61
-
62
- if opt.isTrain and not opt.no_flip:
63
- transform_list.append(transforms.RandomHorizontalFlip())
64
-
65
- return transforms.Compose(transform_list)
66
-
67
-
68
- to_norm_tensor = transforms.Compose([
69
- transforms.ToTensor(),
70
- transforms.Normalize(
71
- (0.5, 0.5, 0.5),
72
- (0.5, 0.5, 0.5)
73
- )
74
- ])
75
-
76
- to_tensor = transforms.ToTensor()
77
-
78
-
79
- def __scale_width(img, target_width):
80
- ow, oh = img.size
81
- if (ow == target_width):
82
- return img
83
- w = target_width
84
- h = int(target_width * oh / ow)
85
- h = math.ceil(h / 2.) * 2 # round up to even
86
- return img.resize((w, h), Image.BICUBIC)
87
-
88
-
89
- # functional
90
- def gaussian_blur(img, kernel_size, sigma):
91
- if not _is_pil_image(img):
92
- raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
93
-
94
- img = np.asarray(img)
95
- # the 3rd dimension (i.e. inter-band) would be filtered which is unwanted for our purpose
96
- # new = gaussian_filter(img, sigma=sigma, truncate=truncate)
97
- if isinstance(kernel_size, int):
98
- kernel_size = (kernel_size, kernel_size)
99
- elif isinstance(kernel_size, collections.Sequence):
100
- assert len(kernel_size) == 2
101
- new = cv2.GaussianBlur(img, kernel_size, sigma) # apply gaussian filter band by band
102
- return Image.fromarray(new)
103
-
104
-
105
- # transforms
106
- class GaussianBlur(object):
107
- def __init__(self, kernel_size=11, sigma=3):
108
- self.kernel_size = kernel_size
109
- self.sigma = sigma
110
-
111
- def __call__(self, img):
112
- return gaussian_blur(img, self.kernel_size, self.sigma)
113
-
114
-
115
- class ReflectionSythesis_0(object):
116
- """Reflection image data synthesis for weakly-supervised learning
117
- of ICCV 2017 paper *"A Generic Deep Architecture for Single Image Reflection Removal and Image Smoothing"*
118
- """
119
-
120
- def __init__(self, kernel_sizes=None, low_sigma=2, high_sigma=5, low_gamma=1.3,
121
- high_gamma=1.3, low_delta=0.4, high_delta=1.8):
122
- self.kernel_sizes = kernel_sizes or [11]
123
- self.low_sigma = low_sigma
124
- self.high_sigma = high_sigma
125
- self.low_gamma = low_gamma
126
- self.high_gamma = high_gamma
127
- self.low_delta = low_delta
128
- self.high_delta = high_delta
129
- print('[i] reflection sythesis model: {}'.format({
130
- 'kernel_sizes': kernel_sizes, 'low_sigma': low_sigma, 'high_sigma': high_sigma,
131
- 'low_gamma': low_gamma, 'high_gamma': high_gamma}))
132
-
133
- def __call__(self, B, R):
134
- if not _is_pil_image(B):
135
- raise TypeError('B should be PIL Image. Got {}'.format(type(B)))
136
- if not _is_pil_image(R):
137
- raise TypeError('R should be PIL Image. Got {}'.format(type(R)))
138
- B_ = np.asarray(B, np.float32)
139
- if random.random() < 0.4:
140
- B_ = np.tile(np.random.uniform(0, 30, (1, 1, 1)), B_.shape) / 255.
141
- else:
142
- B_ = np.tile(np.random.normal(50, 50, (1, 1, 3)), (B_.shape[0], B_.shape[1], 1)).clip(0, 255) / 255.
143
- R_ = np.asarray(R, np.float32) / 255.
144
-
145
- kernel_size = np.random.choice(self.kernel_sizes)
146
- sigma = np.random.uniform(self.low_sigma, self.high_sigma)
147
- gamma = np.random.uniform(self.low_gamma, self.high_gamma)
148
- delta = np.random.uniform(self.low_delta, self.high_delta)
149
- R_blur = R_
150
- kernel = cv2.getGaussianKernel(11, sigma)
151
- kernel2d = np.dot(kernel, kernel.T)
152
-
153
- for i in range(3):
154
- R_blur[..., i] = convolve2d(R_blur[..., i], kernel2d, mode='same')
155
-
156
- R_blur = np.clip(R_blur - np.mean(R_blur) * gamma, 0, 1)
157
- R_blur = np.clip(R_blur * delta, 0, 1)
158
- M_ = np.clip(R_blur + B_, 0, 1)
159
-
160
- return B_, R_blur, M_
161
-
162
-
163
- class ReflectionSythesis_1(object):
164
- """Reflection image data synthesis for weakly-supervised learning
165
- of ICCV 2017 paper *"A Generic Deep Architecture for Single Image Reflection Removal and Image Smoothing"*
166
- """
167
-
168
- def __init__(self, kernel_sizes=None, low_sigma=2, high_sigma=5, low_gamma=1.3, high_gamma=1.3):
169
- self.kernel_sizes = kernel_sizes or [11]
170
- self.low_sigma = low_sigma
171
- self.high_sigma = high_sigma
172
- self.low_gamma = low_gamma
173
- self.high_gamma = high_gamma
174
- print('[i] reflection sythesis model: {}'.format({
175
- 'kernel_sizes': kernel_sizes, 'low_sigma': low_sigma, 'high_sigma': high_sigma,
176
- 'low_gamma': low_gamma, 'high_gamma': high_gamma}))
177
-
178
- def __call__(self, B, R):
179
- if not _is_pil_image(B):
180
- raise TypeError('B should be PIL Image. Got {}'.format(type(B)))
181
- if not _is_pil_image(R):
182
- raise TypeError('R should be PIL Image. Got {}'.format(type(R)))
183
-
184
- B_ = np.asarray(B, np.float32) / 255.
185
- R_ = np.asarray(R, np.float32) / 255.
186
-
187
- kernel_size = np.random.choice(self.kernel_sizes)
188
- sigma = np.random.uniform(self.low_sigma, self.high_sigma)
189
- gamma = np.random.uniform(self.low_gamma, self.high_gamma)
190
- R_blur = R_
191
- kernel = cv2.getGaussianKernel(11, sigma)
192
- kernel2d = np.dot(kernel, kernel.T)
193
-
194
- for i in range(3):
195
- R_blur[..., i] = convolve2d(R_blur[..., i], kernel2d, mode='same')
196
-
197
- M_ = B_ + R_blur
198
-
199
- if np.max(M_) > 1:
200
- m = M_[M_ > 1]
201
- m = (np.mean(m) - 1) * gamma
202
- R_blur = np.clip(R_blur - m, 0, 1)
203
- M_ = np.clip(R_blur + B_, 0, 1)
204
-
205
- return B_, R_blur, M_
206
-
207
-
208
- class Sobel(object):
209
- def __call__(self, img):
210
- if not _is_pil_image(img):
211
- raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
212
-
213
- gray_img = np.array(img.convert('L'))
214
- x = cv2.Sobel(gray_img, cv2.CV_16S, 1, 0)
215
- y = cv2.Sobel(gray_img, cv2.CV_16S, 0, 1)
216
-
217
- absX = cv2.convertScaleAbs(x)
218
- absY = cv2.convertScaleAbs(y)
219
-
220
- dst = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
221
- return Image.fromarray(dst)
222
-
223
-
224
- class ReflectionSythesis_2(object):
225
- """Reflection image data synthesis for weakly-supervised learning
226
- of CVPR 2018 paper *"Single Image Reflection Separation with Perceptual Losses"*
227
- """
228
-
229
- def __init__(self, kernel_sizes=None):
230
- self.kernel_sizes = kernel_sizes or np.linspace(1, 5, 80)
231
-
232
- @staticmethod
233
- def gkern(kernlen=100, nsig=1):
234
- """Returns a 2D Gaussian kernel array."""
235
- interval = (2 * nsig + 1.) / (kernlen)
236
- x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kernlen + 1)
237
- kern1d = np.diff(st.norm.cdf(x))
238
- kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
239
- kernel = kernel_raw / kernel_raw.sum()
240
- kernel = kernel / kernel.max()
241
- return kernel
242
-
243
- def __call__(self, t, r):
244
- t = np.float32(t) / 255.
245
- r = np.float32(r) / 255.
246
- ori_t = t
247
- # create a vignetting mask
248
- g_mask = self.gkern(560, 3)
249
- g_mask = np.dstack((g_mask, g_mask, g_mask))
250
- sigma = self.kernel_sizes[np.random.randint(0, len(self.kernel_sizes))]
251
-
252
- t = np.power(t, 2.2)
253
- r = np.power(r, 2.2)
254
-
255
- sz = int(2 * np.ceil(2 * sigma) + 1)
256
-
257
- r_blur = cv2.GaussianBlur(r, (sz, sz), sigma, sigma, 0)
258
- blend = r_blur + t
259
-
260
- att = 1.08 + np.random.random() / 10.0
261
-
262
- for i in range(3):
263
- maski = blend[:, :, i] > 1
264
- mean_i = max(1., np.sum(blend[:, :, i] * maski) / (maski.sum() + 1e-6))
265
- r_blur[:, :, i] = r_blur[:, :, i] - (mean_i - 1) * att
266
- r_blur[r_blur >= 1] = 1
267
- r_blur[r_blur <= 0] = 0
268
-
269
- h, w = r_blur.shape[0:2]
270
- neww = np.random.randint(0, 560 - w - 10)
271
- newh = np.random.randint(0, 560 - h - 10)
272
- alpha1 = g_mask[newh:newh + h, neww:neww + w, :]
273
- alpha2 = 1 - np.random.random() / 5.0
274
- r_blur_mask = np.multiply(r_blur, alpha1)
275
- blend = r_blur_mask + t * alpha2
276
-
277
- t = np.power(t, 1 / 2.2)
278
- r_blur_mask = np.power(r_blur_mask, 1 / 2.2)
279
- blend = np.power(blend, 1 / 2.2)
280
- blend[blend >= 1] = 1
281
- blend[blend <= 0] = 0
282
-
283
- return np.float32(ori_t), np.float32(r_blur_mask), np.float32(blend)
284
-
285
-
286
- # Examples
287
- if __name__ == '__main__':
288
- """cv2 imread"""
289
- # img = cv2.imread('testdata_reflection_real/19-input.png')
290
- # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
291
- # img2 = cv2.GaussianBlur(img, (11,11), 3)
292
-
293
- """Sobel Operator"""
294
- # img = np.array(Image.open('datasets/VOC224/train/B/2007_000250.png').convert('L'))
295
-
296
- """Reflection Sythesis"""
297
- b = Image.open('')
298
- r = Image.open('')
299
- G = ReflectionSythesis_0()
300
- m, r = G(b, r)
301
- r.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
engine.py DELETED
@@ -1,178 +0,0 @@
1
- import torch
2
- import util.util as util
3
- from models import make_model
4
- import time
5
- import os
6
- import sys
7
- from os.path import join
8
- from util.visualizer import Visualizer
9
- import tqdm
10
- import visdom
11
- import numpy as np
12
- from tools import mutils
13
-
14
- class Engine(object):
15
- def __init__(self, opt,eval_dataset_real,eval_dataset_solidobject,eval_dataset_postcard,eval_dataloader_wild):
16
- self.opt = opt
17
- self.writer = None
18
- self.visualizer = None
19
- self.model = None
20
- self.best_val_loss = 1e6
21
- self.eval_dataset_real = eval_dataset_real
22
- self.eval_dataset_solidobject = eval_dataset_solidobject
23
- self.eval_dataset_postcard = eval_dataset_postcard
24
- self.eval_dataloader_wild = eval_dataloader_wild
25
- self.result_dir = os.path.join(f'./experiment/{self.opt.name}/results',
26
- mutils.get_formatted_time())
27
- self.biggest_psnr=0
28
- self.__setup()
29
-
30
- def __setup(self):
31
- self.basedir = join('experiment', self.opt.name)
32
- os.makedirs(self.basedir, exist_ok=True)
33
-
34
- opt = self.opt
35
-
36
- """Model"""
37
- self.model = make_model(self.opt.model) # models.__dict__[self.opt.model]()
38
- self.model.initialize(opt)
39
- if True:
40
- print("IN")
41
- self.writer = util.get_summary_writer(os.path.join(self.basedir, 'logs'))
42
- self.visualizer = Visualizer(opt)
43
-
44
- def train(self, train_loader, **kwargs):
45
- print('\nEpoch: %d' % self.epoch)
46
- avg_meters = util.AverageMeters()
47
- opt = self.opt
48
- model = self.model
49
- epoch = self.epoch
50
-
51
- epoch_start_time = time.time()
52
- for i, data in tqdm.tqdm(enumerate(train_loader)):
53
-
54
- iter_start_time = time.time()
55
- iterations = self.iterations
56
-
57
- model.set_input(data, mode='train')
58
- model.optimize_parameters(**kwargs)
59
-
60
- errors = model.get_current_errors()
61
- avg_meters.update(errors)
62
- util.progress_bar(i, len(train_loader), str(avg_meters))
63
- util.write_loss(self.writer, 'train', avg_meters, iterations)
64
- if iterations%100==0:
65
- imgs=[]
66
- output_clean,output_reflection,input=model.return_output()
67
- # output_clean,input=model.return_output()
68
-
69
- output_clean=np.transpose(output_clean,(2,0,1))/255
70
- #output_reflection = np.transpose(output_reflection, (2, 0, 1))/255
71
- input = np.transpose(input, (2, 0, 1))/255
72
- imgs.append(output_clean)
73
- #imgs.append(output_reflection)
74
- imgs.append(input)
75
- util.get_visual(self.writer,iterations,imgs)
76
- if iterations % opt.print_freq == 0 and opt.display_id != 0:
77
- t = (time.time() - iter_start_time)
78
-
79
- self.iterations += 1
80
-
81
- self.epoch += 1
82
-
83
- if True:#not self.opt.no_log:
84
- if self.epoch % opt.save_epoch_freq == 0:
85
- save_dir = os.path.join(self.result_dir, '%03d' % self.epoch)
86
- os.makedirs(save_dir, exist_ok=True)
87
- matrix_real=self.eval(self.eval_dataset_real, dataset_name='testdata_real20', savedir=save_dir, suffix='real20')
88
- matrix_solid=self.eval(self.eval_dataset_solidobject, dataset_name='testdata_solidobject', savedir=save_dir,
89
- suffix='solidobject')
90
- matrix_post=self.eval(self.eval_dataset_postcard, dataset_name='testdata_postcard', savedir=save_dir, suffix='postcard')
91
- matrix_wild=self.eval(self.eval_dataloader_wild, dataset_name='testdata_wild', savedir=save_dir, suffix='wild')
92
- sum_PSNR_real=matrix_real['PSNR']*20
93
- sum_PSNR_solid=matrix_solid['PSNR']*200
94
- sum_PSNR_post=matrix_post['PSNR']*199
95
- sum_PSNR_wild=matrix_wild['PSNR']*55
96
- print("sum_PSNR_real: ",matrix_real['PSNR'],"sum_PSNR_solid: ",matrix_solid['PSNR'],"sum_PSNR_post: ",matrix_post['PSNR'],"sum_PSNR_wild: ",matrix_wild['PSNR'])
97
- sum_PSNR = float(sum_PSNR_real + sum_PSNR_solid + sum_PSNR_post + sum_PSNR_wild)/474.0
98
- print('总PSNR:', sum_PSNR)
99
- if sum_PSNR>self.biggest_psnr:
100
- self.biggest_psnr=sum_PSNR
101
- print('saving the model at epoch %d, iters %d' %(self.epoch, self.iterations))
102
- model.save()
103
- print('highest: ',self.biggest_psnr,' name: ',opt.name)
104
-
105
- print('saving the latest model at the end of epoch %d, iters %d' %
106
- (self.epoch, self.iterations))
107
- model.save(label='latest')
108
-
109
- print('Time Taken: %d sec' %
110
- (time.time() - epoch_start_time))
111
-
112
- # model.update_learning_rate()
113
- try:
114
- train_loader.reset()
115
- except:
116
- pass
117
-
118
- def eval(self, val_loader, dataset_name, savedir='./tmp', loss_key=None, **kwargs):
119
- # print(dataset_name)
120
- if savedir is not None:
121
- os.makedirs(savedir, exist_ok=True)
122
- self.f = open(os.path.join(savedir, 'metrics.txt'), 'w+')
123
- self.f.write(dataset_name + '\n')
124
- avg_meters = util.AverageMeters()
125
- model = self.model
126
- opt = self.opt
127
- with torch.no_grad():
128
- for i, data in enumerate(val_loader):
129
- if self.opt.select is not None and data['fn'][0] not in [f'{self.opt.select}.jpg']:
130
- continue
131
- #print(data.shape())
132
- index = model.eval(data, savedir=savedir, **kwargs)
133
-
134
- # print(data['fn'][0], index)
135
- if savedir is not None:
136
- self.f.write(f"{data['fn'][0]} {index['PSNR']} {index['SSIM']}\n")
137
- avg_meters.update(index)
138
- util.progress_bar(i, len(val_loader), str(avg_meters))
139
-
140
- if not opt.no_log:
141
- util.write_loss(self.writer, join('eval', dataset_name), avg_meters, self.epoch)
142
-
143
- if loss_key is not None:
144
- val_loss = avg_meters[loss_key]
145
- if val_loss < self.best_val_loss:
146
- self.best_val_loss = val_loss
147
- print('saving the best model at the end of epoch %d, iters %d' %
148
- (self.epoch, self.iterations))
149
- model.save(label='best_{}_{}'.format(loss_key, dataset_name))
150
-
151
- return avg_meters
152
-
153
- def test(self, test_loader, savedir=None, **kwargs):
154
- model = self.model
155
- opt = self.opt
156
- with torch.no_grad():
157
- for i, data in enumerate(test_loader):
158
- model.test(data, savedir=savedir, **kwargs)
159
- util.progress_bar(i, len(test_loader))
160
-
161
- def save_eval(self, label):
162
- self.model.save_eval(label)
163
-
164
- @property
165
- def iterations(self):
166
- return self.model.iterations
167
-
168
- @iterations.setter
169
- def iterations(self, i):
170
- self.model.iterations = i
171
-
172
- @property
173
- def epoch(self):
174
- return self.model.epoch
175
-
176
- @epoch.setter
177
- def epoch(self, e):
178
- self.model.epoch = e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
figures/Input_car.jpg DELETED
Binary file (26.8 kB)
 
figures/Input_class.png DELETED

Git LFS Details

  • SHA256: 9b3823f5b2f4319e23470a1a747bb2974ddc63f323fed61eb8ceedfce4d48343
  • Pointer size: 131 Bytes
  • Size of remote file: 246 kB
figures/Input_green.png DELETED

Git LFS Details

  • SHA256: 62805a64a7167f0000a4ec1c8e92f0b45a2f7f6fdd9ec1bb7d623ae2f5d5cffe
  • Pointer size: 131 Bytes
  • Size of remote file: 418 kB
figures/Ours_car.png DELETED

Git LFS Details

  • SHA256: 313fbf8070c481775b44153eaea645f35ca8112d7616b5af8ab2e982a37e030e
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB
figures/Ours_class.png DELETED

Git LFS Details

  • SHA256: e4d97e42e8953fb7c5af9b8d7cfd2123ffeb10e734f50f98bd40b7f531f2f02b
  • Pointer size: 131 Bytes
  • Size of remote file: 280 kB
figures/Ours_green.png DELETED

Git LFS Details

  • SHA256: ee3fb53a2f9f410c2e3b8d9679ba3296034786c922fcc70fcd6681af0ce43b36
  • Pointer size: 131 Bytes
  • Size of remote file: 414 kB
figures/Ours_white.png DELETED

Git LFS Details

  • SHA256: 9b79ca2d5c76f21e947ec93752ae21e33c301f4099edb8375925a6bb0274977d
  • Pointer size: 131 Bytes
  • Size of remote file: 187 kB
figures/Title.png DELETED
Binary file (98.8 kB)
 
figures/input_white.jpg DELETED
Binary file (24.9 kB)
 
figures/net.png DELETED

Git LFS Details

  • SHA256: d0293129d5ef9c40eb72c2cb33863f4a37b45062f4369285387081da3644a8bf
  • Pointer size: 131 Bytes
  • Size of remote file: 725 kB
figures/result.png DELETED

Git LFS Details

  • SHA256: 7bf2e5f68b691f3b0f6246d35f88ffe2a36a12b3c79b7020ba9483ce9fef355c
  • Pointer size: 131 Bytes
  • Size of remote file: 184 kB
figures/vis.png DELETED

Git LFS Details

  • SHA256: 325aed759f19aaae59e9a06c1ae4b8c1e4d3adf1cae2d8c092c1c836834828d8
  • Pointer size: 132 Bytes
  • Size of remote file: 2.21 MB
models/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- import importlib
2
-
3
- from models.arch import *
4
-
5
- from models.cls_model_eval_nocls_reg import ClsModel as ClsReg
6
-
7
-
8
- def make_model(name: str):
9
-
10
- model = ClsReg()
11
- return model
 
 
 
 
 
 
 
 
 
 
 
 
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (388 Bytes). View file
 
models/__pycache__/cls_model_eval_nocls_reg.cpython-310.pyc ADDED
Binary file (17 kB). View file
 
models/__pycache__/losses.cpython-310.pyc ADDED
Binary file (15 kB). View file
 
models/base_model.py DELETED
@@ -1,71 +0,0 @@
1
- import os
2
- import torch
3
- import util.util as util
4
-
5
-
6
- class BaseModel:
7
- def name(self):
8
- return self.__class__.__name__.lower()
9
-
10
- def initialize(self, opt):
11
- self.opt = opt
12
- self.gpu_ids = opt.gpu_ids
13
- self.isTrain = opt.isTrain
14
- self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
15
- last_split = opt.checkpoints_dir.split('/')[-1]
16
- if opt.resume and last_split != 'checkpoints' and (last_split != opt.name or opt.supp_eval):
17
-
18
- self.save_dir = opt.checkpoints_dir
19
- self.model_save_dir = os.path.join(opt.checkpoints_dir.replace(opt.checkpoints_dir.split('/')[-1], ''),
20
- opt.name)
21
- else:
22
- self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
23
- self.model_save_dir = os.path.join(opt.checkpoints_dir, opt.name)
24
- self._count = 0
25
-
26
- def set_input(self, input):
27
- self.input = input
28
-
29
- def forward(self, mode='train'):
30
- pass
31
-
32
- # used in test time, no backprop
33
- def test(self):
34
- pass
35
-
36
- def get_image_paths(self):
37
- pass
38
-
39
- def optimize_parameters(self):
40
- pass
41
-
42
- def get_current_visuals(self):
43
- return self.input
44
-
45
- def get_current_errors(self):
46
- return {}
47
-
48
- def print_optimizer_param(self):
49
- print(self.optimizers[-1])
50
-
51
- def save(self, label=None):
52
- epoch = self.epoch
53
- iterations = self.iterations
54
-
55
- if label is None:
56
- model_name = os.path.join(self.model_save_dir, self.opt.name + '_%03d_%08d.pt' % ((epoch), (iterations)))
57
- else:
58
- model_name = os.path.join(self.model_save_dir, self.opt.name + '_' + label + '.pt')
59
-
60
- torch.save(self.state_dict(), model_name)
61
-
62
- def save_eval(self, label=None):
63
- model_name = os.path.join(self.model_save_dir, label + '.pt')
64
-
65
- torch.save(self.state_dict_eval(), model_name)
66
-
67
- def _init_optimizer(self, optimizers):
68
- self.optimizers = optimizers
69
- for optimizer in self.optimizers:
70
- util.set_opt_param(optimizer, 'initial_lr', self.opt.lr)
71
- util.set_opt_param(optimizer, 'weight_decay', self.opt.wd)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/cls_model_eval_nocls_reg.py DELETED
@@ -1,517 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import torch.nn.functional as F
4
- from models.losses import DINOLoss
5
- import os
6
- import numpy as np
7
- from collections import OrderedDict
8
- from ema_pytorch import EMA
9
- from models.arch.classifier import PretrainedConvNext
10
- import util.util as util
11
- import util.index as index
12
- import models.networks as networks
13
- import models.losses as losses
14
- from models import arch
15
- #from models.arch.dncnn import effnetv2_s
16
- from .base_model import BaseModel
17
- from PIL import Image
18
- from os.path import join
19
- #from torchviz import make_dot
20
- from models.arch.RDnet_ import FullNet_NLP
21
- import timm
22
-
23
- def tensor2im(image_tensor, imtype=np.uint8):
24
- image_tensor = image_tensor.detach()
25
- image_numpy = image_tensor[0].cpu().float().numpy()
26
- image_numpy = np.clip(image_numpy, 0, 1)
27
- if image_numpy.shape[0] == 1:
28
- image_numpy = np.tile(image_numpy, (3, 1, 1))
29
- image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0
30
- # image_numpy = image_numpy.astype(imtype)
31
- return image_numpy
32
-
33
-
34
- class EdgeMap(nn.Module):
35
- def __init__(self, scale=1):
36
- super(EdgeMap, self).__init__()
37
- self.scale = scale
38
- self.requires_grad = False
39
-
40
- def forward(self, img):
41
- img = img / self.scale
42
-
43
- N, C, H, W = img.shape
44
- gradX = torch.zeros(N, 1, H, W, dtype=img.dtype, device=img.device)
45
- gradY = torch.zeros(N, 1, H, W, dtype=img.dtype, device=img.device)
46
-
47
- gradx = (img[..., 1:, :] - img[..., :-1, :]).abs().sum(dim=1, keepdim=True)
48
- grady = (img[..., 1:] - img[..., :-1]).abs().sum(dim=1, keepdim=True)
49
-
50
- gradX[..., :-1, :] += gradx
51
- gradX[..., 1:, :] += gradx
52
- gradX[..., 1:-1, :] /= 2
53
-
54
- gradY[..., :-1] += grady
55
- gradY[..., 1:] += grady
56
- gradY[..., 1:-1] /= 2
57
-
58
- # edge = (gradX + gradY) / 2
59
- edge = (gradX + gradY)
60
-
61
- return edge
62
-
63
-
64
- class YTMTNetBase(BaseModel):
65
- def _init_optimizer(self, optimizers):
66
- self.optimizers = optimizers
67
- for optimizer in self.optimizers:
68
- util.set_opt_param(optimizer, 'initial_lr', self.opt.lr)
69
- util.set_opt_param(optimizer, 'weight_decay', self.opt.wd)
70
-
71
- def set_input(self, data, mode='train'):
72
- target_t = None
73
- target_r = None
74
- data_name = None
75
- identity = False
76
- mode = mode.lower()
77
- if mode == 'train':
78
- input, target_t, target_r = data['input'], data['target_t'], data['target_r']
79
- elif mode == 'eval':
80
- input, target_t, target_r, data_name = data['input'], data['target_t'], data['target_r'], data['fn']
81
- elif mode == 'test':
82
- input, data_name = data['input'], data['fn']
83
- else:
84
- raise NotImplementedError('Mode [%s] is not implemented' % mode)
85
-
86
- if len(self.gpu_ids) > 0: # transfer data into gpu
87
- input = input.to(device=self.gpu_ids[0])
88
- if target_t is not None:
89
- target_t = target_t.to(device=self.gpu_ids[0])
90
- if target_r is not None:
91
- target_r = target_r.to(device=self.gpu_ids[0])
92
-
93
- self.input = input
94
- self.identity = identity
95
- self.input_edge = self.edge_map(self.input)
96
- self.target_t = target_t
97
- self.target_r = target_r
98
- self.data_name = data_name
99
-
100
- self.issyn = False if 'real' in data else True
101
- self.aligned = False if 'unaligned' in data else True
102
-
103
- if target_t is not None:
104
- self.target_edge = self.edge_map(self.target_t)
105
-
106
- def eval(self, data, savedir=None, suffix=None, pieapp=None):
107
- self._eval()
108
- self.set_input(data, 'eval')
109
- with torch.no_grad():
110
- self.forward_eval()
111
-
112
- output_i = tensor2im(self.output_j[6])
113
- output_j = tensor2im(self.output_j[7])
114
- target = tensor2im(self.target_t)
115
- target_r = tensor2im(self.target_r)
116
-
117
- if self.aligned:
118
- res = index.quality_assess(output_i, target)
119
- else:
120
- res = {}
121
-
122
- if savedir is not None:
123
- if self.data_name is not None:
124
- name = os.path.splitext(os.path.basename(self.data_name[0]))[0]
125
- savedir = join(savedir, suffix, name)
126
- os.makedirs(savedir, exist_ok=True)
127
- Image.fromarray(output_i.astype(np.uint8)).save(
128
- join(savedir, '{}_t.png'.format(self.opt.name)))
129
- Image.fromarray(output_j.astype(np.uint8)).save(
130
- join(savedir, '{}_r.png'.format(self.opt.name)))
131
- Image.fromarray(target.astype(np.uint8)).save(join(savedir, 't_label.png'))
132
- Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(join(savedir, 'm_input.png'))
133
- else:
134
- if not os.path.exists(join(savedir, 'transmission_layer')):
135
- os.makedirs(join(savedir, 'transmission_layer'))
136
- os.makedirs(join(savedir, 'blended'))
137
- Image.fromarray(target.astype(np.uint8)).save(
138
- join(savedir, 'transmission_layer', str(self._count) + '.png'))
139
- Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(
140
- join(savedir, 'blended', str(self._count) + '.png'))
141
- self._count += 1
142
-
143
- return res
144
-
145
- def test(self, data, savedir=None):
146
- # only the 1st input of the whole minibatch would be evaluated
147
- self._eval()
148
- self.set_input(data, 'test')
149
-
150
- if self.data_name is not None and savedir is not None:
151
- name = os.path.splitext(os.path.basename(self.data_name[0]))[0]
152
- if not os.path.exists(join(savedir, name)):
153
- os.makedirs(join(savedir, name))
154
-
155
- if os.path.exists(join(savedir, name, '{}.png'.format(self.opt.name))):
156
- return
157
-
158
- with torch.no_grad():
159
- output_i, output_j = self.forward()
160
- output_i = tensor2im(output_i)
161
- output_j = tensor2im(output_j)
162
- if self.data_name is not None and savedir is not None:
163
- Image.fromarray(output_i.astype(np.uint8)).save(join(savedir, name, '{}_l.png'.format(self.opt.name)))
164
- Image.fromarray(output_j.astype(np.uint8)).save(join(savedir, name, '{}_r.png'.format(self.opt.name)))
165
- Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(join(savedir, name, 'm_input.png'))
166
-
167
-
168
- class ClsModel(YTMTNetBase):
169
- def name(self):
170
- return 'ytmtnet'
171
-
172
- def __init__(self):
173
- self.epoch = 0
174
- self.iterations = 0
175
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
176
- self.net_c = None
177
-
178
- def print_network(self):
179
- print('--------------------- Model ---------------------')
180
- print('##################### NetG #####################')
181
- networks.print_network(self.net_i)
182
- if self.isTrain and self.opt.lambda_gan > 0:
183
- print('##################### NetD #####################')
184
- networks.print_network(self.netD)
185
-
186
- def _eval(self):
187
- self.net_i.eval()
188
- self.net_c.eval()
189
-
190
- def _train(self):
191
- self.net_i.train()
192
- self.net_c.eval()
193
- def initialize(self, opt):
194
- self.opt=opt
195
- BaseModel.initialize(self, opt)
196
-
197
- in_channels = 3
198
- self.vgg = None
199
-
200
- if opt.hyper:
201
- self.vgg = losses.Vgg19(requires_grad=False).to(self.device)
202
- in_channels += 1472
203
- channels = [64, 128, 256, 512]
204
- layers = [2, 2, 4, 2]
205
- num_subnet = opt.num_subnet
206
- self.net_c = PretrainedConvNext("convnext_small_in22k").cuda()
207
-
208
- self.net_c.load_state_dict(torch.load('pretrained/cls_model.pth')['icnn'])
209
-
210
- self.net_i = FullNet_NLP(channels, layers, num_subnet, opt.loss_col,num_classes=1000, drop_path=0,save_memory=True, inter_supv=True, head_init_scale=None, kernel_size=3).to(self.device)
211
-
212
- self.edge_map = EdgeMap(scale=1).to(self.device)
213
-
214
- if self.isTrain:
215
- self.loss_dic = losses.init_loss(opt, self.Tensor)
216
- vggloss = losses.ContentLoss()
217
- vggloss.initialize(losses.VGGLoss(self.vgg))
218
- self.loss_dic['t_vgg'] = vggloss
219
-
220
- cxloss = losses.ContentLoss()
221
- if opt.unaligned_loss == 'vgg':
222
- cxloss.initialize(losses.VGGLoss(self.vgg, weights=[0.1], indices=[opt.vgg_layer]))
223
- elif opt.unaligned_loss == 'ctx':
224
- cxloss.initialize(losses.CXLoss(self.vgg, weights=[0.1, 0.1, 0.1], indices=[8, 13, 22]))
225
- elif opt.unaligned_loss == 'mse':
226
- cxloss.initialize(nn.MSELoss())
227
- elif opt.unaligned_loss == 'ctx_vgg':
228
- cxloss.initialize(losses.CXLoss(self.vgg, weights=[0.1, 0.1, 0.1, 0.1], indices=[8, 13, 22, 31],
229
- criterions=[losses.CX_loss] * 3 + [nn.L1Loss()]))
230
- else:
231
- raise NotImplementedError
232
- self.scaler=torch.cuda.amp.GradScaler()
233
- with torch.autocast(device_type='cuda',dtype=torch.float16):
234
- self.dinoloss=DINOLoss()
235
- self.loss_dic['t_cx'] = cxloss
236
-
237
- self.optimizer_G = torch.optim.Adam(self.net_i.parameters(),
238
- lr=opt.lr, betas=(0.9, 0.999), weight_decay=opt.wd)
239
-
240
-
241
- self._init_optimizer([self.optimizer_G])
242
-
243
- if opt.resume:
244
- self.load(self, opt.resume_epoch)
245
-
246
-
247
- def backward_D(self):
248
- loss_D=[]
249
- weight=self.opt.weight_loss
250
- for p in self.netD.parameters():
251
- p.requires_grad = True
252
- for i in range(4):
253
- loss_D_1, pred_fake_1, pred_real_1 = self.loss_dic['gan'].get_loss(
254
- self.netD, self.input, self.output_j[2*i], self.target_t)
255
- loss_D.append(loss_D_1*weight)
256
- weight+=self.opt.weight_loss
257
- loss_sum=sum(loss_D)
258
-
259
- self.loss_D, self.pred_fake, self.pred_real = (loss_sum, pred_fake_1, pred_real_1)
260
-
261
- (self.loss_D * self.opt.lambda_gan).backward(retain_graph=True)
262
-
263
- def get_loss(self, out_l, out_r):
264
- loss_G_GAN_sum=[]
265
- loss_icnn_pixel_sum=[]
266
- loss_rcnn_pixel_sum=[]
267
- loss_icnn_vgg_sum=[]
268
- weight=self.opt.weight_loss
269
- for i in range(self.opt.loss_col):
270
- out_r_clean=out_r[2*i]
271
- out_r_reflection=out_r[2*i+1]
272
- if i != self.opt.loss_col -1:
273
- loss_G_GAN = 0
274
- loss_icnn_pixel = self.loss_dic['t_pixel'].get_loss(out_r_clean, self.target_t)
275
- loss_rcnn_pixel = self.loss_dic['r_pixel'].get_loss(out_r_reflection, self.target_r) * 1.5 * self.opt.r_pixel_weight
276
- loss_icnn_vgg = self.loss_dic['t_vgg'].get_loss(out_r_clean, self.target_t) * self.opt.lambda_vgg
277
- else:
278
- if self.opt.lambda_gan>0:
279
-
280
- loss_G_GAN=0
281
- else:
282
- loss_G_GAN=0
283
- loss_icnn_pixel = self.loss_dic['t_pixel'].get_loss(out_r_clean, self.target_t)
284
- loss_rcnn_pixel = self.loss_dic['r_pixel'].get_loss(out_r_reflection, self.target_r) * 1.5 * self.opt.r_pixel_weight
285
- loss_icnn_vgg = self.loss_dic['t_vgg'].get_loss(out_r_clean, self.target_t) * self.opt.lambda_vgg
286
-
287
- loss_G_GAN_sum.append(loss_G_GAN*weight)
288
- loss_icnn_pixel_sum.append(loss_icnn_pixel*weight)
289
- loss_rcnn_pixel_sum.append(loss_rcnn_pixel*weight)
290
- loss_icnn_vgg_sum.append(loss_icnn_vgg*weight)
291
- weight=weight+self.opt.weight_loss
292
- return sum(loss_G_GAN_sum), sum(loss_icnn_pixel_sum), sum(loss_rcnn_pixel_sum), sum(loss_icnn_vgg_sum)
293
-
294
- def backward_G(self):
295
-
296
- self.loss_G_GAN,self.loss_icnn_pixel, self.loss_rcnn_pixel, \
297
- self.loss_icnn_vgg = self.get_loss(self.output_i, self.output_j)
298
-
299
- self.loss_exclu = self.exclusion_loss(self.output_i, self.output_j, 3)
300
-
301
- self.loss_recons = self.loss_dic['recons'](self.output_i, self.output_j, self.input) * 0.2
302
-
303
- self.loss_G = self.loss_G_GAN +self.loss_icnn_pixel + self.loss_rcnn_pixel + \
304
- self.loss_icnn_vgg
305
- self.scaler.scale(self.loss_G).backward()
306
-
307
-
308
-
309
- def hyper_column(self, input_img):
310
- hypercolumn = self.vgg(input_img)
311
- _, C, H, W = input_img.shape
312
- hypercolumn = [F.interpolate(feature.detach(), size=(H, W), mode='bilinear', align_corners=False) for
313
- feature in hypercolumn]
314
- input_i = [input_img]
315
- input_i.extend(hypercolumn)
316
- input_i = torch.cat(input_i, dim=1)
317
- return input_i
318
-
319
- def forward(self):
320
- # without edge
321
-
322
- self.output_j=[]
323
- input_i = self.input
324
- if self.vgg is not None:
325
- input_i = self.hyper_column(input_i)
326
- with torch.no_grad():
327
- ipt = self.net_c(input_i)
328
- output_i, output_j = self.net_i(input_i,ipt,prompt=True)
329
- self.output_i = output_i
330
- for i in range(self.opt.loss_col):
331
- out_reflection, out_clean = output_j[i][:, :3, ...], output_j[i][:, 3:, ...]
332
- self.output_j.append(out_clean)
333
- self.output_j.append(out_reflection)
334
- return self.output_i, self.output_j
335
-
336
-
337
- @torch.no_grad()
338
- def forward_eval(self):
339
-
340
- self.output_j=[]
341
- input_i = self.input
342
- if self.vgg is not None:
343
- input_i = self.hyper_column(input_i)
344
- ipt = self.net_c(input_i)
345
-
346
- output_i, output_j = self.net_i(input_i,ipt,prompt=True)
347
- self.output_i = output_i #alpha * output_i + beta
348
- for i in range(self.opt.loss_col):
349
- out_reflection, out_clean = output_j[i][:, :3, ...], output_j[i][:, 3:, ...]
350
- self.output_j.append(out_clean)
351
- self.output_j.append(out_reflection)
352
- return self.output_i, self.output_j
353
-
354
- def optimize_parameters(self):
355
- self._train()
356
- self.forward()
357
- self.optimizer_G.zero_grad()
358
- self.backward_G()
359
- self.optimizer_G.step()
360
-
361
- def return_output(self):
362
- output_clean = self.output_j[1]
363
- output_reflection = self.output_j[0]
364
- output_clean = tensor2im(output_clean).astype(np.uint8)
365
- output_reflection = tensor2im(output_reflection).astype(np.uint8)
366
- input=tensor2im(self.input)
367
- return output_clean,output_reflection,input
368
- def exclusion_loss(self, img_T, img_R, level=3, eps=1e-6):
369
- loss_gra=[]
370
- weight=0.25
371
- for i in range(4):
372
- grad_x_loss = []
373
- grad_y_loss = []
374
- img_T=self.output_j[2*i]
375
- img_R=self.output_j[2*i+1]
376
- for l in range(level):
377
- grad_x_T, grad_y_T = self.compute_grad(img_T)
378
- grad_x_R, grad_y_R = self.compute_grad(img_R)
379
-
380
- alphax = (2.0 * torch.mean(torch.abs(grad_x_T))) / (torch.mean(torch.abs(grad_x_R)) + eps)
381
- alphay = (2.0 * torch.mean(torch.abs(grad_y_T))) / (torch.mean(torch.abs(grad_y_R)) + eps)
382
-
383
- gradx1_s = (torch.sigmoid(grad_x_T) * 2) - 1 # mul 2 minus 1 is to change sigmoid into tanh
384
- grady1_s = (torch.sigmoid(grad_y_T) * 2) - 1
385
- gradx2_s = (torch.sigmoid(grad_x_R * alphax) * 2) - 1
386
- grady2_s = (torch.sigmoid(grad_y_R * alphay) * 2) - 1
387
-
388
- grad_x_loss.append(((torch.mean(torch.mul(gradx1_s.pow(2), gradx2_s.pow(2)))) + eps) ** 0.25)
389
- grad_y_loss.append(((torch.mean(torch.mul(grady1_s.pow(2), grady2_s.pow(2)))) + eps) ** 0.25)
390
-
391
- img_T = F.interpolate(img_T, scale_factor=0.5, mode='bilinear')
392
- img_R = F.interpolate(img_R, scale_factor=0.5, mode='bilinear')
393
- loss_gradxy = torch.sum(sum(grad_x_loss) / 3) + torch.sum(sum(grad_y_loss) / 3)
394
- loss_gra.append(loss_gradxy*weight)
395
- weight+=0.25
396
-
397
-
398
- return sum(loss_gra) / 2
399
-
400
- def contain_loss(self, img_T, img_R, img_I, eps=1e-6):
401
- pix_num = np.prod(img_I.shape)
402
- predict_tx, predict_ty = self.compute_grad(img_T)
403
- predict_tx, predict_ty = self.compute_grad(img_T)
404
- predict_rx, predict_ry = self.compute_grad(img_R)
405
- input_x, input_y = self.compute_grad(img_I)
406
-
407
- out = torch.norm(predict_tx / (input_x + eps), 2) ** 2 + \
408
- torch.norm(predict_ty / (input_y + eps), 2) ** 2 + \
409
- torch.norm(predict_rx / (input_x + eps), 2) ** 2 + \
410
- torch.norm(predict_ry / (input_y + eps), 2) ** 2
411
-
412
- return out / pix_num
413
-
414
- def compute_grad(self, img):
415
- gradx = img[:, :, 1:, :] - img[:, :, :-1, :]
416
- grady = img[:, :, :, 1:] - img[:, :, :, :-1]
417
- return gradx, grady
418
-
419
- def load(self, model, resume_epoch=None):
420
- icnn_path = model.opt.icnn_path
421
- state_dict = torch.load(icnn_path)
422
- model.net_i.load_state_dict(state_dict['icnn'])
423
- return state_dict
424
-
425
- def state_dict(self):
426
- state_dict = {
427
- 'icnn': self.net_i.state_dict(),
428
- 'opt_g': self.optimizer_G.state_dict(),
429
- #'ema' : self.ema.state_dict(),
430
- 'epoch': self.epoch, 'iterations': self.iterations
431
- }
432
-
433
- if self.opt.lambda_gan > 0:
434
- state_dict.update({
435
- 'opt_d': self.optimizer_D.state_dict(),
436
- 'netD': self.netD.state_dict(),
437
- })
438
-
439
- return state_dict
440
- class AvgPool2d(nn.Module):
441
- def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
442
- super().__init__()
443
- self.kernel_size = kernel_size
444
- self.base_size = base_size
445
- self.auto_pad = auto_pad
446
-
447
- # only used for fast implementation
448
- self.fast_imp = fast_imp
449
- self.rs = [5, 4, 3, 2, 1]
450
- self.max_r1 = self.rs[0]
451
- self.max_r2 = self.rs[0]
452
- self.train_size = train_size
453
-
454
- def extra_repr(self) -> str:
455
- return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
456
- self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
457
- )
458
-
459
- def forward(self, x):
460
- if self.kernel_size is None and self.base_size:
461
- train_size = self.train_size
462
- if isinstance(self.base_size, int):
463
- self.base_size = (self.base_size, self.base_size)
464
- self.kernel_size = list(self.base_size)
465
- self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
466
- self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
467
-
468
- # only used for fast implementation
469
- self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
470
- self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
471
-
472
- if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
473
- return F.adaptive_avg_pool2d(x, 1)
474
-
475
- if self.fast_imp: # Non-equivalent implementation but faster
476
- h, w = x.shape[2:]
477
- if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
478
- out = F.adaptive_avg_pool2d(x, 1)
479
- else:
480
- r1 = [r for r in self.rs if h % r == 0][0]
481
- r2 = [r for r in self.rs if w % r == 0][0]
482
- # reduction_constraint
483
- r1 = min(self.max_r1, r1)
484
- r2 = min(self.max_r2, r2)
485
- s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
486
- n, c, h, w = s.shape
487
- k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
488
- out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
489
- out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
490
- else:
491
- n, c, h, w = x.shape
492
- s = x.cumsum(dim=-1).cumsum_(dim=-2)
493
- s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
494
- k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
495
- s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
496
- out = s4 + s1 - s2 - s3
497
- out = out / (k1 * k2)
498
-
499
- if self.auto_pad:
500
- n, c, h, w = x.shape
501
- _h, _w = out.shape[2:]
502
- # print(x.shape, self.kernel_size)
503
- pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
504
- out = torch.nn.functional.pad(out, pad2d, mode='replicate')
505
-
506
- return out
507
-
508
- def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
509
- for n, m in model.named_children():
510
- if len(list(m.children())) > 0:
511
- ## compound module, go inside it
512
- replace_layers(m, base_size, train_size, fast_imp, **kwargs)
513
-
514
- if isinstance(m, nn.AdaptiveAvgPool2d):
515
- pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
516
- assert m.output_size == 1
517
- setattr(model, n, pool)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/losses.py DELETED
@@ -1,468 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from pytorch_msssim import SSIM
6
- from models.vit_feature_extractor import VitExtractor
7
- from models.vgg import Vgg19
8
-
9
-
10
- ###############################################################################
11
- # Functions
12
- ###############################################################################
13
- def compute_gradient(img):
14
- gradx = img[..., 1:, :] - img[..., :-1, :]
15
- grady = img[..., 1:] - img[..., :-1]
16
- return gradx, grady
17
-
18
-
19
- class GradientLoss(nn.Module):
20
- def __init__(self):
21
- super(GradientLoss, self).__init__()
22
- self.loss = nn.L1Loss()
23
-
24
- def forward(self, predict, target):
25
- predict_gradx, predict_grady = compute_gradient(predict)
26
- target_gradx, target_grady = compute_gradient(target)
27
-
28
- return self.loss(predict_gradx, target_gradx) + self.loss(predict_grady, target_grady)
29
-
30
-
31
- class ContainLoss(nn.Module):
32
- def __init__(self, eps=1e-12):
33
- super(ContainLoss, self).__init__()
34
- self.eps = eps
35
-
36
- def forward(self, predict_t, predict_r, input_image):
37
- pix_num = np.prod(input_image.shape)
38
- predict_tx, predict_ty = compute_gradient(predict_t)
39
- predict_rx, predict_ry = compute_gradient(predict_r)
40
- input_x, input_y = compute_gradient(input_image)
41
-
42
- out = torch.norm(predict_tx / (input_x + self.eps), 2) ** 2 + \
43
- torch.norm(predict_ty / (input_y + self.eps), 2) ** 2 + \
44
- torch.norm(predict_rx / (input_x + self.eps), 2) ** 2 + \
45
- torch.norm(predict_ry / (input_y + self.eps), 2) ** 2
46
-
47
- return out / pix_num
48
-
49
-
50
- class MultipleLoss(nn.Module):
51
- def __init__(self, losses, weight=None):
52
- super(MultipleLoss, self).__init__()
53
- self.losses = nn.ModuleList(losses)
54
- self.weight = weight or [1 / len(self.losses)] * len(self.losses)
55
-
56
- def forward(self, predict, target):
57
- total_loss = 0
58
- for weight, loss in zip(self.weight, self.losses):
59
- total_loss += loss(predict, target) * weight
60
- return total_loss
61
-
62
-
63
- class MeanShift(nn.Conv2d):
64
- def __init__(self, data_mean, data_std, data_range=1, norm=True):
65
- """norm (bool): normalize/denormalize the stats"""
66
- c = len(data_mean)
67
- super(MeanShift, self).__init__(c, c, kernel_size=1)
68
- std = torch.Tensor(data_std)
69
- self.weight.data = torch.eye(c).view(c, c, 1, 1)
70
- if norm:
71
- self.weight.data.div_(std.view(c, 1, 1, 1))
72
- self.bias.data = -1 * data_range * torch.Tensor(data_mean)
73
- self.bias.data.div_(std)
74
- else:
75
- self.weight.data.mul_(std.view(c, 1, 1, 1))
76
- self.bias.data = data_range * torch.Tensor(data_mean)
77
- self.requires_grad = False
78
-
79
-
80
- class VGGLoss(nn.Module):
81
- def __init__(self, vgg=None, weights=None, indices=None, normalize=True):
82
- super(VGGLoss, self).__init__()
83
- if vgg is None:
84
- self.vgg = torch.compile(Vgg19().cuda())
85
- else:
86
- self.vgg = vgg
87
- self.criterion = nn.L1Loss()
88
- self.weights = weights or [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10 / 1.5]
89
- self.indices = indices or [2, 7, 12, 21, 30]
90
- if normalize:
91
- self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
92
- else:
93
- self.normalize = None
94
-
95
- def forward(self, x, y):
96
- if self.normalize is not None:
97
- x = self.normalize(x)
98
- y = self.normalize(y)
99
- with torch.no_grad():
100
- y_vgg = self.vgg(y, self.indices)
101
- x_vgg = self.vgg(x, self.indices) #, self.vgg(y, self.indices)
102
- loss = 0
103
- for i in range(len(x_vgg)):
104
- loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i]) #.detach())
105
-
106
- return loss
107
-
108
-
109
- def l1_norm_dim(x, dim):
110
- return torch.mean(torch.abs(x), dim=dim)
111
-
112
-
113
- def l1_norm(x):
114
- return torch.mean(torch.abs(x))
115
-
116
-
117
- def l2_norm(x):
118
- return torch.mean(torch.square(x))
119
-
120
-
121
- def gradient_norm_kernel(x, kernel_size=10):
122
- out_h, out_v = compute_gradient(x)
123
- shape = out_h.shape
124
- out_h = F.unfold(out_h, kernel_size=(kernel_size, kernel_size), stride=(1, 1))
125
- out_h = out_h.reshape(shape[0], shape[1], kernel_size * kernel_size, -1)
126
- out_h = l1_norm_dim(out_h, 2)
127
- out_v = F.unfold(out_v, kernel_size=(kernel_size, kernel_size), stride=(1, 1))
128
- out_v = out_v.reshape(shape[0], shape[1], kernel_size * kernel_size, -1)
129
- out_v = l1_norm_dim(out_v, 2)
130
- return out_h, out_v
131
-
132
-
133
- class KTVLoss(nn.Module):
134
- def __init__(self, kernel_size=10):
135
- super().__init__()
136
- self.kernel_size = kernel_size
137
- self.criterion = nn.L1Loss()
138
- self.eps = 1e-6
139
-
140
- def forward(self, out_l, out_r, input_i):
141
- out_l_normx, out_l_normy = gradient_norm_kernel(out_l, self.kernel_size)
142
- out_r_normx, out_r_normy = gradient_norm_kernel(out_r, self.kernel_size)
143
- input_normx, input_normy = gradient_norm_kernel(input_i, self.kernel_size)
144
- norm_l = out_l_normx + out_l_normy
145
- norm_r = out_r_normx + out_r_normy
146
- norm_target = input_normx + input_normy + self.eps
147
- norm_loss = (norm_l / norm_target + norm_r / norm_target).mean()
148
-
149
- out_lx, out_ly = compute_gradient(out_l)
150
- out_rx, out_ry = compute_gradient(out_r)
151
- input_x, input_y = compute_gradient(input_i)
152
- gradient_diffx = self.criterion(out_lx + out_rx, input_x)
153
- gradient_diffy = self.criterion(out_ly + out_ry, input_y)
154
- grad_loss = gradient_diffx + gradient_diffy
155
-
156
- loss = norm_loss * 1e-4 + grad_loss
157
- return loss
158
-
159
-
160
- class MTVLoss(nn.Module):
161
- def __init__(self, kernel_size=10):
162
- super().__init__()
163
- self.criterion = nn.L1Loss()
164
- self.norm = l1_norm
165
-
166
- def forward(self, out_l, out_r, input_i):
167
- out_lx, out_ly = compute_gradient(out_l)
168
- out_rx, out_ry = compute_gradient(out_r)
169
- input_x, input_y = compute_gradient(input_i)
170
-
171
- norm_l = self.norm(out_lx) + self.norm(out_ly)
172
- norm_r = self.norm(out_rx) + self.norm(out_ry)
173
- norm_target = self.norm(input_x) + self.norm(input_y)
174
-
175
- gradient_diffx = self.criterion(out_lx + out_rx, input_x)
176
- gradient_diffy = self.criterion(out_ly + out_ry, input_y)
177
-
178
- loss = (norm_l / norm_target + norm_r / norm_target) * 1e-5 + gradient_diffx + gradient_diffy
179
-
180
- return loss
181
-
182
-
183
- class ReconsLoss(nn.Module):
184
- def __init__(self, edge_recons=True):
185
- super().__init__()
186
- self.criterion = nn.L1Loss()
187
- self.norm = l1_norm
188
- self.edge_recons = edge_recons
189
- self.mse_loss=nn.MSELoss()
190
-
191
- def forward(self, out_l, out_r, input_i):
192
- loss_sum=[]
193
- weight=0.25
194
- for i in range(4):
195
- #out_res = out_l[i]
196
- out_clean=out_r[2*i]
197
- out_reflection=out_r[2*i+1]
198
- #content_diff = self.criterion(out_clean + out_reflection, input_i)
199
- # if self.edge_recons:
200
- # out_lx, out_ly = compute_gradient(out_clean)
201
- # out_rx, out_ry = compute_gradient(out_reflection)
202
- # #out_resx, out_resy = compute_gradient(out_res)
203
- # input_x, input_y = compute_gradient(input_i)
204
-
205
- # gradient_diffx = self.criterion(out_lx + out_rx, input_x)
206
- # gradient_diffy = self.criterion(out_ly + out_ry, input_y)
207
-
208
- # loss = content_diff + (gradient_diffx + gradient_diffy) * 5.0
209
- # else:
210
- # loss = content_diff
211
- loss=self.mse_loss(out_clean+out_reflection,input_i)
212
- loss_sum.append(loss*weight)
213
- weight=weight+0.25
214
-
215
- return sum(loss_sum)
216
-
217
-
218
- class ReconsLossX(nn.Module):
219
- def __init__(self, edge_recons=True):
220
- super().__init__()
221
- self.criterion = nn.MSELoss()
222
- self.norm = l1_norm
223
- self.edge_recons = edge_recons
224
-
225
- def forward(self, out, input_i):
226
- content_diff = self.criterion(out, input_i)
227
- if self.edge_recons:
228
- out_x, out_y = compute_gradient(out)
229
- input_x, input_y = compute_gradient(input_i)
230
-
231
- gradient_diffx = self.criterion(out_x, input_x)
232
- gradient_diffy = self.criterion(out_y, input_y)
233
-
234
- loss = content_diff + (gradient_diffx + gradient_diffy) * 1.0
235
- else:
236
- loss = content_diff
237
- return loss
238
-
239
-
240
- class ContentLoss():
241
- def initialize(self, loss):
242
- self.criterion = loss
243
-
244
- def get_loss(self, fakeIm, realIm):
245
- return self.criterion(fakeIm, realIm)
246
-
247
-
248
- class GANLoss(nn.Module):
249
- def __init__(self, use_l1=True, target_real_label=1.0, target_fake_label=0.0,
250
- tensor=torch.FloatTensor):
251
- super(GANLoss, self).__init__()
252
- self.real_label = target_real_label
253
- self.fake_label = target_fake_label
254
- self.real_label_var = None
255
- self.fake_label_var = None
256
- self.Tensor = tensor
257
- if use_l1:
258
- self.loss = nn.L1Loss()
259
- else:
260
- self.loss = nn.BCEWithLogitsLoss() # absorb sigmoid into BCELoss
261
-
262
- def get_target_tensor(self, input, target_is_real):
263
- target_tensor = None
264
- if target_is_real:
265
- create_label = ((self.real_label_var is None) or
266
- (self.real_label_var.numel() != input.numel()))
267
- if create_label:
268
- real_tensor = self.Tensor(input.size()).fill_(self.real_label)
269
- self.real_label_var = real_tensor
270
- target_tensor = self.real_label_var
271
- else:
272
- create_label = ((self.fake_label_var is None) or
273
- (self.fake_label_var.numel() != input.numel()))
274
- if create_label:
275
- fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
276
- self.fake_label_var = fake_tensor
277
- target_tensor = self.fake_label_var
278
- return target_tensor
279
-
280
- def __call__(self, input, target_is_real):
281
- if isinstance(input, list):
282
- loss = 0
283
- for input_i in input:
284
- target_tensor = self.get_target_tensor(input_i, target_is_real)
285
- loss += self.loss(input_i, target_tensor)
286
- return loss
287
- else:
288
- target_tensor = self.get_target_tensor(input, target_is_real)
289
- return self.loss(input, target_tensor)
290
-
291
-
292
- class DiscLoss():
293
- def name(self):
294
- return 'SGAN'
295
-
296
- def initialize(self, opt, tensor):
297
- self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
298
-
299
- def get_g_loss(self, net, realA, fakeB, realB):
300
- # First, G(A) should fake the discriminator
301
- pred_fake = net.forward(fakeB)
302
- return self.criterionGAN(pred_fake, 1)
303
-
304
- def get_loss(self, net, realA=None, fakeB=None, realB=None):
305
- pred_fake = None
306
- pred_real = None
307
- loss_D_fake = 0
308
- loss_D_real = 0
309
- # Fake
310
- # stop backprop to the generator by detaching fake_B
311
- # Generated Image Disc Output should be close to zero
312
-
313
- if fakeB is not None:
314
- pred_fake = net.forward(fakeB.detach())
315
- loss_D_fake = self.criterionGAN(pred_fake, 0)
316
-
317
- # Real
318
- if realB is not None:
319
- pred_real = net.forward(realB)
320
- loss_D_real = self.criterionGAN(pred_real, 1)
321
-
322
- # Combined loss
323
- loss_D = (loss_D_fake + loss_D_real) * 0.5
324
- return loss_D, pred_fake, pred_real
325
-
326
-
327
- class DiscLossR(DiscLoss):
328
- # RSGAN from
329
- # https://arxiv.org/abs/1807.00734
330
- def name(self):
331
- return 'RSGAN'
332
-
333
- def initialize(self, opt, tensor):
334
- DiscLoss.initialize(self, opt, tensor)
335
- self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
336
-
337
- def get_g_loss(self, net, realA, fakeB, realB, pred_real=None):
338
- if pred_real is None:
339
- pred_real = net.forward(realB)
340
- pred_fake = net.forward(fakeB)
341
- return self.criterionGAN(pred_fake - pred_real, 1)
342
-
343
- def get_loss(self, net, realA, fakeB, realB):
344
- pred_real = net.forward(realB)
345
- pred_fake = net.forward(fakeB.detach())
346
-
347
- loss_D = self.criterionGAN(pred_real - pred_fake, 1) # BCE_stable loss
348
- return loss_D, pred_fake, pred_real
349
-
350
-
351
- class DiscLossRa(DiscLoss):
352
- # RaSGAN from
353
- # https://arxiv.org/abs/1807.00734
354
- def name(self):
355
- return 'RaSGAN'
356
-
357
- def initialize(self, opt, tensor):
358
- DiscLoss.initialize(self, opt, tensor)
359
- self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
360
-
361
- def get_g_loss(self, net, realA, fakeB, realB, pred_real=None):
362
- if pred_real is None:
363
- pred_real = net.forward(realB)
364
- pred_fake = net.forward(fakeB)
365
-
366
- loss_G = self.criterionGAN(pred_real - torch.mean(pred_fake, dim=0, keepdim=True), 0)
367
- loss_G += self.criterionGAN(pred_fake - torch.mean(pred_real, dim=0, keepdim=True), 1)
368
- return loss_G * 0.5
369
-
370
- def get_loss(self, net, realA, fakeB, realB):
371
- pred_real = net.forward(realB)
372
-
373
- pred_fake = net.forward(fakeB.detach())
374
-
375
- loss_D = self.criterionGAN(pred_real - torch.mean(pred_fake, dim=0, keepdim=True), 1)
376
- loss_D += self.criterionGAN(pred_fake - torch.mean(pred_real, dim=0, keepdim=True), 0)
377
- return loss_D * 0.5, pred_fake, pred_real
378
-
379
-
380
- class SSIM_Loss(nn.Module):
381
- def __init__(self):
382
- super().__init__()
383
- self.ssim = SSIM(data_range=1, size_average=True, channel=3)
384
-
385
- def forward(self, output, target):
386
- return 1 - self.ssim(output, target)
387
-
388
-
389
- def init_loss(opt, tensor):
390
- disc_loss = None
391
- content_loss = None
392
-
393
- loss_dic = {}
394
-
395
- pixel_loss = ContentLoss()
396
- pixel_loss.initialize(MultipleLoss([nn.MSELoss(), GradientLoss()], [0.3, 0.6]))
397
-
398
- loss_dic['t_pixel'] = pixel_loss
399
-
400
- r_loss = ContentLoss()
401
- r_loss.initialize(MultipleLoss([nn.MSELoss()], [0.9]))
402
- loss_dic['r_pixel'] = pixel_loss
403
-
404
- loss_dic['t_ssim'] = SSIM_Loss()
405
- loss_dic['r_ssim'] = SSIM_Loss()
406
-
407
- loss_dic['mtv'] = MTVLoss()
408
- loss_dic['ktv'] = KTVLoss()
409
- loss_dic['recons'] = ReconsLoss(edge_recons=False)
410
- loss_dic['reconsx'] = ReconsLossX(edge_recons=False)
411
-
412
- if opt.lambda_gan > 0:
413
- if opt.gan_type == 'sgan' or opt.gan_type == 'gan':
414
- disc_loss = DiscLoss()
415
- elif opt.gan_type == 'rsgan':
416
- disc_loss = DiscLossR()
417
- elif opt.gan_type == 'rasgan':
418
- disc_loss = DiscLossRa()
419
- else:
420
- raise ValueError("GAN [%s] not recognized." % opt.gan_type)
421
-
422
- disc_loss.initialize(opt, tensor)
423
- loss_dic['gan'] = disc_loss
424
-
425
- return loss_dic
426
-
427
- class DINOLoss(nn.Module):
428
- '''
429
- DINO-ViT as perceptual loss
430
- '''
431
-
432
- def resize_to_dino(self, feature, size = (224, 224)):
433
- return F.interpolate(feature, size = size, mode='bilinear', align_corners=False)
434
-
435
- def calculate_crop_cls_loss(self, outputs, inputs):
436
- loss = 0.0
437
- for a, b in zip(outputs, inputs): # avoid memory limitations
438
- a = self.global_transform(a).unsqueeze(0)
439
- b = self.global_transform(b).unsqueeze(0)
440
- cls_token = self.extractor.get_feature_from_input(a)[-1][0, 0, :]
441
- with torch.no_grad():
442
- target_cls_token = self.extractor.get_feature_from_input(b)[-1][0, 0, :]
443
- loss += F.mse_loss(cls_token, target_cls_token)
444
- return loss
445
-
446
- def __init__(self) :
447
- super(DINOLoss, self).__init__()
448
- self.extractor = VitExtractor(model_name = 'dino_vits8', device = 'cuda')
449
- self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
450
-
451
- def forward(self, output, target):
452
- output = self.normalize(self.resize_to_dino(output))
453
- output_cls_token = self.extractor.get_feature_from_input(output)[-1][0, 0, :]
454
- with torch.no_grad():
455
- target = self.normalize(self.resize_to_dino(target))
456
- target_cls_token = self.extractor.get_feature_from_input(target)[-1][0, 0, :]
457
-
458
- return F.mse_loss(output_cls_token, target_cls_token)
459
-
460
- if __name__ == '__main__':
461
- x = torch.randn(3, 32, 224, 224).cuda()
462
- import time
463
-
464
- s = time.time()
465
- out1, out2 = gradient_norm_kernel(x)
466
- t = time.time()
467
- print(t - s)
468
- print(out1.shape, out2.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/losses_opt.py DELETED
@@ -1,404 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from pytorch_msssim import MS_SSIM, SSIM
6
-
7
- from models.vgg import Vgg19
8
-
9
-
10
- ###############################################################################
11
- # Functions
12
- ###############################################################################
13
- def compute_gradient(img):
14
- gradx = img[..., 1:, :] - img[..., :-1, :]
15
- grady = img[..., 1:] - img[..., :-1]
16
- return gradx, grady
17
-
18
-
19
- class GradientLoss(nn.Module):
20
- def __init__(self):
21
- super(GradientLoss, self).__init__()
22
- self.loss = nn.L1Loss()
23
-
24
- def forward(self, predict, target):
25
- predict_gradx, predict_grady = compute_gradient(predict)
26
- target_gradx, target_grady = compute_gradient(target)
27
-
28
- return self.loss(predict_gradx, target_gradx) + self.loss(predict_grady, target_grady)
29
-
30
-
31
- class ContainLoss(nn.Module):
32
- def __init__(self, eps=1e-12):
33
- super(ContainLoss, self).__init__()
34
- self.eps = eps
35
-
36
- def forward(self, predict_t, predict_r, input_image):
37
- pix_num = np.prod(input_image.shape)
38
- predict_tx, predict_ty = compute_gradient(predict_t)
39
- predict_rx, predict_ry = compute_gradient(predict_r)
40
- input_x, input_y = compute_gradient(input_image)
41
-
42
- out = torch.norm(predict_tx / (input_x + self.eps), 2) ** 2 + \
43
- torch.norm(predict_ty / (input_y + self.eps), 2) ** 2 + \
44
- torch.norm(predict_rx / (input_x + self.eps), 2) ** 2 + \
45
- torch.norm(predict_ry / (input_y + self.eps), 2) ** 2
46
-
47
- return out / pix_num
48
-
49
-
50
- class MultipleLoss(nn.Module):
51
- def __init__(self, losses, weight=None):
52
- super(MultipleLoss, self).__init__()
53
- self.losses = nn.ModuleList(losses)
54
- self.weight = weight or [1 / len(self.losses)] * len(self.losses)
55
-
56
- def forward(self, predict, target):
57
- total_loss = 0
58
- for weight, loss in zip(self.weight, self.losses):
59
- total_loss += loss(predict, target) * weight
60
- return total_loss
61
-
62
-
63
- class MeanShift(nn.Conv2d):
64
- def __init__(self, data_mean, data_std, data_range=1, norm=True):
65
- """norm (bool): normalize/denormalize the stats"""
66
- c = len(data_mean)
67
- super(MeanShift, self).__init__(c, c, kernel_size=1)
68
- std = torch.Tensor(data_std)
69
- self.weight.data = torch.eye(c).view(c, c, 1, 1)
70
- if norm:
71
- self.weight.data.div_(std.view(c, 1, 1, 1))
72
- self.bias.data = -1 * data_range * torch.Tensor(data_mean)
73
- self.bias.data.div_(std)
74
- else:
75
- self.weight.data.mul_(std.view(c, 1, 1, 1))
76
- self.bias.data = data_range * torch.Tensor(data_mean)
77
- self.requires_grad = False
78
-
79
-
80
- class VGGLoss(nn.Module):
81
- def __init__(self, vgg=None, weights=None, indices=None, normalize=True):
82
- super(VGGLoss, self).__init__()
83
- if vgg is None:
84
- self.vgg = Vgg19().cuda()
85
- else:
86
- self.vgg = vgg
87
- self.criterion = nn.L1Loss()
88
- self.weights = weights or [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10 / 1.5]
89
- self.indices = indices or [2, 7, 12, 21, 30]
90
- if normalize:
91
- self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
92
- else:
93
- self.normalize = None
94
-
95
- def forward(self, x, y):
96
- if self.normalize is not None:
97
- x = self.normalize(x)
98
- y = self.normalize(y)
99
- x_vgg, y_vgg = self.vgg(x, self.indices), self.vgg(y, self.indices)
100
- loss = 0
101
- for i in range(len(x_vgg)):
102
- loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
103
-
104
- return loss
105
-
106
-
107
- def l1_norm_dim(x, dim):
108
- return torch.mean(torch.abs(x), dim=dim)
109
-
110
-
111
- def l1_norm(x):
112
- return torch.mean(torch.abs(x))
113
-
114
-
115
- def l2_norm(x):
116
- return torch.mean(torch.square(x))
117
-
118
-
119
- def gradient_norm_kernel(x, kernel_size=10):
120
- out_h, out_v = compute_gradient(x)
121
- shape = out_h.shape
122
- out_h = F.unfold(out_h, kernel_size=(kernel_size, kernel_size), stride=(1, 1))
123
- out_h = out_h.reshape(shape[0], shape[1], kernel_size * kernel_size, -1)
124
- out_h = l1_norm_dim(out_h, 2)
125
- out_v = F.unfold(out_v, kernel_size=(kernel_size, kernel_size), stride=(1, 1))
126
- out_v = out_v.reshape(shape[0], shape[1], kernel_size * kernel_size, -1)
127
- out_v = l1_norm_dim(out_v, 2)
128
- return out_h, out_v
129
-
130
-
131
- class KTVLoss(nn.Module):
132
- def __init__(self, kernel_size=10):
133
- super().__init__()
134
- self.kernel_size = kernel_size
135
- self.criterion = nn.L1Loss()
136
- self.eps = 1e-6
137
-
138
- def forward(self, out_l, out_r, input_i):
139
- out_l_normx, out_l_normy = gradient_norm_kernel(out_l, self.kernel_size)
140
- out_r_normx, out_r_normy = gradient_norm_kernel(out_r, self.kernel_size)
141
- input_normx, input_normy = gradient_norm_kernel(input_i, self.kernel_size)
142
- norm_l = out_l_normx + out_l_normy
143
- norm_r = out_r_normx + out_r_normy
144
- norm_target = input_normx + input_normy + self.eps
145
- norm_loss = (norm_l / norm_target + norm_r / norm_target).mean()
146
-
147
- out_lx, out_ly = compute_gradient(out_l)
148
- out_rx, out_ry = compute_gradient(out_r)
149
- input_x, input_y = compute_gradient(input_i)
150
- gradient_diffx = self.criterion(out_lx + out_rx, input_x)
151
- gradient_diffy = self.criterion(out_ly + out_ry, input_y)
152
- grad_loss = gradient_diffx + gradient_diffy
153
-
154
- loss = norm_loss * 1e-4 + grad_loss
155
- return loss
156
-
157
-
158
- class MTVLoss(nn.Module):
159
- def __init__(self, kernel_size=10):
160
- super().__init__()
161
- self.criterion = nn.L1Loss()
162
- self.norm = l1_norm
163
-
164
- def forward(self, out_l, out_r, input_i):
165
- out_lx, out_ly = compute_gradient(out_l)
166
- out_rx, out_ry = compute_gradient(out_r)
167
- input_x, input_y = compute_gradient(input_i)
168
-
169
- norm_l = self.norm(out_lx) + self.norm(out_ly)
170
- norm_r = self.norm(out_rx) + self.norm(out_ry)
171
- norm_target = self.norm(input_x) + self.norm(input_y)
172
-
173
- gradient_diffx = self.criterion(out_lx + out_rx, input_x)
174
- gradient_diffy = self.criterion(out_ly + out_ry, input_y)
175
-
176
- loss = (norm_l / norm_target + norm_r / norm_target) * 1e-5 + gradient_diffx + gradient_diffy
177
-
178
- return loss
179
-
180
-
181
- class ReconsLoss(nn.Module):
182
- def __init__(self):
183
- super().__init__()
184
- self.criterion = nn.L1Loss()
185
- self.norm = l1_norm
186
-
187
- def forward(self, out_l, out_r, input_i):
188
- content_diff = self.criterion(out_l + out_r, input_i)
189
- out_lx, out_ly = compute_gradient(out_l)
190
- out_rx, out_ry = compute_gradient(out_r)
191
- input_x, input_y = compute_gradient(input_i)
192
-
193
- gradient_diffx = self.criterion(out_lx + out_rx, input_x)
194
- gradient_diffy = self.criterion(out_ly + out_ry, input_y)
195
-
196
- loss = content_diff + (gradient_diffx + gradient_diffy) * 0.5
197
-
198
- return loss
199
-
200
-
201
- class ContentLoss():
202
- def initialize(self, loss):
203
- self.criterion = loss
204
-
205
- def get_loss(self, fakeIm, realIm):
206
- return self.criterion(fakeIm, realIm)
207
-
208
-
209
- class GANLoss(nn.Module):
210
- def __init__(self, use_l1=True, target_real_label=1.0, target_fake_label=0.0,
211
- tensor=torch.FloatTensor):
212
- super(GANLoss, self).__init__()
213
- self.real_label = target_real_label
214
- self.fake_label = target_fake_label
215
- self.real_label_var = None
216
- self.fake_label_var = None
217
- self.Tensor = tensor
218
- if use_l1:
219
- self.loss = nn.L1Loss()
220
- else:
221
- self.loss = nn.BCEWithLogitsLoss() # absorb sigmoid into BCELoss
222
-
223
- def get_target_tensor(self, input, target_is_real):
224
- target_tensor = None
225
- if target_is_real:
226
- create_label = ((self.real_label_var is None) or
227
- (self.real_label_var.numel() != input.numel()))
228
- if create_label:
229
- real_tensor = self.Tensor(input.size()).fill_(self.real_label)
230
- self.real_label_var = real_tensor
231
- target_tensor = self.real_label_var
232
- else:
233
- create_label = ((self.fake_label_var is None) or
234
- (self.fake_label_var.numel() != input.numel()))
235
- if create_label:
236
- fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
237
- self.fake_label_var = fake_tensor
238
- target_tensor = self.fake_label_var
239
- return target_tensor
240
-
241
- def __call__(self, input, target_is_real):
242
- if isinstance(input, list):
243
- loss = 0
244
- for input_i in input:
245
- target_tensor = self.get_target_tensor(input_i, target_is_real)
246
- loss += self.loss(input_i, target_tensor)
247
- return loss
248
- else:
249
- target_tensor = self.get_target_tensor(input, target_is_real)
250
- return self.loss(input, target_tensor)
251
-
252
-
253
- class DiscLoss():
254
- def name(self):
255
- return 'SGAN'
256
-
257
- def initialize(self, opt, tensor):
258
- self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
259
-
260
- def get_g_loss(self, net, realA, fakeB, realB):
261
- # First, G(A) should fake the discriminator
262
- pred_fake = net.forward(fakeB)
263
- return self.criterionGAN(pred_fake, 1)
264
-
265
- def get_loss(self, net, realA=None, fakeB=None, realB=None):
266
- pred_fake = None
267
- pred_real = None
268
- loss_D_fake = 0
269
- loss_D_real = 0
270
- # Fake
271
- # stop backprop to the generator by detaching fake_B
272
- # Generated Image Disc Output should be close to zero
273
-
274
- if fakeB is not None:
275
- pred_fake = net.forward(fakeB.detach())
276
- loss_D_fake = self.criterionGAN(pred_fake, 0)
277
-
278
- # Real
279
- if realB is not None:
280
- pred_real = net.forward(realB)
281
- loss_D_real = self.criterionGAN(pred_real, 1)
282
-
283
- # Combined loss
284
- loss_D = (loss_D_fake + loss_D_real) * 0.5
285
- return loss_D, pred_fake, pred_real
286
-
287
-
288
- class DiscLossR(DiscLoss):
289
- # RSGAN from
290
- # https://arxiv.org/abs/1807.00734
291
- def name(self):
292
- return 'RSGAN'
293
-
294
- def initialize(self, opt, tensor):
295
- DiscLoss.initialize(self, opt, tensor)
296
- self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
297
-
298
- def get_g_loss(self, net, realA, fakeB, realB, pred_real=None):
299
- if pred_real is None:
300
- pred_real = net.forward(realB)
301
- pred_fake = net.forward(fakeB)
302
- return self.criterionGAN(pred_fake - pred_real, 1)
303
-
304
- def get_loss(self, net, realA, fakeB, realB):
305
- pred_real = net.forward(realB)
306
- pred_fake = net.forward(fakeB.detach())
307
-
308
- loss_D = self.criterionGAN(pred_real - pred_fake, 1) # BCE_stable loss
309
- return loss_D, pred_fake, pred_real
310
-
311
-
312
- class DiscLossRa(DiscLoss):
313
- # RaSGAN from
314
- # https://arxiv.org/abs/1807.00734
315
- def name(self):
316
- return 'RaSGAN'
317
-
318
- def initialize(self, opt, tensor):
319
- DiscLoss.initialize(self, opt, tensor)
320
- self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
321
-
322
- def get_g_loss(self, net, realA, fakeB, realB, pred_real=None):
323
- if pred_real is None:
324
- pred_real = net.forward(realB)
325
- pred_fake = net.forward(fakeB)
326
-
327
- loss_G = self.criterionGAN(pred_real - torch.mean(pred_fake, dim=0, keepdim=True), 0)
328
- loss_G += self.criterionGAN(pred_fake - torch.mean(pred_real, dim=0, keepdim=True), 1)
329
- return loss_G * 0.5
330
-
331
- def get_loss(self, net, realA, fakeB, realB):
332
- pred_real = net.forward(realB)
333
- pred_fake = net.forward(fakeB.detach())
334
-
335
- loss_D = self.criterionGAN(pred_real - torch.mean(pred_fake, dim=0, keepdim=True), 1)
336
- loss_D += self.criterionGAN(pred_fake - torch.mean(pred_real, dim=0, keepdim=True), 0)
337
- return loss_D * 0.5, pred_fake, pred_real
338
-
339
-
340
- class MS_SSIM_Loss(nn.Module):
341
- def __init__(self):
342
- super().__init__()
343
- self.ms_ssim = MS_SSIM(data_range=1, size_average=True, channel=3)
344
-
345
- def forward(self, output, target):
346
- return 1 - self.ms_ssim(output, target)
347
-
348
-
349
- class SSIM_Loss(nn.Module):
350
- def __init__(self):
351
- super().__init__()
352
- self.ssim = SSIM(data_range=1, size_average=True, channel=3)
353
-
354
- def forward(self, output, target):
355
- return 1 - self.ssim(output, target)
356
-
357
-
358
- def init_loss(opt, tensor):
359
- disc_loss = None
360
- content_loss = None
361
-
362
- loss_dic = {}
363
-
364
- t_pixel_loss = ContentLoss()
365
- t_pixel_loss.initialize(
366
- MultipleLoss([nn.MSELoss(), MS_SSIM_Loss(), GradientLoss()], [1.0, 0.4, 0.6]))
367
-
368
- loss_dic['t_pixel'] = t_pixel_loss
369
-
370
- r_pixel_loss = ContentLoss()
371
- r_pixel_loss.initialize(
372
- MultipleLoss([nn.MSELoss()], [4.0]))
373
-
374
- loss_dic['r_pixel'] = r_pixel_loss
375
- loss_dic['recons'] = ReconsLoss()
376
-
377
- loss_dic['mtv'] = MTVLoss()
378
- loss_dic['ktv'] = KTVLoss()
379
-
380
- if opt.lambda_gan > 0:
381
- if opt.gan_type == 'sgan' or opt.gan_type == 'gan':
382
- disc_loss = DiscLoss()
383
- elif opt.gan_type == 'rsgan':
384
- disc_loss = DiscLossR()
385
- elif opt.gan_type == 'rasgan':
386
- disc_loss = DiscLossRa()
387
- else:
388
- raise ValueError("GAN [%s] not recognized." % opt.gan_type)
389
-
390
- disc_loss.initialize(opt, tensor)
391
- loss_dic['gan'] = disc_loss
392
-
393
- return loss_dic
394
-
395
-
396
- if __name__ == '__main__':
397
- x = torch.randn(3, 32, 224, 224).cuda()
398
- import time
399
-
400
- s = time.time()
401
- out1, out2 = gradient_norm_kernel(x)
402
- t = time.time()
403
- print(t - s)
404
- print(out1.shape, out2.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/networks.py DELETED
@@ -1,335 +0,0 @@
1
- import functools
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- from torch.nn import init
7
- from torch.nn.utils import spectral_norm
8
- from torch.nn import functional as F
9
- ###############################################################################
10
- # Functions
11
- ###############################################################################
12
-
13
-
14
- def weights_init_normal(m):
15
- classname = m.__class__.__name__
16
- # print(classname)
17
- if isinstance(m, nn.Sequential):
18
- return
19
- if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
20
- init.normal_(m.weight.data, 0.0, 0.02)
21
- elif isinstance(m, nn.Linear):
22
- init.normal_(m.weight.data, 0.0, 0.02)
23
- elif isinstance(m, nn.BatchNorm2d):
24
- init.normal_(m.weight.data, 1.0, 0.02)
25
- init.constant_(m.bias.data, 0.0)
26
-
27
-
28
- def weights_init_xavier(m):
29
- classname = m.__class__.__name__
30
- # print(classname)
31
- if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
32
- init.xavier_normal_(m.weight.data, gain=0.02)
33
- elif isinstance(m, nn.Linear):
34
- init.xavier_normal_(m.weight.data, gain=0.02)
35
- elif isinstance(m, nn.BatchNorm2d):
36
- init.normal_(m.weight.data, 1.0, 0.02)
37
- init.constant_(m.bias.data, 0.0)
38
-
39
-
40
- def weights_init_kaiming(m):
41
- classname = m.__class__.__name__
42
- # print(classname)
43
- if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
44
- init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
45
- elif isinstance(m, nn.Linear):
46
- init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
47
- elif isinstance(m, nn.BatchNorm2d):
48
- init.normal_(m.weight.data, 1.0, 0.02)
49
- init.constant_(m.bias.data, 0.0)
50
-
51
-
52
- def weights_init_orthogonal(m):
53
- classname = m.__class__.__name__
54
- print(classname)
55
- if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
56
- init.orthogonal(m.weight.data, gain=1)
57
- elif isinstance(m, nn.Linear):
58
- init.orthogonal(m.weight.data, gain=1)
59
- elif isinstance(m, nn.BatchNorm2d):
60
- init.normal(m.weight.data, 1.0, 0.02)
61
- init.constant_(m.bias.data, 0.0)
62
-
63
-
64
- def init_weights(net, init_type='normal'):
65
- print('[i] initialization method [%s]' % init_type)
66
- if init_type == 'normal':
67
- net.apply(weights_init_normal)
68
- elif init_type == 'xavier':
69
- net.apply(weights_init_xavier)
70
- elif init_type == 'kaiming':
71
- net.apply(weights_init_kaiming)
72
- elif init_type == 'orthogonal':
73
- net.apply(weights_init_orthogonal)
74
- elif init_type == 'edsr':
75
- pass
76
- else:
77
- raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
78
-
79
-
80
- def get_norm_layer(norm_type='instance'):
81
- if norm_type == 'batch':
82
- norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
83
- elif norm_type == 'instance':
84
- norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
85
- elif norm_type == 'none':
86
- norm_layer = None
87
- else:
88
- raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
89
- return norm_layer
90
-
91
-
92
- def define_D(opt, in_channels=3):
93
- # use_sigmoid = opt.gan_type == 'gan'
94
- use_sigmoid = False # incorporate sigmoid into BCE_stable loss
95
-
96
- if opt.which_model_D == 'disc_vgg':
97
- netD = Discriminator_VGG(in_channels, use_sigmoid=use_sigmoid)
98
- init_weights(netD, init_type='kaiming')
99
- elif opt.which_model_D == 'disc_patch':
100
- netD = NLayerDiscriminator(in_channels, 64, 3, nn.InstanceNorm2d, use_sigmoid, getIntermFeat=False)
101
- init_weights(netD, init_type='normal')
102
- elif opt.which_model_D == 'disc_unet':
103
- netD = UNetDiscriminatorSN(in_channels)
104
- else:
105
- raise NotImplementedError('%s is not implemented' %opt.which_model_D)
106
-
107
- if len(opt.gpu_ids) > 0:
108
- assert(torch.cuda.is_available())
109
- netD.cuda(opt.gpu_ids[0])
110
-
111
- return netD
112
-
113
-
114
- def print_network(net):
115
- num_params = 0
116
- for param in net.parameters():
117
- num_params += param.numel()
118
- print(net)
119
- print('Total number of parameters: %d' % num_params)
120
- print('The size of receptive field: %d' % receptive_field(net))
121
-
122
-
123
- def receptive_field(net):
124
- def _f(output_size, ksize, stride, dilation):
125
- return (output_size - 1) * stride + ksize * dilation - dilation + 1
126
-
127
- stats = []
128
- for m in net.modules():
129
- if isinstance(m, nn.Conv2d):
130
- stats.append((m.kernel_size, m.stride, m.dilation))
131
-
132
- rsize = 1
133
- for (ksize, stride, dilation) in reversed(stats):
134
- if type(ksize) == tuple: ksize = ksize[0]
135
- if type(stride) == tuple: stride = stride[0]
136
- if type(dilation) == tuple: dilation = dilation[0]
137
- rsize = _f(rsize, ksize, stride, dilation)
138
- return rsize
139
-
140
-
141
- def debug_network(net):
142
- def _hook(m, i, o):
143
- print(o.size())
144
- for m in net.modules():
145
- m.register_forward_hook(_hook)
146
-
147
-
148
- ##############################################################################
149
- # Classes
150
- ##############################################################################
151
-
152
- # Defines the PatchGAN discriminator with the specified arguments.
153
- class NLayerDiscriminator(nn.Module):
154
- def __init__(self, input_nc, ndf=64, n_layers=3,
155
- norm_layer=nn.BatchNorm2d, use_sigmoid=False,
156
- branch=1, bias=True, getIntermFeat=False):
157
- super(NLayerDiscriminator, self).__init__()
158
- self.getIntermFeat = getIntermFeat
159
- self.n_layers = n_layers
160
- kw = 4
161
- padw = int(np.ceil((kw-1.0)/2))
162
- sequence = [[nn.Conv2d(input_nc*branch, ndf*branch, kernel_size=kw, stride=2, padding=padw, groups=branch, bias=True), nn.LeakyReLU(0.2, True)]]
163
-
164
- nf = ndf
165
- for n in range(1, n_layers):
166
- nf_prev = nf
167
- nf = min(nf * 2, 512)
168
- sequence += [[
169
- nn.Conv2d(nf_prev*branch, nf*branch, groups=branch, kernel_size=kw, stride=2, padding=padw, bias=bias),
170
- norm_layer(nf*branch), nn.LeakyReLU(0.2, True)
171
- ]]
172
-
173
- nf_prev = nf
174
- nf = min(nf * 2, 512)
175
- sequence += [[
176
- nn.Conv2d(nf_prev*branch, nf*branch, groups=branch, kernel_size=kw, stride=1, padding=padw, bias=bias),
177
- norm_layer(nf*branch),
178
- nn.LeakyReLU(0.2, True)
179
- ]]
180
-
181
- sequence += [[nn.Conv2d(nf*branch, 1*branch, groups=branch, kernel_size=kw, stride=1, padding=padw, bias=True)]]
182
-
183
- if use_sigmoid:
184
- sequence += [[nn.Sigmoid()]]
185
-
186
- if getIntermFeat:
187
- for n in range(len(sequence)):
188
- setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
189
- else:
190
- sequence_stream = []
191
- for n in range(len(sequence)):
192
- sequence_stream += sequence[n]
193
- self.model = nn.Sequential(*sequence_stream)
194
-
195
- def forward(self, input):
196
- if self.getIntermFeat:
197
- res = [input]
198
- for n in range(self.n_layers+2):
199
- model = getattr(self, 'model'+str(n))
200
- res.append(model(res[-1]))
201
- return res[1:]
202
- else:
203
- return self.model(input)
204
-
205
-
206
- class Discriminator_VGG(nn.Module):
207
- def __init__(self, in_channels=3, use_sigmoid=True):
208
- super(Discriminator_VGG, self).__init__()
209
- def conv(*args, **kwargs):
210
- return nn.Conv2d(*args, **kwargs)
211
-
212
- num_groups = 32
213
-
214
- body = [
215
- conv(in_channels, 64, kernel_size=3, padding=1), # 224
216
- nn.LeakyReLU(0.2),
217
-
218
- conv(64, 64, kernel_size=3, stride=2, padding=1), # 112
219
- nn.GroupNorm(num_groups, 64),
220
- nn.LeakyReLU(0.2),
221
-
222
- conv(64, 128, kernel_size=3, padding=1),
223
- nn.GroupNorm(num_groups, 128),
224
- nn.LeakyReLU(0.2),
225
-
226
- conv(128, 128, kernel_size=3, stride=2, padding=1), # 56
227
- nn.GroupNorm(num_groups, 128),
228
- nn.LeakyReLU(0.2),
229
-
230
- conv(128, 256, kernel_size=3, padding=1),
231
- nn.GroupNorm(num_groups, 256),
232
- nn.LeakyReLU(0.2),
233
-
234
- conv(256, 256, kernel_size=3, stride=2, padding=1), # 28
235
- nn.GroupNorm(num_groups, 256),
236
- nn.LeakyReLU(0.2),
237
-
238
- conv(256, 512, kernel_size=3, padding=1),
239
- nn.GroupNorm(num_groups, 512),
240
- nn.LeakyReLU(0.2),
241
-
242
- conv(512, 512, kernel_size=3, stride=2, padding=1), # 14
243
- nn.GroupNorm(num_groups, 512),
244
- nn.LeakyReLU(0.2),
245
-
246
- conv(512, 512, kernel_size=3, stride=1, padding=1),
247
- nn.GroupNorm(num_groups, 512),
248
- nn.LeakyReLU(0.2),
249
-
250
- conv(512, 512, kernel_size=3, stride=2, padding=1), # 7
251
- nn.GroupNorm(num_groups, 512),
252
- nn.LeakyReLU(0.2),
253
- ]
254
-
255
- tail = [
256
- nn.AdaptiveAvgPool2d(1),
257
- nn.Conv2d(512, 1024, kernel_size=1),
258
- nn.LeakyReLU(0.2),
259
- nn.Conv2d(1024, 1, kernel_size=1)
260
- ]
261
-
262
- if use_sigmoid:
263
- tail.append(nn.Sigmoid())
264
-
265
- self.body = nn.Sequential(*body)
266
- self.tail = nn.Sequential(*tail)
267
-
268
- def forward(self, x):
269
- x = self.body(x)
270
- out = self.tail(x)
271
- return out
272
-
273
- class UNetDiscriminatorSN(nn.Module):
274
- """Defines a U-Net discriminator with spectral normalization (SN)
275
-
276
- It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
277
-
278
- Arg:
279
- num_in_ch (int): Channel number of inputs. Default: 3.
280
- num_feat (int): Channel number of base intermediate features. Default: 64.
281
- skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
282
- """
283
-
284
- def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
285
- super(UNetDiscriminatorSN, self).__init__()
286
- self.skip_connection = skip_connection
287
- norm = spectral_norm
288
- # the first convolution
289
- self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
290
- # downsample
291
- self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
292
- self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
293
- self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
294
- # upsample
295
- self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
296
- self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
297
- self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
298
- # extra convolutions
299
- self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
300
- self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
301
- self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
302
-
303
- def forward(self, x, illu = None):
304
- # downsample
305
- ingress = self.conv0(x)
306
- if illu is not None : ingress = ingress * (1 - illu * 2)
307
- x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
308
- x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
309
- x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
310
- x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
311
-
312
- # upsample
313
- x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
314
- x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
315
-
316
- if self.skip_connection:
317
- x4 = x4 + x2
318
- x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
319
- x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
320
-
321
- if self.skip_connection:
322
- x5 = x5 + x1
323
- x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
324
- x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
325
-
326
- if self.skip_connection:
327
- x6 = x6 + x0
328
-
329
- # extra convolutions
330
- out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
331
- out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
332
- out = self.conv9(out)
333
-
334
- # print(out.shape, 'real_esrgan out shape')
335
- return out #if illu is None else out * illu
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/vgg.py DELETED
@@ -1,66 +0,0 @@
1
- from collections import namedtuple
2
-
3
- import torch
4
- from torchvision import models
5
-
6
-
7
- class Vgg16(torch.nn.Module):
8
- def __init__(self, requires_grad=False):
9
- super(Vgg16, self).__init__()
10
- vgg_pretrained_features = models.vgg16(pretrained=True).features
11
- self.slice1 = torch.nn.Sequential()
12
- self.slice2 = torch.nn.Sequential()
13
- self.slice3 = torch.nn.Sequential()
14
- self.slice4 = torch.nn.Sequential()
15
- for x in range(4):
16
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
17
- for x in range(4, 9):
18
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
19
- for x in range(9, 16):
20
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
21
- for x in range(16, 23):
22
- self.slice4.add_module(str(x), vgg_pretrained_features[x])
23
- if not requires_grad:
24
- for param in self.parameters():
25
- param.requires_grad = False
26
-
27
- def forward(self, X):
28
- h = self.slice1(X)
29
- h_relu1_2 = h
30
- h = self.slice2(h)
31
- h_relu2_2 = h
32
- h = self.slice3(h)
33
- h_relu3_3 = h
34
- h = self.slice4(h)
35
- h_relu4_3 = h
36
- vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
37
- out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
38
- return out
39
-
40
-
41
- class Vgg19(torch.nn.Module):
42
- def __init__(self, requires_grad=False):
43
- super(Vgg19, self).__init__()
44
- self.vgg_pretrained_features = models.vgg19(pretrained=True).features
45
-
46
- if not requires_grad:
47
- for param in self.parameters():
48
- param.requires_grad = False
49
-
50
- def forward(self, X, indices=None):
51
- if indices is None:
52
- indices = [2, 7, 12, 21, 30]
53
- out = []
54
- for i in range(indices[-1]):
55
- X = self.vgg_pretrained_features[i](X)
56
- if (i + 1) in indices:
57
- out.append(X)
58
-
59
- return out
60
-
61
-
62
- if __name__ == '__main__':
63
- vgg = Vgg19()
64
- import ipdb
65
-
66
- ipdb.set_trace()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/vit_feature_extractor.py DELETED
@@ -1,164 +0,0 @@
1
- import torch
2
-
3
-
4
- def attn_cosine_sim(x, eps=1e-08):
5
- assert x.shape[0] == 1, 'x.shape[0] must eqs 1'
6
- x = x[0] # TEMP: getting rid of redundant dimension, TBF
7
- norm1 = x.norm(dim=2, keepdim=True)
8
- factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps)
9
- sim_matrix = (x @ x.permute(0, 2, 1)) / factor
10
- return sim_matrix
11
-
12
-
13
- class VitExtractor:
14
- BLOCK_KEY = 'block'
15
- ATTN_KEY = 'attn'
16
- PATCH_IMD_KEY = 'patch_imd'
17
- QKV_KEY = 'qkv'
18
- KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY]
19
-
20
- def __init__(self, model_name, device):
21
- self.model = torch.hub.load('facebookresearch/dino:main', model_name).to(device)
22
- self.model.eval()
23
- self.model_name = model_name
24
- self.hook_handlers = []
25
- self.layers_dict = {}
26
- self.outputs_dict = {}
27
- for key in VitExtractor.KEY_LIST:
28
- self.layers_dict[key] = []
29
- self.outputs_dict[key] = []
30
- self._init_hooks_data()
31
-
32
- def _init_hooks_data(self):
33
- self.layers_dict[VitExtractor.BLOCK_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
34
- self.layers_dict[VitExtractor.ATTN_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
35
- self.layers_dict[VitExtractor.QKV_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
36
- self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
37
- for key in VitExtractor.KEY_LIST:
38
- # self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else []
39
- self.outputs_dict[key] = []
40
-
41
- def _register_hooks(self, **kwargs):
42
- for block_idx, block in enumerate(self.model.blocks):
43
- if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]:
44
- self.hook_handlers.append(block.register_forward_hook(self._get_block_hook()))
45
- if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]:
46
- self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook()))
47
- if block_idx in self.layers_dict[VitExtractor.QKV_KEY]:
48
- self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook()))
49
- if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]:
50
- self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook()))
51
-
52
- def _clear_hooks(self):
53
- for handler in self.hook_handlers:
54
- handler.remove()
55
- self.hook_handlers = []
56
-
57
- def _get_block_hook(self):
58
- def _get_block_output(model, input, output):
59
- self.outputs_dict[VitExtractor.BLOCK_KEY].append(output)
60
-
61
- return _get_block_output
62
-
63
- def _get_attn_hook(self):
64
- def _get_attn_output(model, inp, output):
65
- self.outputs_dict[VitExtractor.ATTN_KEY].append(output)
66
-
67
- return _get_attn_output
68
-
69
- def _get_qkv_hook(self):
70
- def _get_qkv_output(model, inp, output):
71
- self.outputs_dict[VitExtractor.QKV_KEY].append(output)
72
-
73
- return _get_qkv_output
74
-
75
- # TODO: CHECK ATTN OUTPUT TUPLE
76
- def _get_patch_imd_hook(self):
77
- def _get_attn_output(model, inp, output):
78
- self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0])
79
-
80
- return _get_attn_output
81
-
82
- def get_feature_from_input(self, input_img): # List([B, N, D])
83
- self._register_hooks()
84
- self.model(input_img)
85
- feature = self.outputs_dict[VitExtractor.BLOCK_KEY]
86
- self._clear_hooks()
87
- self._init_hooks_data()
88
- return feature
89
-
90
- def get_qkv_feature_from_input(self, input_img):
91
- self._register_hooks()
92
- self.model(input_img)
93
- feature = self.outputs_dict[VitExtractor.QKV_KEY]
94
- self._clear_hooks()
95
- self._init_hooks_data()
96
- return feature
97
-
98
- def get_attn_feature_from_input(self, input_img):
99
- self._register_hooks()
100
- self.model(input_img)
101
- feature = self.outputs_dict[VitExtractor.ATTN_KEY]
102
- self._clear_hooks()
103
- self._init_hooks_data()
104
- return feature
105
-
106
- def get_patch_size(self):
107
- return 8 if "8" in self.model_name else 16
108
-
109
- def get_width_patch_num(self, input_img_shape):
110
- b, c, h, w = input_img_shape
111
- patch_size = self.get_patch_size()
112
- return w // patch_size
113
-
114
- def get_height_patch_num(self, input_img_shape):
115
- b, c, h, w = input_img_shape
116
- patch_size = self.get_patch_size()
117
- return h // patch_size
118
-
119
- def get_patch_num(self, input_img_shape):
120
- patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape))
121
- return patch_num
122
-
123
- def get_head_num(self):
124
- if "dino" in self.model_name:
125
- return 6 if "s" in self.model_name else 12
126
- return 6 if "small" in self.model_name else 12
127
-
128
- def get_embedding_dim(self):
129
- if "dino" in self.model_name:
130
- return 384 if "s" in self.model_name else 768
131
- return 384 if "small" in self.model_name else 768
132
-
133
- def get_queries_from_qkv(self, qkv, input_img_shape):
134
- patch_num = self.get_patch_num(input_img_shape)
135
- head_num = self.get_head_num()
136
- embedding_dim = self.get_embedding_dim()
137
- q = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[0]
138
- return q
139
-
140
- def get_keys_from_qkv(self, qkv, input_img_shape):
141
- patch_num = self.get_patch_num(input_img_shape)
142
- head_num = self.get_head_num()
143
- embedding_dim = self.get_embedding_dim()
144
- k = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[1]
145
- return k
146
-
147
- def get_values_from_qkv(self, qkv, input_img_shape):
148
- patch_num = self.get_patch_num(input_img_shape)
149
- head_num = self.get_head_num()
150
- embedding_dim = self.get_embedding_dim()
151
- v = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[2]
152
- return v
153
-
154
- def get_keys_from_input(self, input_img, layer_num):
155
- qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num]
156
- keys = self.get_keys_from_qkv(qkv_features, input_img.shape)
157
- return keys
158
-
159
- def get_keys_self_sim_from_input(self, input_img, layer_num):
160
- keys = self.get_keys_from_input(input_img, layer_num=layer_num)
161
- h, t, d = keys.shape
162
- concatenated_keys = keys.transpose(0, 1).reshape(t, h * d)
163
- ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...])
164
- return ssim_map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
options/__init__.py DELETED
File without changes
options/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (151 Bytes)
 
options/__pycache__/base_option.cpython-38.pyc DELETED
Binary file (2.68 kB)
 
options/base_option.py DELETED
@@ -1,47 +0,0 @@
1
- import argparse
2
- import models
3
-
4
- model_names = sorted(name for name in models.__dict__
5
- if name.islower() and not name.startswith("__")
6
- and callable(models.__dict__[name]))
7
-
8
-
9
- class BaseOptions():
10
- def __init__(self):
11
- self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
12
- self.initialized = False
13
-
14
- def initialize(self):
15
- # experiment specifics
16
- self.parser.add_argument('--name', type=str, default='ytmt_ucs_sirs',
17
- help='name of the experiment. It decides where to store samples and models')
18
- self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
19
- self.parser.add_argument('--model', type=str, default='revcol', help='chooses which model to use.')
20
- self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
21
- self.parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
22
- self.parser.add_argument('--resume_epoch', '-re', type=int, default=None,
23
- help='checkpoint to use. (default: latest')
24
- self.parser.add_argument('--seed', type=int, default=2018, help='random seed to use. Default=2018')
25
- self.parser.add_argument('--supp_eval', action='store_true', help='supplementary evaluation')
26
- self.parser.add_argument('--start_now', action='store_true', help='supplementary evaluation')
27
- self.parser.add_argument('--testr', action='store_true', help='test for reflections')
28
- self.parser.add_argument('--select', type=str, default=None)
29
-
30
- # for setting input
31
- self.parser.add_argument('--serial_batches', action='store_true',
32
- help='if true, takes images in order to make batches, otherwise takes them randomly')
33
- self.parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data')
34
- self.parser.add_argument('--max_dataset_size', type=int, default=None,
35
- help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
36
-
37
- # for display
38
- self.parser.add_argument('--no-log', action='store_true', help='disable tf logger?')
39
- self.parser.add_argument('--no-verbose', action='store_true', help='disable verbose info?')
40
- self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
41
- self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
42
- self.parser.add_argument('--display_id', type=int, default=0,
43
- help='window id of the web display (use 0 to disable visdom)')
44
- self.parser.add_argument('--display_single_pane_ncols', type=int, default=0,
45
- help='if positive, display all images in a single visdom web panel with certain number of images per row.')
46
-
47
- self.initialized = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
options/net_options/__init__.py DELETED
File without changes
options/net_options/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (163 Bytes)
 
options/net_options/__pycache__/base_options.cpython-38.pyc DELETED
Binary file (2.4 kB)
 
options/net_options/__pycache__/train_options.cpython-38.pyc DELETED
Binary file (3.54 kB)
 
options/net_options/base_options.py DELETED
@@ -1,71 +0,0 @@
1
- from options.base_option import BaseOptions as Base
2
- from util import util
3
- import os
4
- import torch
5
- import numpy as np
6
- import random
7
-
8
- class BaseOptions(Base):
9
- def initialize(self):
10
- Base.initialize(self)
11
- # experiment specifics
12
- self.parser.add_argument('--inet', type=str, default='ytmt_ucs', help='chooses which architecture to use for inet.')
13
- self.parser.add_argument('--icnn_path', type=str, default=None, help='icnn checkpoint to use.')
14
- self.parser.add_argument('--init_type', type=str, default='edsr', help='network initialization [normal|xavier|kaiming|orthogonal|uniform]')
15
- # for network
16
- self.parser.add_argument('--hyper', action='store_true', help='if true, augment input with vgg hypercolumn feature')
17
-
18
- self.initialized = True
19
-
20
- def parse(self):
21
- if not self.initialized:
22
- self.initialize()
23
- self.opt = self.parser.parse_args()
24
- self.opt.isTrain = self.isTrain # train or test
25
-
26
- torch.backends.cudnn.deterministic = True
27
- torch.manual_seed(self.opt.seed)
28
- np.random.seed(self.opt.seed) # seed for every module
29
- random.seed(self.opt.seed)
30
-
31
- str_ids = self.opt.gpu_ids.split(',')
32
- self.opt.gpu_ids = []
33
- for str_id in str_ids:
34
- id = int(str_id)
35
- if id >= 0:
36
- self.opt.gpu_ids.append(id)
37
-
38
- # set gpu ids
39
- if len(self.opt.gpu_ids) > 0:
40
- torch.cuda.set_device(self.opt.gpu_ids[0])
41
-
42
- args = vars(self.opt)
43
-
44
- print('------------ Options -------------')
45
- for k, v in sorted(args.items()):
46
- print('%s: %s' % (str(k), str(v)))
47
- print('-------------- End ----------------')
48
-
49
- # save to the disk
50
- self.opt.name = self.opt.name or '_'.join([self.opt.model])
51
- expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
52
- util.mkdirs(expr_dir)
53
- file_name = os.path.join(expr_dir, 'opt.txt')
54
- with open(file_name, 'wt') as opt_file:
55
- opt_file.write('------------ Options -------------\n')
56
- for k, v in sorted(args.items()):
57
- opt_file.write('%s: %s\n' % (str(k), str(v)))
58
- opt_file.write('-------------- End ----------------\n')
59
-
60
- if self.opt.debug:
61
- self.opt.display_freq = 20
62
- self.opt.print_freq = 20
63
- self.opt.nEpochs = 40
64
- self.opt.max_dataset_size = 100
65
- self.opt.no_log = False
66
- self.opt.nThreads = 0
67
- self.opt.decay_iter = 0
68
- self.opt.serial_batches = True
69
- self.opt.no_flip = True
70
-
71
- return self.opt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
options/net_options/train_options.py DELETED
@@ -1,75 +0,0 @@
1
- from .base_options import BaseOptions
2
-
3
-
4
- class TrainOptions(BaseOptions):
5
- def initialize(self):
6
- BaseOptions.initialize(self)
7
- # for displays
8
- self.parser.add_argument('--display_freq', type=int, default=100,
9
- help='frequency of showing training results on screen')
10
- self.parser.add_argument('--update_html_freq', type=int, default=1000,
11
- help='frequency of saving training results to html')
12
- self.parser.add_argument('--print_freq', type=int, default=100,
13
- help='frequency of showing training results on console')
14
- self.parser.add_argument('--eval_freq', type=int, default=1, help='frequency of evaluation')
15
- self.parser.add_argument('--save_freq', type=int, default=1, help='frequency of save eval samples')
16
- self.parser.add_argument('--no_html', action='store_true',
17
- help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
18
- self.parser.add_argument('--save_epoch_freq', type=int, default=1,
19
- help='frequency of saving checkpoints at the end of epochs')
20
- self.parser.add_argument('--debug', action='store_true',
21
- help='only do one epoch and displays at each iteration')
22
- self.parser.add_argument('--finetune', action='store_true',
23
- help='finetune the network using identity inputs and outputs')
24
- self.parser.add_argument('--if_align', action='store_true',
25
- help='if align 4x')
26
-
27
- # self.parser.add_argument('--graph', action='store_true',
28
- # help='print computation graph')
29
- # for training (Note: in train_sirs.py, we mannually tune the training protocol, but you can also use following setting by modifying the code in errnet_model.py)
30
- self.parser.add_argument('--nEpochs', '-n', type=int, default=60, help='# of epochs to run')
31
- self.parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate for adam')
32
- self.parser.add_argument('--wd', type=float, default=0, help='weight decay for adam')
33
-
34
- self.parser.add_argument('--r_pixel_weight', '-rw', type=float, default=1.0, help='weight for r_pixel loss')
35
-
36
- self.parser.add_argument('--low_sigma', type=float, default=2, help='min sigma in synthetic dataset')
37
- self.parser.add_argument('--high_sigma', type=float, default=5, help='max sigma in synthetic dataset')
38
- self.parser.add_argument('--low_gamma', type=float, default=1.3, help='max gamma in synthetic dataset')
39
- self.parser.add_argument('--high_gamma', type=float, default=1.3, help='max gamma in synthetic dataset')
40
-
41
- # data augmentation
42
- self.parser.add_argument('--real20_size', type=int, default=420, help='scale images to compat size')
43
- self.parser.add_argument('--batchSize', '-b', type=int, default=2, help='input batch size')
44
- self.parser.add_argument('--loadSize', type=str, default='224,336,448', help='scale images to multiple size')
45
- self.parser.add_argument('--fineSize', type=str, default='224,224', help='then crop to this size')
46
- self.parser.add_argument('--no_flip', action='store_true',
47
- help='if specified, do not flip the images for data augmentation')
48
- self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop',
49
- help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
50
- self.parser.add_argument('--debug_eval', action='store_true',
51
- help='if specified, do not flip the images for data augmentation')
52
- self.parser.add_argument('--graph', action='store_true', help='print graph')
53
-
54
- # for discriminator
55
- self.parser.add_argument('--which_model_D', type=str, default='disc_vgg', choices=['disc_vgg', 'disc_patch'])
56
- self.parser.add_argument('--gan_type', type=str, default='rasgan',
57
- help='gan/sgan : Vanilla GAN; rasgan : relativistic gan')
58
- # loss weight
59
- self.parser.add_argument('--unaligned_loss', type=str, default='vgg',
60
- help='learning rate policy: vgg|mse|ctx|ctx_vgg')
61
- self.parser.add_argument('--tv_type', type=str, default=None, choices=['ktv', 'mtv'])
62
- self.parser.add_argument('--vgg_layer', type=int, default=31, help='vgg layer of unaligned loss')
63
- self.parser.add_argument('--init_lr', type=float, default=1e-2, help='initial learning rate')
64
- self.parser.add_argument('--fixed_lr', type=float, default=0, help='initial learning rate')
65
- self.parser.add_argument('--lambda_gan', type=float, default=0.01, help='weight for gan loss')
66
- self.parser.add_argument('--lambda_vgg', type=float, default=0.1, help='weight for vgg loss')
67
- self.parser.add_argument('--weight_loss',type=float,default=0.25,help='weight fot overall loss')
68
- self.parser.add_argument('--num_subnet',type=int,default=4,help='num_number of subnet')
69
- self.parser.add_argument('--dataset',type=float,default=0.5,help='the setting of dataset')
70
- self.parser.add_argument('--loss_col',type=int,default=4,help='numcol for loss')
71
- self.parser.add_argument('--drop_path',type=float,default=0.6,help='drop_path')
72
-
73
-
74
-
75
- self.isTrain = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pretrained/README.md DELETED
@@ -1,3 +0,0 @@
1
- # Pretrained models
2
-
3
- This folder is for pretrained models.
 
 
 
 
script.py DELETED
@@ -1,64 +0,0 @@
1
- import torch
2
-
3
- # Load the original weights file
4
- original_weights = torch.load('/home/xteam/zhaohao/pycharmproject/YTMT/merge_stem_reg_014_00055524.pt')
5
-
6
- # Create a new weights dictionary
7
- # new_weights = {}
8
-
9
- # # Iterate through the original weights dictionary
10
- # for key, value in original_weights.items():
11
- # # Check if the key contains 'projec_shit'
12
- # if 'projback_shit' in key:
13
- # # Replace 'projec_shit' with 'project_'
14
- # new_key = key.replace('projback_shit', 'projback_')
15
- # new_weights[new_key] = value
16
- # else:
17
- # # If the key doesn't contain 'projec_shit', keep it unchanged
18
- # new_weights[key] = value
19
- # if 'projback_shit_2' in key:
20
- # # Replace 'projec_shit' with 'project_'
21
- # new_key = key.replace('projback_shit_2', 'projback_2')
22
- # new_weights[new_key] = value
23
- # else:
24
- # # If the key doesn't contain 'projec_shit', keep it unchanged
25
- # new_weights[key] = value
26
-
27
- # # Save the modified weights
28
- # torch.save(new_weights, '/home/xteam/zhaohao/pycharmproject/RDNet/new_weights.pth')
29
-
30
- # print("Weights file has been updated.")
31
-
32
- # # 打印原始权重字典中的所有键,以检查确切的层名称
33
- # print("原始权重文件中的层名:")
34
- # for key in original_weights['icnn'].keys():
35
- # print(key)
36
-
37
- # 创建一个新的权重字典
38
- new_weights = {'icnn': {}}
39
-
40
- # 遍历原始权重字典
41
- for key, value in original_weights['icnn'].items():
42
- # 检查并替换包含 'projback_shit' 的键
43
- if 'projback_shit_2' in key:
44
- new_key = key.replace('projback_shit_2', 'projback_2')
45
- new_weights['icnn'][new_key] = value
46
-
47
- # 检查并替换包含 'projback_shit_2' 的键
48
- elif 'projback_shit' in key:
49
- new_key = key.replace('projback_shit', 'projback_')
50
- new_weights['icnn'][new_key] = value
51
- else:
52
- # 如果键不包含上述字符串,保持不变
53
- new_weights['icnn'][key] = value
54
-
55
- # 打印新的权重字典中的所有键,以验证更改
56
- print("\n更新后的权重文件中的层名:")
57
- for key in new_weights['icnn'].keys():
58
- print(key)
59
-
60
- # 保存修改后的权重
61
- torch.save(new_weights, '/home/xteam/zhaohao/pycharmproject/RDNet/new_weights_4.pth')
62
-
63
- print("\n权重文件已更新。")
64
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_sirs.py DELETED
@@ -1,60 +0,0 @@
1
- import os
2
- from os.path import join
3
-
4
- import torch.backends.cudnn as cudnn
5
-
6
- # import data.sirs_dataset as datasets
7
- import data.dataset_sir as datasets
8
- from data.image_folder import read_fns
9
- from engine import Engine
10
- from options.net_options.train_options import TrainOptions
11
- from tools import mutils
12
-
13
- opt = TrainOptions().parse()
14
-
15
- opt.isTrain = False
16
- cudnn.benchmark = True
17
- opt.no_log = True
18
- opt.display_id = 0
19
- opt.verbose = False
20
- datadir = os.path.join(os.path.expanduser('~'), '/opt/datasets/sirs')
21
-
22
- eval_dataset_real = datasets.DSRTestDataset(join(datadir, f'test/real20_{opt.real20_size}'),
23
- fns=read_fns('data/real_test.txt'), if_align=opt.if_align)
24
- eval_dataset_solidobject = datasets.DSRTestDataset(join(datadir, 'test/SIR2/SolidObjectDataset'),
25
- if_align=opt.if_align)
26
- eval_dataset_postcard = datasets.DSRTestDataset(join(datadir, 'test/SIR2/PostcardDataset'), if_align=opt.if_align)
27
- eval_dataset_wild = datasets.DSRTestDataset(join(datadir, 'test/SIR2/WildSceneDataset'), if_align=opt.if_align)
28
-
29
- eval_dataloader_real = datasets.DataLoader(
30
- eval_dataset_real, batch_size=1, shuffle=True,
31
- num_workers=opt.nThreads, pin_memory=True)
32
-
33
- eval_dataloader_solidobject = datasets.DataLoader(
34
- eval_dataset_solidobject, batch_size=1, shuffle=False,
35
- num_workers=opt.nThreads, pin_memory=True)
36
-
37
- eval_dataloader_postcard = datasets.DataLoader(
38
- eval_dataset_postcard, batch_size=1, shuffle=False,
39
- num_workers=opt.nThreads, pin_memory=True)
40
-
41
- eval_dataloader_wild = datasets.DataLoader(
42
- eval_dataset_wild, batch_size=1, shuffle=False,
43
- num_workers=opt.nThreads, pin_memory=True)
44
-
45
- engine = Engine(opt, eval_dataset_real, eval_dataset_solidobject, eval_dataset_postcard, eval_dataloader_wild)
46
-
47
- """Main Loop"""
48
- result_dir = os.path.join('./results', opt.name, mutils.get_formatted_time())
49
-
50
- res1 = engine.eval(eval_dataloader_real, dataset_name='testdata_real',
51
- savedir=join(result_dir, 'real20'), suffix='real20')
52
-
53
- res2 = engine.eval(eval_dataloader_solidobject, dataset_name='testdata_solidobject',
54
- savedir=join(result_dir, 'solidobject'), suffix='solidobject')
55
- res3 = engine.eval(eval_dataloader_postcard, dataset_name='testdata_postcard',
56
- savedir=join(result_dir, 'postcard'), suffix='postcard')
57
-
58
- res4 = engine.eval(eval_dataloader_wild, dataset_name='testdata_wild',
59
- savedir=join(result_dir, 'wild'), suffix='wild')
60
-