update
Browse files- =2.0 +26 -0
- README.md +1 -1
- data/VOC2012_224_train_png.txt +0 -0
- data/__pycache__/dataset_sir.cpython-38.pyc +0 -0
- data/__pycache__/image_folder.cpython-38.pyc +0 -0
- data/__pycache__/torchdata.cpython-38.pyc +0 -0
- data/__pycache__/transforms.cpython-38.pyc +0 -0
- data/dataset_sir.py +0 -332
- data/image_folder.py +0 -51
- data/real_test.txt +0 -20
- data/torchdata.py +0 -67
- data/transforms.py +0 -301
- engine.py +0 -178
- figures/Input_car.jpg +0 -0
- figures/Input_class.png +0 -3
- figures/Input_green.png +0 -3
- figures/Ours_car.png +0 -3
- figures/Ours_class.png +0 -3
- figures/Ours_green.png +0 -3
- figures/Ours_white.png +0 -3
- figures/Title.png +0 -0
- figures/input_white.jpg +0 -0
- figures/net.png +0 -3
- figures/result.png +0 -3
- figures/vis.png +0 -3
- models/__init__.py +0 -11
- models/__pycache__/__init__.cpython-310.pyc +0 -0
- models/__pycache__/cls_model_eval_nocls_reg.cpython-310.pyc +0 -0
- models/__pycache__/losses.cpython-310.pyc +0 -0
- models/base_model.py +0 -71
- models/cls_model_eval_nocls_reg.py +0 -517
- models/losses.py +0 -468
- models/losses_opt.py +0 -404
- models/networks.py +0 -335
- models/vgg.py +0 -66
- models/vit_feature_extractor.py +0 -164
- options/__init__.py +0 -0
- options/__pycache__/__init__.cpython-38.pyc +0 -0
- options/__pycache__/base_option.cpython-38.pyc +0 -0
- options/base_option.py +0 -47
- options/net_options/__init__.py +0 -0
- options/net_options/__pycache__/__init__.cpython-38.pyc +0 -0
- options/net_options/__pycache__/base_options.cpython-38.pyc +0 -0
- options/net_options/__pycache__/train_options.cpython-38.pyc +0 -0
- options/net_options/base_options.py +0 -71
- options/net_options/train_options.py +0 -75
- pretrained/README.md +0 -3
- script.py +0 -64
- test_sirs.py +0 -60
=2.0
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Requirement already satisfied: torch in /usr/local/lib/python3.10/site-packages (2.6.0)
|
| 2 |
+
Requirement already satisfied: torchvision in /usr/local/lib/python3.10/site-packages (0.21.0)
|
| 3 |
+
Requirement already satisfied: filelock in /usr/local/lib/python3.10/site-packages (from torch) (3.17.0)
|
| 4 |
+
Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.10/site-packages (from torch) (4.12.2)
|
| 5 |
+
Requirement already satisfied: networkx in /usr/local/lib/python3.10/site-packages (from torch) (3.4.2)
|
| 6 |
+
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/site-packages (from torch) (3.1.5)
|
| 7 |
+
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/site-packages (from torch) (2024.12.0)
|
| 8 |
+
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch) (12.4.127)
|
| 9 |
+
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch) (12.4.127)
|
| 10 |
+
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch) (12.4.127)
|
| 11 |
+
Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/site-packages (from torch) (9.1.0.70)
|
| 12 |
+
Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.10/site-packages (from torch) (12.4.5.8)
|
| 13 |
+
Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.10/site-packages (from torch) (11.2.1.3)
|
| 14 |
+
Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.10/site-packages (from torch) (10.3.5.147)
|
| 15 |
+
Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.10/site-packages (from torch) (11.6.1.9)
|
| 16 |
+
Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.10/site-packages (from torch) (12.3.1.170)
|
| 17 |
+
Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.10/site-packages (from torch) (0.6.2)
|
| 18 |
+
Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/site-packages (from torch) (2.21.5)
|
| 19 |
+
Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch) (12.4.127)
|
| 20 |
+
Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.10/site-packages (from torch) (12.4.127)
|
| 21 |
+
Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.10/site-packages (from torch) (3.2.0)
|
| 22 |
+
Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/site-packages (from torch) (1.13.1)
|
| 23 |
+
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/site-packages (from sympy==1.13.1->torch) (1.3.0)
|
| 24 |
+
Requirement already satisfied: numpy in /usr/local/lib/python3.10/site-packages (from torchvision) (2.2.3)
|
| 25 |
+
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/site-packages (from torchvision) (10.4.0)
|
| 26 |
+
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/site-packages (from jinja2->torch) (2.1.5)
|
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: 💻
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: blue
|
|
|
|
| 1 |
---
|
| 2 |
+
title: RDNet
|
| 3 |
emoji: 💻
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: blue
|
data/VOC2012_224_train_png.txt
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/__pycache__/dataset_sir.cpython-38.pyc
DELETED
|
Binary file (10.9 kB)
|
|
|
data/__pycache__/image_folder.cpython-38.pyc
DELETED
|
Binary file (1.58 kB)
|
|
|
data/__pycache__/torchdata.cpython-38.pyc
DELETED
|
Binary file (2.86 kB)
|
|
|
data/__pycache__/transforms.cpython-38.pyc
DELETED
|
Binary file (9.37 kB)
|
|
|
data/dataset_sir.py
DELETED
|
@@ -1,332 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
import os.path
|
| 3 |
-
import os.path
|
| 4 |
-
import random
|
| 5 |
-
from os.path import join
|
| 6 |
-
|
| 7 |
-
import cv2
|
| 8 |
-
import numpy as np
|
| 9 |
-
import torch.utils.data
|
| 10 |
-
import torchvision.transforms.functional as TF
|
| 11 |
-
from PIL import Image
|
| 12 |
-
from scipy.signal import convolve2d
|
| 13 |
-
|
| 14 |
-
from data.image_folder import make_dataset
|
| 15 |
-
from data.torchdata import Dataset as BaseDataset
|
| 16 |
-
from data.transforms import to_tensor
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def __scale_width(img, target_width):
|
| 20 |
-
ow, oh = img.size
|
| 21 |
-
if (ow == target_width):
|
| 22 |
-
return img
|
| 23 |
-
w = target_width
|
| 24 |
-
h = int(target_width * oh / ow)
|
| 25 |
-
h = math.ceil(h / 2.) * 2 # round up to even
|
| 26 |
-
return img.resize((w, h), Image.BICUBIC)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def __scale_height(img, target_height):
|
| 30 |
-
ow, oh = img.size
|
| 31 |
-
if (oh == target_height):
|
| 32 |
-
return img
|
| 33 |
-
h = target_height
|
| 34 |
-
w = int(target_height * ow / oh)
|
| 35 |
-
w = math.ceil(w / 2.) * 2
|
| 36 |
-
return img.resize((w, h), Image.BICUBIC)
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def paired_data_transforms(img_1, img_2, unaligned_transforms=False):
|
| 40 |
-
def get_params(img, output_size):
|
| 41 |
-
w, h = img.size
|
| 42 |
-
th, tw = output_size
|
| 43 |
-
if w == tw and h == th:
|
| 44 |
-
return 0, 0, h, w
|
| 45 |
-
|
| 46 |
-
i = random.randint(0, h - th)
|
| 47 |
-
j = random.randint(0, w - tw)
|
| 48 |
-
return i, j, th, tw
|
| 49 |
-
|
| 50 |
-
target_size = int(random.randint(320, 640) / 2.) * 2
|
| 51 |
-
ow, oh = img_1.size
|
| 52 |
-
if ow >= oh:
|
| 53 |
-
img_1 = __scale_height(img_1, target_size)
|
| 54 |
-
img_2 = __scale_height(img_2, target_size)
|
| 55 |
-
else:
|
| 56 |
-
img_1 = __scale_width(img_1, target_size)
|
| 57 |
-
img_2 = __scale_width(img_2, target_size)
|
| 58 |
-
|
| 59 |
-
if random.random() < 0.5:
|
| 60 |
-
img_1 = TF.hflip(img_1)
|
| 61 |
-
img_2 = TF.hflip(img_2)
|
| 62 |
-
|
| 63 |
-
if random.random() < 0.5:
|
| 64 |
-
angle = random.choice([90, 180, 270])
|
| 65 |
-
img_1 = TF.rotate(img_1, angle)
|
| 66 |
-
img_2 = TF.rotate(img_2, angle)
|
| 67 |
-
|
| 68 |
-
i, j, h, w = get_params(img_1, (320, 320))
|
| 69 |
-
img_1 = TF.crop(img_1, i, j, h, w)
|
| 70 |
-
|
| 71 |
-
if unaligned_transforms:
|
| 72 |
-
# print('random shift')
|
| 73 |
-
i_shift = random.randint(-10, 10)
|
| 74 |
-
j_shift = random.randint(-10, 10)
|
| 75 |
-
i += i_shift
|
| 76 |
-
j += j_shift
|
| 77 |
-
|
| 78 |
-
img_2 = TF.crop(img_2, i, j, h, w)
|
| 79 |
-
|
| 80 |
-
return img_1, img_2
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
class ReflectionSynthesis(object):
|
| 84 |
-
def __init__(self):
|
| 85 |
-
# Kernel Size of the Gaussian Blurry
|
| 86 |
-
self.kernel_sizes = [5, 7, 9, 11]
|
| 87 |
-
self.kernel_probs = [0.1, 0.2, 0.3, 0.4]
|
| 88 |
-
|
| 89 |
-
# Sigma of the Gaussian Blurry
|
| 90 |
-
self.sigma_range = [2, 5]
|
| 91 |
-
self.alpha_range = [0.8, 1.0]
|
| 92 |
-
self.beta_range = [0.4, 1.0]
|
| 93 |
-
|
| 94 |
-
def __call__(self, T_, R_):
|
| 95 |
-
T_ = np.asarray(T_, np.float32) / 255.
|
| 96 |
-
R_ = np.asarray(R_, np.float32) / 255.
|
| 97 |
-
|
| 98 |
-
kernel_size = np.random.choice(self.kernel_sizes, p=self.kernel_probs)
|
| 99 |
-
sigma = np.random.uniform(self.sigma_range[0], self.sigma_range[1])
|
| 100 |
-
kernel = cv2.getGaussianKernel(kernel_size, sigma)
|
| 101 |
-
kernel2d = np.dot(kernel, kernel.T)
|
| 102 |
-
for i in range(3):
|
| 103 |
-
R_[..., i] = convolve2d(R_[..., i], kernel2d, mode='same')
|
| 104 |
-
|
| 105 |
-
a = np.random.uniform(self.alpha_range[0], self.alpha_range[1])
|
| 106 |
-
b = np.random.uniform(self.beta_range[0], self.beta_range[1])
|
| 107 |
-
T, R = a * T_, b * R_
|
| 108 |
-
|
| 109 |
-
if random.random() < 0.7:
|
| 110 |
-
I = T + R - T * R
|
| 111 |
-
|
| 112 |
-
else:
|
| 113 |
-
I = T + R
|
| 114 |
-
if np.max(I) > 1:
|
| 115 |
-
m = I[I > 1]
|
| 116 |
-
m = (np.mean(m) - 1) * 1.3
|
| 117 |
-
I = np.clip(T + np.clip(R - m, 0, 1), 0, 1)
|
| 118 |
-
|
| 119 |
-
return T_, R_, I
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
class DataLoader(torch.utils.data.DataLoader):
|
| 123 |
-
def __init__(self, dataset, batch_size, shuffle, *args, **kwargs):
|
| 124 |
-
super(DataLoader, self).__init__(dataset, batch_size, shuffle, *args, **kwargs)
|
| 125 |
-
self.shuffle = shuffle
|
| 126 |
-
|
| 127 |
-
def reset(self):
|
| 128 |
-
if self.shuffle:
|
| 129 |
-
print('Reset Dataset...')
|
| 130 |
-
self.dataset.reset()
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
class DSRDataset(BaseDataset):
|
| 134 |
-
def __init__(self, datadir, fns=None, size=None, enable_transforms=True):
|
| 135 |
-
super(DSRDataset, self).__init__()
|
| 136 |
-
self.size = size
|
| 137 |
-
self.datadir = datadir
|
| 138 |
-
self.enable_transforms = enable_transforms
|
| 139 |
-
sortkey = lambda key: os.path.split(key)[-1]
|
| 140 |
-
self.paths = sorted(make_dataset(datadir, fns), key=sortkey)
|
| 141 |
-
if size is not None:
|
| 142 |
-
self.paths = np.random.choice(self.paths, size)
|
| 143 |
-
|
| 144 |
-
self.syn_model = ReflectionSynthesis()
|
| 145 |
-
self.reset(shuffle=False)
|
| 146 |
-
|
| 147 |
-
def reset(self, shuffle=True):
|
| 148 |
-
if shuffle:
|
| 149 |
-
random.shuffle(self.paths)
|
| 150 |
-
num_paths = len(self.paths) // 2
|
| 151 |
-
self.B_paths = self.paths[0:num_paths]
|
| 152 |
-
self.R_paths = self.paths[num_paths:2 * num_paths]
|
| 153 |
-
|
| 154 |
-
def data_synthesis(self, t_img, r_img):
|
| 155 |
-
if self.enable_transforms:
|
| 156 |
-
t_img, r_img = paired_data_transforms(t_img, r_img)
|
| 157 |
-
|
| 158 |
-
t_img, r_img, m_img = self.syn_model(t_img, r_img)
|
| 159 |
-
|
| 160 |
-
B = TF.to_tensor(t_img)
|
| 161 |
-
R = TF.to_tensor(r_img)
|
| 162 |
-
M = TF.to_tensor(m_img)
|
| 163 |
-
|
| 164 |
-
return B, R, M
|
| 165 |
-
|
| 166 |
-
def __getitem__(self, index):
|
| 167 |
-
index_B = index % len(self.B_paths)
|
| 168 |
-
index_R = index % len(self.R_paths)
|
| 169 |
-
|
| 170 |
-
B_path = self.B_paths[index_B]
|
| 171 |
-
R_path = self.R_paths[index_R]
|
| 172 |
-
|
| 173 |
-
t_img = Image.open(B_path).convert('RGB')
|
| 174 |
-
r_img = Image.open(R_path).convert('RGB')
|
| 175 |
-
|
| 176 |
-
B, R, M = self.data_synthesis(t_img, r_img)
|
| 177 |
-
fn = os.path.basename(B_path)
|
| 178 |
-
return {'input': M, 'target_t': B, 'target_r': M-B, 'fn': fn, 'real': False}
|
| 179 |
-
|
| 180 |
-
def __len__(self):
|
| 181 |
-
if self.size is not None:
|
| 182 |
-
return min(max(len(self.B_paths), len(self.R_paths)), self.size)
|
| 183 |
-
else:
|
| 184 |
-
return max(len(self.B_paths), len(self.R_paths))
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
class DSRTestDataset(BaseDataset):
|
| 188 |
-
def __init__(self, datadir, fns=None, size=None, enable_transforms=False, unaligned_transforms=False,
|
| 189 |
-
round_factor=1, flag=None, if_align=True):
|
| 190 |
-
super(DSRTestDataset, self).__init__()
|
| 191 |
-
self.size = size
|
| 192 |
-
self.datadir = datadir
|
| 193 |
-
self.fns = fns or os.listdir(join(datadir, 'blended'))
|
| 194 |
-
self.enable_transforms = enable_transforms
|
| 195 |
-
self.unaligned_transforms = unaligned_transforms
|
| 196 |
-
self.round_factor = round_factor
|
| 197 |
-
self.flag = flag
|
| 198 |
-
self.if_align = True # if_align
|
| 199 |
-
|
| 200 |
-
if size is not None:
|
| 201 |
-
self.fns = self.fns[:size]
|
| 202 |
-
|
| 203 |
-
def align(self, x1, x2):
|
| 204 |
-
h, w = x1.height, x1.width
|
| 205 |
-
h, w = h // 32 * 32, w // 32 * 32
|
| 206 |
-
x1 = x1.resize((w, h))
|
| 207 |
-
x2 = x2.resize((w, h))
|
| 208 |
-
return x1, x2
|
| 209 |
-
|
| 210 |
-
def __getitem__(self, index):
|
| 211 |
-
fn = self.fns[index]
|
| 212 |
-
|
| 213 |
-
t_img = Image.open(join(self.datadir, 'transmission_layer', fn)).convert('RGB')
|
| 214 |
-
m_img = Image.open(join(self.datadir, 'blended', fn)).convert('RGB')
|
| 215 |
-
|
| 216 |
-
if self.if_align:
|
| 217 |
-
t_img, m_img = self.align(t_img, m_img)
|
| 218 |
-
|
| 219 |
-
if self.enable_transforms:
|
| 220 |
-
t_img, m_img = paired_data_transforms(t_img, m_img, self.unaligned_transforms)
|
| 221 |
-
|
| 222 |
-
B = TF.to_tensor(t_img)
|
| 223 |
-
M = TF.to_tensor(m_img)
|
| 224 |
-
|
| 225 |
-
dic = {'input': M, 'target_t': B, 'fn': fn, 'real': True, 'target_r': M - B}
|
| 226 |
-
if self.flag is not None:
|
| 227 |
-
dic.update(self.flag)
|
| 228 |
-
return dic
|
| 229 |
-
|
| 230 |
-
def __len__(self):
|
| 231 |
-
if self.size is not None:
|
| 232 |
-
return min(len(self.fns), self.size)
|
| 233 |
-
else:
|
| 234 |
-
return len(self.fns)
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
class SIRTestDataset(BaseDataset):
|
| 238 |
-
def __init__(self, datadir, fns=None, size=None, if_align=True):
|
| 239 |
-
super(SIRTestDataset, self).__init__()
|
| 240 |
-
self.size = size
|
| 241 |
-
self.datadir = datadir
|
| 242 |
-
self.fns = fns or os.listdir(join(datadir, 'blended'))
|
| 243 |
-
self.if_align = if_align
|
| 244 |
-
|
| 245 |
-
if size is not None:
|
| 246 |
-
self.fns = self.fns[:size]
|
| 247 |
-
|
| 248 |
-
def align(self, x1, x2, x3):
|
| 249 |
-
h, w = x1.height, x1.width
|
| 250 |
-
h, w = h // 32 * 32, w // 32 * 32
|
| 251 |
-
x1 = x1.resize((w, h))
|
| 252 |
-
x2 = x2.resize((w, h))
|
| 253 |
-
x3 = x3.resize((w, h))
|
| 254 |
-
return x1, x2, x3
|
| 255 |
-
|
| 256 |
-
def __getitem__(self, index):
|
| 257 |
-
fn = self.fns[index]
|
| 258 |
-
|
| 259 |
-
t_img = Image.open(join(self.datadir, 'transmission_layer', fn)).convert('RGB')
|
| 260 |
-
r_img = Image.open(join(self.datadir, 'reflection_layer', fn)).convert('RGB')
|
| 261 |
-
m_img = Image.open(join(self.datadir, 'blended', fn)).convert('RGB')
|
| 262 |
-
|
| 263 |
-
if self.if_align:
|
| 264 |
-
t_img, r_img, m_img = self.align(t_img, r_img, m_img)
|
| 265 |
-
|
| 266 |
-
B = TF.to_tensor(t_img)
|
| 267 |
-
R = TF.to_tensor(r_img)
|
| 268 |
-
M = TF.to_tensor(m_img)
|
| 269 |
-
|
| 270 |
-
dic = {'input': M, 'target_t': B, 'fn': fn, 'real': True, 'target_r': R, 'target_r_hat': M - B}
|
| 271 |
-
return dic
|
| 272 |
-
|
| 273 |
-
def __len__(self):
|
| 274 |
-
if self.size is not None:
|
| 275 |
-
return min(len(self.fns), self.size)
|
| 276 |
-
else:
|
| 277 |
-
return len(self.fns)
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
class RealDataset(BaseDataset):
|
| 281 |
-
def __init__(self, datadir, fns=None, size=None):
|
| 282 |
-
super(RealDataset, self).__init__()
|
| 283 |
-
self.size = size
|
| 284 |
-
self.datadir = datadir
|
| 285 |
-
self.fns = fns or os.listdir(join(datadir))
|
| 286 |
-
|
| 287 |
-
if size is not None:
|
| 288 |
-
self.fns = self.fns[:size]
|
| 289 |
-
|
| 290 |
-
def align(self, x):
|
| 291 |
-
h, w = x.height, x.width
|
| 292 |
-
h, w = h // 32 * 32, w // 32 * 32
|
| 293 |
-
x = x.resize((w, h))
|
| 294 |
-
return x
|
| 295 |
-
|
| 296 |
-
def __getitem__(self, index):
|
| 297 |
-
fn = self.fns[index]
|
| 298 |
-
B = -1
|
| 299 |
-
m_img = Image.open(join(self.datadir, fn)).convert('RGB')
|
| 300 |
-
M = to_tensor(self.align(m_img))
|
| 301 |
-
data = {'input': M, 'target_t': B, 'fn': fn}
|
| 302 |
-
return data
|
| 303 |
-
|
| 304 |
-
def __len__(self):
|
| 305 |
-
if self.size is not None:
|
| 306 |
-
return min(len(self.fns), self.size)
|
| 307 |
-
else:
|
| 308 |
-
return len(self.fns)
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
class FusionDataset(BaseDataset):
|
| 312 |
-
def __init__(self, datasets, fusion_ratios=None):
|
| 313 |
-
self.datasets = datasets
|
| 314 |
-
self.size = sum([len(dataset) for dataset in datasets])
|
| 315 |
-
self.fusion_ratios = fusion_ratios or [1. / len(datasets)] * len(datasets)
|
| 316 |
-
print('[i] using a fusion dataset: %d %s imgs fused with ratio %s' % (
|
| 317 |
-
self.size, [len(dataset) for dataset in datasets], self.fusion_ratios))
|
| 318 |
-
|
| 319 |
-
def reset(self):
|
| 320 |
-
for dataset in self.datasets:
|
| 321 |
-
dataset.reset()
|
| 322 |
-
|
| 323 |
-
def __getitem__(self, index):
|
| 324 |
-
residual = 1
|
| 325 |
-
for i, ratio in enumerate(self.fusion_ratios):
|
| 326 |
-
if random.random() < ratio / residual or i == len(self.fusion_ratios) - 1:
|
| 327 |
-
dataset = self.datasets[i]
|
| 328 |
-
return dataset[index % len(dataset)]
|
| 329 |
-
residual -= ratio
|
| 330 |
-
|
| 331 |
-
def __len__(self):
|
| 332 |
-
return self.size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/image_folder.py
DELETED
|
@@ -1,51 +0,0 @@
|
|
| 1 |
-
###############################################################################
|
| 2 |
-
# Code from
|
| 3 |
-
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
|
| 4 |
-
# Modified the original code so that it also loads images from the current
|
| 5 |
-
# directory as well as the subdirectories
|
| 6 |
-
###############################################################################
|
| 7 |
-
|
| 8 |
-
import torch.utils.data as data
|
| 9 |
-
|
| 10 |
-
from PIL import Image
|
| 11 |
-
import os
|
| 12 |
-
import os.path
|
| 13 |
-
|
| 14 |
-
IMG_EXTENSIONS = [
|
| 15 |
-
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
| 16 |
-
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
| 17 |
-
]
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def read_fns(filename):
|
| 21 |
-
with open(filename) as f:
|
| 22 |
-
fns = f.readlines()
|
| 23 |
-
fns = [fn.strip() for fn in fns]
|
| 24 |
-
return fns
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def is_image_file(filename):
|
| 28 |
-
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def make_dataset(dir, fns=None):
|
| 32 |
-
images = []
|
| 33 |
-
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
| 34 |
-
|
| 35 |
-
if fns is None:
|
| 36 |
-
for root, _, fnames in sorted(os.walk(dir)):
|
| 37 |
-
for fname in fnames:
|
| 38 |
-
if is_image_file(fname):
|
| 39 |
-
path = os.path.join(root, fname)
|
| 40 |
-
images.append(path)
|
| 41 |
-
else:
|
| 42 |
-
for fname in fns:
|
| 43 |
-
if is_image_file(fname):
|
| 44 |
-
path = os.path.join(dir, fname)
|
| 45 |
-
images.append(path)
|
| 46 |
-
|
| 47 |
-
return images
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def default_loader(path):
|
| 51 |
-
return Image.open(path).convert('RGB')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/real_test.txt
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
3.jpg
|
| 2 |
-
4.jpg
|
| 3 |
-
9.jpg
|
| 4 |
-
12.jpg
|
| 5 |
-
15.jpg
|
| 6 |
-
22.jpg
|
| 7 |
-
23.jpg
|
| 8 |
-
25.jpg
|
| 9 |
-
29.jpg
|
| 10 |
-
39.jpg
|
| 11 |
-
46.jpg
|
| 12 |
-
47.jpg
|
| 13 |
-
58.jpg
|
| 14 |
-
86.jpg
|
| 15 |
-
87.jpg
|
| 16 |
-
89.jpg
|
| 17 |
-
93.jpg
|
| 18 |
-
103.jpg
|
| 19 |
-
107.jpg
|
| 20 |
-
110.jpg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/torchdata.py
DELETED
|
@@ -1,67 +0,0 @@
|
|
| 1 |
-
import bisect
|
| 2 |
-
import warnings
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class Dataset(object):
|
| 6 |
-
"""An abstract class representing a Dataset.
|
| 7 |
-
|
| 8 |
-
All other datasets should subclass it. All subclasses should override
|
| 9 |
-
``__len__``, that provides the size of the dataset, and ``__getitem__``,
|
| 10 |
-
supporting integer indexing in range from 0 to len(self) exclusive.
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
def __getitem__(self, index):
|
| 14 |
-
raise NotImplementedError
|
| 15 |
-
|
| 16 |
-
def __len__(self):
|
| 17 |
-
raise NotImplementedError
|
| 18 |
-
|
| 19 |
-
def __add__(self, other):
|
| 20 |
-
return ConcatDataset([self, other])
|
| 21 |
-
|
| 22 |
-
def reset(self):
|
| 23 |
-
return
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class ConcatDataset(Dataset):
|
| 27 |
-
"""
|
| 28 |
-
Dataset to concatenate multiple datasets.
|
| 29 |
-
Purpose: useful to assemble different existing datasets, possibly
|
| 30 |
-
large-scale datasets as the concatenation operation is done in an
|
| 31 |
-
on-the-fly manner.
|
| 32 |
-
|
| 33 |
-
Arguments:
|
| 34 |
-
datasets (sequence): List of datasets to be concatenated
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
@staticmethod
|
| 38 |
-
def cumsum(sequence):
|
| 39 |
-
r, s = [], 0
|
| 40 |
-
for e in sequence:
|
| 41 |
-
l = len(e)
|
| 42 |
-
r.append(l + s)
|
| 43 |
-
s += l
|
| 44 |
-
return r
|
| 45 |
-
|
| 46 |
-
def __init__(self, datasets):
|
| 47 |
-
super(ConcatDataset, self).__init__()
|
| 48 |
-
assert len(datasets) > 0, 'datasets should not be an empty iterable'
|
| 49 |
-
self.datasets = list(datasets)
|
| 50 |
-
self.cumulative_sizes = self.cumsum(self.datasets)
|
| 51 |
-
|
| 52 |
-
def __len__(self):
|
| 53 |
-
return self.cumulative_sizes[-1]
|
| 54 |
-
|
| 55 |
-
def __getitem__(self, idx):
|
| 56 |
-
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
| 57 |
-
if dataset_idx == 0:
|
| 58 |
-
sample_idx = idx
|
| 59 |
-
else:
|
| 60 |
-
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
| 61 |
-
return self.datasets[dataset_idx][sample_idx]
|
| 62 |
-
|
| 63 |
-
@property
|
| 64 |
-
def cummulative_sizes(self):
|
| 65 |
-
warnings.warn("cummulative_sizes attribute is renamed to "
|
| 66 |
-
"cumulative_sizes", DeprecationWarning, stacklevel=2)
|
| 67 |
-
return self.cumulative_sizes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/transforms.py
DELETED
|
@@ -1,301 +0,0 @@
|
|
| 1 |
-
from __future__ import division
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
import random
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
from PIL import Image
|
| 8 |
-
|
| 9 |
-
try:
|
| 10 |
-
import accimage
|
| 11 |
-
except ImportError:
|
| 12 |
-
accimage = None
|
| 13 |
-
import numpy as np
|
| 14 |
-
import scipy.stats as st
|
| 15 |
-
import cv2
|
| 16 |
-
import collections
|
| 17 |
-
import torchvision.transforms as transforms
|
| 18 |
-
import util.util as util
|
| 19 |
-
from scipy.signal import convolve2d
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
# utility
|
| 23 |
-
def _is_pil_image(img):
|
| 24 |
-
if accimage is not None:
|
| 25 |
-
return isinstance(img, (Image.Image, accimage.Image))
|
| 26 |
-
else:
|
| 27 |
-
return isinstance(img, Image.Image)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def _is_tensor_image(img):
|
| 31 |
-
return torch.is_tensor(img) and img.ndimension() == 3
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def _is_numpy_image(img):
|
| 35 |
-
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def arrshow(arr):
|
| 39 |
-
Image.fromarray(arr.astype(np.uint8)).show()
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def get_transform(opt):
|
| 43 |
-
transform_list = []
|
| 44 |
-
osizes = util.parse_args(opt.loadSize)
|
| 45 |
-
fineSize = util.parse_args(opt.fineSize)
|
| 46 |
-
if opt.resize_or_crop == 'resize_and_crop':
|
| 47 |
-
transform_list.append(
|
| 48 |
-
transforms.RandomChoice([
|
| 49 |
-
transforms.Resize([osize, osize], Image.BICUBIC) for osize in osizes
|
| 50 |
-
]))
|
| 51 |
-
transform_list.append(transforms.RandomCrop(fineSize))
|
| 52 |
-
elif opt.resize_or_crop == 'crop':
|
| 53 |
-
transform_list.append(transforms.RandomCrop(fineSize))
|
| 54 |
-
elif opt.resize_or_crop == 'scale_width':
|
| 55 |
-
transform_list.append(transforms.Lambda(
|
| 56 |
-
lambda img: __scale_width(img, fineSize)))
|
| 57 |
-
elif opt.resize_or_crop == 'scale_width_and_crop':
|
| 58 |
-
transform_list.append(transforms.Lambda(
|
| 59 |
-
lambda img: __scale_width(img, opt.loadSize)))
|
| 60 |
-
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
| 61 |
-
|
| 62 |
-
if opt.isTrain and not opt.no_flip:
|
| 63 |
-
transform_list.append(transforms.RandomHorizontalFlip())
|
| 64 |
-
|
| 65 |
-
return transforms.Compose(transform_list)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
to_norm_tensor = transforms.Compose([
|
| 69 |
-
transforms.ToTensor(),
|
| 70 |
-
transforms.Normalize(
|
| 71 |
-
(0.5, 0.5, 0.5),
|
| 72 |
-
(0.5, 0.5, 0.5)
|
| 73 |
-
)
|
| 74 |
-
])
|
| 75 |
-
|
| 76 |
-
to_tensor = transforms.ToTensor()
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def __scale_width(img, target_width):
|
| 80 |
-
ow, oh = img.size
|
| 81 |
-
if (ow == target_width):
|
| 82 |
-
return img
|
| 83 |
-
w = target_width
|
| 84 |
-
h = int(target_width * oh / ow)
|
| 85 |
-
h = math.ceil(h / 2.) * 2 # round up to even
|
| 86 |
-
return img.resize((w, h), Image.BICUBIC)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
# functional
|
| 90 |
-
def gaussian_blur(img, kernel_size, sigma):
|
| 91 |
-
if not _is_pil_image(img):
|
| 92 |
-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
| 93 |
-
|
| 94 |
-
img = np.asarray(img)
|
| 95 |
-
# the 3rd dimension (i.e. inter-band) would be filtered which is unwanted for our purpose
|
| 96 |
-
# new = gaussian_filter(img, sigma=sigma, truncate=truncate)
|
| 97 |
-
if isinstance(kernel_size, int):
|
| 98 |
-
kernel_size = (kernel_size, kernel_size)
|
| 99 |
-
elif isinstance(kernel_size, collections.Sequence):
|
| 100 |
-
assert len(kernel_size) == 2
|
| 101 |
-
new = cv2.GaussianBlur(img, kernel_size, sigma) # apply gaussian filter band by band
|
| 102 |
-
return Image.fromarray(new)
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
# transforms
|
| 106 |
-
class GaussianBlur(object):
|
| 107 |
-
def __init__(self, kernel_size=11, sigma=3):
|
| 108 |
-
self.kernel_size = kernel_size
|
| 109 |
-
self.sigma = sigma
|
| 110 |
-
|
| 111 |
-
def __call__(self, img):
|
| 112 |
-
return gaussian_blur(img, self.kernel_size, self.sigma)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
class ReflectionSythesis_0(object):
|
| 116 |
-
"""Reflection image data synthesis for weakly-supervised learning
|
| 117 |
-
of ICCV 2017 paper *"A Generic Deep Architecture for Single Image Reflection Removal and Image Smoothing"*
|
| 118 |
-
"""
|
| 119 |
-
|
| 120 |
-
def __init__(self, kernel_sizes=None, low_sigma=2, high_sigma=5, low_gamma=1.3,
|
| 121 |
-
high_gamma=1.3, low_delta=0.4, high_delta=1.8):
|
| 122 |
-
self.kernel_sizes = kernel_sizes or [11]
|
| 123 |
-
self.low_sigma = low_sigma
|
| 124 |
-
self.high_sigma = high_sigma
|
| 125 |
-
self.low_gamma = low_gamma
|
| 126 |
-
self.high_gamma = high_gamma
|
| 127 |
-
self.low_delta = low_delta
|
| 128 |
-
self.high_delta = high_delta
|
| 129 |
-
print('[i] reflection sythesis model: {}'.format({
|
| 130 |
-
'kernel_sizes': kernel_sizes, 'low_sigma': low_sigma, 'high_sigma': high_sigma,
|
| 131 |
-
'low_gamma': low_gamma, 'high_gamma': high_gamma}))
|
| 132 |
-
|
| 133 |
-
def __call__(self, B, R):
|
| 134 |
-
if not _is_pil_image(B):
|
| 135 |
-
raise TypeError('B should be PIL Image. Got {}'.format(type(B)))
|
| 136 |
-
if not _is_pil_image(R):
|
| 137 |
-
raise TypeError('R should be PIL Image. Got {}'.format(type(R)))
|
| 138 |
-
B_ = np.asarray(B, np.float32)
|
| 139 |
-
if random.random() < 0.4:
|
| 140 |
-
B_ = np.tile(np.random.uniform(0, 30, (1, 1, 1)), B_.shape) / 255.
|
| 141 |
-
else:
|
| 142 |
-
B_ = np.tile(np.random.normal(50, 50, (1, 1, 3)), (B_.shape[0], B_.shape[1], 1)).clip(0, 255) / 255.
|
| 143 |
-
R_ = np.asarray(R, np.float32) / 255.
|
| 144 |
-
|
| 145 |
-
kernel_size = np.random.choice(self.kernel_sizes)
|
| 146 |
-
sigma = np.random.uniform(self.low_sigma, self.high_sigma)
|
| 147 |
-
gamma = np.random.uniform(self.low_gamma, self.high_gamma)
|
| 148 |
-
delta = np.random.uniform(self.low_delta, self.high_delta)
|
| 149 |
-
R_blur = R_
|
| 150 |
-
kernel = cv2.getGaussianKernel(11, sigma)
|
| 151 |
-
kernel2d = np.dot(kernel, kernel.T)
|
| 152 |
-
|
| 153 |
-
for i in range(3):
|
| 154 |
-
R_blur[..., i] = convolve2d(R_blur[..., i], kernel2d, mode='same')
|
| 155 |
-
|
| 156 |
-
R_blur = np.clip(R_blur - np.mean(R_blur) * gamma, 0, 1)
|
| 157 |
-
R_blur = np.clip(R_blur * delta, 0, 1)
|
| 158 |
-
M_ = np.clip(R_blur + B_, 0, 1)
|
| 159 |
-
|
| 160 |
-
return B_, R_blur, M_
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
class ReflectionSythesis_1(object):
|
| 164 |
-
"""Reflection image data synthesis for weakly-supervised learning
|
| 165 |
-
of ICCV 2017 paper *"A Generic Deep Architecture for Single Image Reflection Removal and Image Smoothing"*
|
| 166 |
-
"""
|
| 167 |
-
|
| 168 |
-
def __init__(self, kernel_sizes=None, low_sigma=2, high_sigma=5, low_gamma=1.3, high_gamma=1.3):
|
| 169 |
-
self.kernel_sizes = kernel_sizes or [11]
|
| 170 |
-
self.low_sigma = low_sigma
|
| 171 |
-
self.high_sigma = high_sigma
|
| 172 |
-
self.low_gamma = low_gamma
|
| 173 |
-
self.high_gamma = high_gamma
|
| 174 |
-
print('[i] reflection sythesis model: {}'.format({
|
| 175 |
-
'kernel_sizes': kernel_sizes, 'low_sigma': low_sigma, 'high_sigma': high_sigma,
|
| 176 |
-
'low_gamma': low_gamma, 'high_gamma': high_gamma}))
|
| 177 |
-
|
| 178 |
-
def __call__(self, B, R):
|
| 179 |
-
if not _is_pil_image(B):
|
| 180 |
-
raise TypeError('B should be PIL Image. Got {}'.format(type(B)))
|
| 181 |
-
if not _is_pil_image(R):
|
| 182 |
-
raise TypeError('R should be PIL Image. Got {}'.format(type(R)))
|
| 183 |
-
|
| 184 |
-
B_ = np.asarray(B, np.float32) / 255.
|
| 185 |
-
R_ = np.asarray(R, np.float32) / 255.
|
| 186 |
-
|
| 187 |
-
kernel_size = np.random.choice(self.kernel_sizes)
|
| 188 |
-
sigma = np.random.uniform(self.low_sigma, self.high_sigma)
|
| 189 |
-
gamma = np.random.uniform(self.low_gamma, self.high_gamma)
|
| 190 |
-
R_blur = R_
|
| 191 |
-
kernel = cv2.getGaussianKernel(11, sigma)
|
| 192 |
-
kernel2d = np.dot(kernel, kernel.T)
|
| 193 |
-
|
| 194 |
-
for i in range(3):
|
| 195 |
-
R_blur[..., i] = convolve2d(R_blur[..., i], kernel2d, mode='same')
|
| 196 |
-
|
| 197 |
-
M_ = B_ + R_blur
|
| 198 |
-
|
| 199 |
-
if np.max(M_) > 1:
|
| 200 |
-
m = M_[M_ > 1]
|
| 201 |
-
m = (np.mean(m) - 1) * gamma
|
| 202 |
-
R_blur = np.clip(R_blur - m, 0, 1)
|
| 203 |
-
M_ = np.clip(R_blur + B_, 0, 1)
|
| 204 |
-
|
| 205 |
-
return B_, R_blur, M_
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
class Sobel(object):
|
| 209 |
-
def __call__(self, img):
|
| 210 |
-
if not _is_pil_image(img):
|
| 211 |
-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
| 212 |
-
|
| 213 |
-
gray_img = np.array(img.convert('L'))
|
| 214 |
-
x = cv2.Sobel(gray_img, cv2.CV_16S, 1, 0)
|
| 215 |
-
y = cv2.Sobel(gray_img, cv2.CV_16S, 0, 1)
|
| 216 |
-
|
| 217 |
-
absX = cv2.convertScaleAbs(x)
|
| 218 |
-
absY = cv2.convertScaleAbs(y)
|
| 219 |
-
|
| 220 |
-
dst = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
|
| 221 |
-
return Image.fromarray(dst)
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
class ReflectionSythesis_2(object):
|
| 225 |
-
"""Reflection image data synthesis for weakly-supervised learning
|
| 226 |
-
of CVPR 2018 paper *"Single Image Reflection Separation with Perceptual Losses"*
|
| 227 |
-
"""
|
| 228 |
-
|
| 229 |
-
def __init__(self, kernel_sizes=None):
|
| 230 |
-
self.kernel_sizes = kernel_sizes or np.linspace(1, 5, 80)
|
| 231 |
-
|
| 232 |
-
@staticmethod
|
| 233 |
-
def gkern(kernlen=100, nsig=1):
|
| 234 |
-
"""Returns a 2D Gaussian kernel array."""
|
| 235 |
-
interval = (2 * nsig + 1.) / (kernlen)
|
| 236 |
-
x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kernlen + 1)
|
| 237 |
-
kern1d = np.diff(st.norm.cdf(x))
|
| 238 |
-
kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
|
| 239 |
-
kernel = kernel_raw / kernel_raw.sum()
|
| 240 |
-
kernel = kernel / kernel.max()
|
| 241 |
-
return kernel
|
| 242 |
-
|
| 243 |
-
def __call__(self, t, r):
|
| 244 |
-
t = np.float32(t) / 255.
|
| 245 |
-
r = np.float32(r) / 255.
|
| 246 |
-
ori_t = t
|
| 247 |
-
# create a vignetting mask
|
| 248 |
-
g_mask = self.gkern(560, 3)
|
| 249 |
-
g_mask = np.dstack((g_mask, g_mask, g_mask))
|
| 250 |
-
sigma = self.kernel_sizes[np.random.randint(0, len(self.kernel_sizes))]
|
| 251 |
-
|
| 252 |
-
t = np.power(t, 2.2)
|
| 253 |
-
r = np.power(r, 2.2)
|
| 254 |
-
|
| 255 |
-
sz = int(2 * np.ceil(2 * sigma) + 1)
|
| 256 |
-
|
| 257 |
-
r_blur = cv2.GaussianBlur(r, (sz, sz), sigma, sigma, 0)
|
| 258 |
-
blend = r_blur + t
|
| 259 |
-
|
| 260 |
-
att = 1.08 + np.random.random() / 10.0
|
| 261 |
-
|
| 262 |
-
for i in range(3):
|
| 263 |
-
maski = blend[:, :, i] > 1
|
| 264 |
-
mean_i = max(1., np.sum(blend[:, :, i] * maski) / (maski.sum() + 1e-6))
|
| 265 |
-
r_blur[:, :, i] = r_blur[:, :, i] - (mean_i - 1) * att
|
| 266 |
-
r_blur[r_blur >= 1] = 1
|
| 267 |
-
r_blur[r_blur <= 0] = 0
|
| 268 |
-
|
| 269 |
-
h, w = r_blur.shape[0:2]
|
| 270 |
-
neww = np.random.randint(0, 560 - w - 10)
|
| 271 |
-
newh = np.random.randint(0, 560 - h - 10)
|
| 272 |
-
alpha1 = g_mask[newh:newh + h, neww:neww + w, :]
|
| 273 |
-
alpha2 = 1 - np.random.random() / 5.0
|
| 274 |
-
r_blur_mask = np.multiply(r_blur, alpha1)
|
| 275 |
-
blend = r_blur_mask + t * alpha2
|
| 276 |
-
|
| 277 |
-
t = np.power(t, 1 / 2.2)
|
| 278 |
-
r_blur_mask = np.power(r_blur_mask, 1 / 2.2)
|
| 279 |
-
blend = np.power(blend, 1 / 2.2)
|
| 280 |
-
blend[blend >= 1] = 1
|
| 281 |
-
blend[blend <= 0] = 0
|
| 282 |
-
|
| 283 |
-
return np.float32(ori_t), np.float32(r_blur_mask), np.float32(blend)
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
# Examples
|
| 287 |
-
if __name__ == '__main__':
|
| 288 |
-
"""cv2 imread"""
|
| 289 |
-
# img = cv2.imread('testdata_reflection_real/19-input.png')
|
| 290 |
-
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 291 |
-
# img2 = cv2.GaussianBlur(img, (11,11), 3)
|
| 292 |
-
|
| 293 |
-
"""Sobel Operator"""
|
| 294 |
-
# img = np.array(Image.open('datasets/VOC224/train/B/2007_000250.png').convert('L'))
|
| 295 |
-
|
| 296 |
-
"""Reflection Sythesis"""
|
| 297 |
-
b = Image.open('')
|
| 298 |
-
r = Image.open('')
|
| 299 |
-
G = ReflectionSythesis_0()
|
| 300 |
-
m, r = G(b, r)
|
| 301 |
-
r.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
engine.py
DELETED
|
@@ -1,178 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import util.util as util
|
| 3 |
-
from models import make_model
|
| 4 |
-
import time
|
| 5 |
-
import os
|
| 6 |
-
import sys
|
| 7 |
-
from os.path import join
|
| 8 |
-
from util.visualizer import Visualizer
|
| 9 |
-
import tqdm
|
| 10 |
-
import visdom
|
| 11 |
-
import numpy as np
|
| 12 |
-
from tools import mutils
|
| 13 |
-
|
| 14 |
-
class Engine(object):
|
| 15 |
-
def __init__(self, opt,eval_dataset_real,eval_dataset_solidobject,eval_dataset_postcard,eval_dataloader_wild):
|
| 16 |
-
self.opt = opt
|
| 17 |
-
self.writer = None
|
| 18 |
-
self.visualizer = None
|
| 19 |
-
self.model = None
|
| 20 |
-
self.best_val_loss = 1e6
|
| 21 |
-
self.eval_dataset_real = eval_dataset_real
|
| 22 |
-
self.eval_dataset_solidobject = eval_dataset_solidobject
|
| 23 |
-
self.eval_dataset_postcard = eval_dataset_postcard
|
| 24 |
-
self.eval_dataloader_wild = eval_dataloader_wild
|
| 25 |
-
self.result_dir = os.path.join(f'./experiment/{self.opt.name}/results',
|
| 26 |
-
mutils.get_formatted_time())
|
| 27 |
-
self.biggest_psnr=0
|
| 28 |
-
self.__setup()
|
| 29 |
-
|
| 30 |
-
def __setup(self):
|
| 31 |
-
self.basedir = join('experiment', self.opt.name)
|
| 32 |
-
os.makedirs(self.basedir, exist_ok=True)
|
| 33 |
-
|
| 34 |
-
opt = self.opt
|
| 35 |
-
|
| 36 |
-
"""Model"""
|
| 37 |
-
self.model = make_model(self.opt.model) # models.__dict__[self.opt.model]()
|
| 38 |
-
self.model.initialize(opt)
|
| 39 |
-
if True:
|
| 40 |
-
print("IN")
|
| 41 |
-
self.writer = util.get_summary_writer(os.path.join(self.basedir, 'logs'))
|
| 42 |
-
self.visualizer = Visualizer(opt)
|
| 43 |
-
|
| 44 |
-
def train(self, train_loader, **kwargs):
|
| 45 |
-
print('\nEpoch: %d' % self.epoch)
|
| 46 |
-
avg_meters = util.AverageMeters()
|
| 47 |
-
opt = self.opt
|
| 48 |
-
model = self.model
|
| 49 |
-
epoch = self.epoch
|
| 50 |
-
|
| 51 |
-
epoch_start_time = time.time()
|
| 52 |
-
for i, data in tqdm.tqdm(enumerate(train_loader)):
|
| 53 |
-
|
| 54 |
-
iter_start_time = time.time()
|
| 55 |
-
iterations = self.iterations
|
| 56 |
-
|
| 57 |
-
model.set_input(data, mode='train')
|
| 58 |
-
model.optimize_parameters(**kwargs)
|
| 59 |
-
|
| 60 |
-
errors = model.get_current_errors()
|
| 61 |
-
avg_meters.update(errors)
|
| 62 |
-
util.progress_bar(i, len(train_loader), str(avg_meters))
|
| 63 |
-
util.write_loss(self.writer, 'train', avg_meters, iterations)
|
| 64 |
-
if iterations%100==0:
|
| 65 |
-
imgs=[]
|
| 66 |
-
output_clean,output_reflection,input=model.return_output()
|
| 67 |
-
# output_clean,input=model.return_output()
|
| 68 |
-
|
| 69 |
-
output_clean=np.transpose(output_clean,(2,0,1))/255
|
| 70 |
-
#output_reflection = np.transpose(output_reflection, (2, 0, 1))/255
|
| 71 |
-
input = np.transpose(input, (2, 0, 1))/255
|
| 72 |
-
imgs.append(output_clean)
|
| 73 |
-
#imgs.append(output_reflection)
|
| 74 |
-
imgs.append(input)
|
| 75 |
-
util.get_visual(self.writer,iterations,imgs)
|
| 76 |
-
if iterations % opt.print_freq == 0 and opt.display_id != 0:
|
| 77 |
-
t = (time.time() - iter_start_time)
|
| 78 |
-
|
| 79 |
-
self.iterations += 1
|
| 80 |
-
|
| 81 |
-
self.epoch += 1
|
| 82 |
-
|
| 83 |
-
if True:#not self.opt.no_log:
|
| 84 |
-
if self.epoch % opt.save_epoch_freq == 0:
|
| 85 |
-
save_dir = os.path.join(self.result_dir, '%03d' % self.epoch)
|
| 86 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 87 |
-
matrix_real=self.eval(self.eval_dataset_real, dataset_name='testdata_real20', savedir=save_dir, suffix='real20')
|
| 88 |
-
matrix_solid=self.eval(self.eval_dataset_solidobject, dataset_name='testdata_solidobject', savedir=save_dir,
|
| 89 |
-
suffix='solidobject')
|
| 90 |
-
matrix_post=self.eval(self.eval_dataset_postcard, dataset_name='testdata_postcard', savedir=save_dir, suffix='postcard')
|
| 91 |
-
matrix_wild=self.eval(self.eval_dataloader_wild, dataset_name='testdata_wild', savedir=save_dir, suffix='wild')
|
| 92 |
-
sum_PSNR_real=matrix_real['PSNR']*20
|
| 93 |
-
sum_PSNR_solid=matrix_solid['PSNR']*200
|
| 94 |
-
sum_PSNR_post=matrix_post['PSNR']*199
|
| 95 |
-
sum_PSNR_wild=matrix_wild['PSNR']*55
|
| 96 |
-
print("sum_PSNR_real: ",matrix_real['PSNR'],"sum_PSNR_solid: ",matrix_solid['PSNR'],"sum_PSNR_post: ",matrix_post['PSNR'],"sum_PSNR_wild: ",matrix_wild['PSNR'])
|
| 97 |
-
sum_PSNR = float(sum_PSNR_real + sum_PSNR_solid + sum_PSNR_post + sum_PSNR_wild)/474.0
|
| 98 |
-
print('总PSNR:', sum_PSNR)
|
| 99 |
-
if sum_PSNR>self.biggest_psnr:
|
| 100 |
-
self.biggest_psnr=sum_PSNR
|
| 101 |
-
print('saving the model at epoch %d, iters %d' %(self.epoch, self.iterations))
|
| 102 |
-
model.save()
|
| 103 |
-
print('highest: ',self.biggest_psnr,' name: ',opt.name)
|
| 104 |
-
|
| 105 |
-
print('saving the latest model at the end of epoch %d, iters %d' %
|
| 106 |
-
(self.epoch, self.iterations))
|
| 107 |
-
model.save(label='latest')
|
| 108 |
-
|
| 109 |
-
print('Time Taken: %d sec' %
|
| 110 |
-
(time.time() - epoch_start_time))
|
| 111 |
-
|
| 112 |
-
# model.update_learning_rate()
|
| 113 |
-
try:
|
| 114 |
-
train_loader.reset()
|
| 115 |
-
except:
|
| 116 |
-
pass
|
| 117 |
-
|
| 118 |
-
def eval(self, val_loader, dataset_name, savedir='./tmp', loss_key=None, **kwargs):
|
| 119 |
-
# print(dataset_name)
|
| 120 |
-
if savedir is not None:
|
| 121 |
-
os.makedirs(savedir, exist_ok=True)
|
| 122 |
-
self.f = open(os.path.join(savedir, 'metrics.txt'), 'w+')
|
| 123 |
-
self.f.write(dataset_name + '\n')
|
| 124 |
-
avg_meters = util.AverageMeters()
|
| 125 |
-
model = self.model
|
| 126 |
-
opt = self.opt
|
| 127 |
-
with torch.no_grad():
|
| 128 |
-
for i, data in enumerate(val_loader):
|
| 129 |
-
if self.opt.select is not None and data['fn'][0] not in [f'{self.opt.select}.jpg']:
|
| 130 |
-
continue
|
| 131 |
-
#print(data.shape())
|
| 132 |
-
index = model.eval(data, savedir=savedir, **kwargs)
|
| 133 |
-
|
| 134 |
-
# print(data['fn'][0], index)
|
| 135 |
-
if savedir is not None:
|
| 136 |
-
self.f.write(f"{data['fn'][0]} {index['PSNR']} {index['SSIM']}\n")
|
| 137 |
-
avg_meters.update(index)
|
| 138 |
-
util.progress_bar(i, len(val_loader), str(avg_meters))
|
| 139 |
-
|
| 140 |
-
if not opt.no_log:
|
| 141 |
-
util.write_loss(self.writer, join('eval', dataset_name), avg_meters, self.epoch)
|
| 142 |
-
|
| 143 |
-
if loss_key is not None:
|
| 144 |
-
val_loss = avg_meters[loss_key]
|
| 145 |
-
if val_loss < self.best_val_loss:
|
| 146 |
-
self.best_val_loss = val_loss
|
| 147 |
-
print('saving the best model at the end of epoch %d, iters %d' %
|
| 148 |
-
(self.epoch, self.iterations))
|
| 149 |
-
model.save(label='best_{}_{}'.format(loss_key, dataset_name))
|
| 150 |
-
|
| 151 |
-
return avg_meters
|
| 152 |
-
|
| 153 |
-
def test(self, test_loader, savedir=None, **kwargs):
|
| 154 |
-
model = self.model
|
| 155 |
-
opt = self.opt
|
| 156 |
-
with torch.no_grad():
|
| 157 |
-
for i, data in enumerate(test_loader):
|
| 158 |
-
model.test(data, savedir=savedir, **kwargs)
|
| 159 |
-
util.progress_bar(i, len(test_loader))
|
| 160 |
-
|
| 161 |
-
def save_eval(self, label):
|
| 162 |
-
self.model.save_eval(label)
|
| 163 |
-
|
| 164 |
-
@property
|
| 165 |
-
def iterations(self):
|
| 166 |
-
return self.model.iterations
|
| 167 |
-
|
| 168 |
-
@iterations.setter
|
| 169 |
-
def iterations(self, i):
|
| 170 |
-
self.model.iterations = i
|
| 171 |
-
|
| 172 |
-
@property
|
| 173 |
-
def epoch(self):
|
| 174 |
-
return self.model.epoch
|
| 175 |
-
|
| 176 |
-
@epoch.setter
|
| 177 |
-
def epoch(self, e):
|
| 178 |
-
self.model.epoch = e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
figures/Input_car.jpg
DELETED
|
Binary file (26.8 kB)
|
|
|
figures/Input_class.png
DELETED
Git LFS Details
|
figures/Input_green.png
DELETED
Git LFS Details
|
figures/Ours_car.png
DELETED
Git LFS Details
|
figures/Ours_class.png
DELETED
Git LFS Details
|
figures/Ours_green.png
DELETED
Git LFS Details
|
figures/Ours_white.png
DELETED
Git LFS Details
|
figures/Title.png
DELETED
|
Binary file (98.8 kB)
|
|
|
figures/input_white.jpg
DELETED
|
Binary file (24.9 kB)
|
|
|
figures/net.png
DELETED
Git LFS Details
|
figures/result.png
DELETED
Git LFS Details
|
figures/vis.png
DELETED
Git LFS Details
|
models/__init__.py
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
import importlib
|
| 2 |
-
|
| 3 |
-
from models.arch import *
|
| 4 |
-
|
| 5 |
-
from models.cls_model_eval_nocls_reg import ClsModel as ClsReg
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def make_model(name: str):
|
| 9 |
-
|
| 10 |
-
model = ClsReg()
|
| 11 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (388 Bytes). View file
|
|
|
models/__pycache__/cls_model_eval_nocls_reg.cpython-310.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
models/__pycache__/losses.cpython-310.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
models/base_model.py
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import util.util as util
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class BaseModel:
|
| 7 |
-
def name(self):
|
| 8 |
-
return self.__class__.__name__.lower()
|
| 9 |
-
|
| 10 |
-
def initialize(self, opt):
|
| 11 |
-
self.opt = opt
|
| 12 |
-
self.gpu_ids = opt.gpu_ids
|
| 13 |
-
self.isTrain = opt.isTrain
|
| 14 |
-
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
|
| 15 |
-
last_split = opt.checkpoints_dir.split('/')[-1]
|
| 16 |
-
if opt.resume and last_split != 'checkpoints' and (last_split != opt.name or opt.supp_eval):
|
| 17 |
-
|
| 18 |
-
self.save_dir = opt.checkpoints_dir
|
| 19 |
-
self.model_save_dir = os.path.join(opt.checkpoints_dir.replace(opt.checkpoints_dir.split('/')[-1], ''),
|
| 20 |
-
opt.name)
|
| 21 |
-
else:
|
| 22 |
-
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
| 23 |
-
self.model_save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
| 24 |
-
self._count = 0
|
| 25 |
-
|
| 26 |
-
def set_input(self, input):
|
| 27 |
-
self.input = input
|
| 28 |
-
|
| 29 |
-
def forward(self, mode='train'):
|
| 30 |
-
pass
|
| 31 |
-
|
| 32 |
-
# used in test time, no backprop
|
| 33 |
-
def test(self):
|
| 34 |
-
pass
|
| 35 |
-
|
| 36 |
-
def get_image_paths(self):
|
| 37 |
-
pass
|
| 38 |
-
|
| 39 |
-
def optimize_parameters(self):
|
| 40 |
-
pass
|
| 41 |
-
|
| 42 |
-
def get_current_visuals(self):
|
| 43 |
-
return self.input
|
| 44 |
-
|
| 45 |
-
def get_current_errors(self):
|
| 46 |
-
return {}
|
| 47 |
-
|
| 48 |
-
def print_optimizer_param(self):
|
| 49 |
-
print(self.optimizers[-1])
|
| 50 |
-
|
| 51 |
-
def save(self, label=None):
|
| 52 |
-
epoch = self.epoch
|
| 53 |
-
iterations = self.iterations
|
| 54 |
-
|
| 55 |
-
if label is None:
|
| 56 |
-
model_name = os.path.join(self.model_save_dir, self.opt.name + '_%03d_%08d.pt' % ((epoch), (iterations)))
|
| 57 |
-
else:
|
| 58 |
-
model_name = os.path.join(self.model_save_dir, self.opt.name + '_' + label + '.pt')
|
| 59 |
-
|
| 60 |
-
torch.save(self.state_dict(), model_name)
|
| 61 |
-
|
| 62 |
-
def save_eval(self, label=None):
|
| 63 |
-
model_name = os.path.join(self.model_save_dir, label + '.pt')
|
| 64 |
-
|
| 65 |
-
torch.save(self.state_dict_eval(), model_name)
|
| 66 |
-
|
| 67 |
-
def _init_optimizer(self, optimizers):
|
| 68 |
-
self.optimizers = optimizers
|
| 69 |
-
for optimizer in self.optimizers:
|
| 70 |
-
util.set_opt_param(optimizer, 'initial_lr', self.opt.lr)
|
| 71 |
-
util.set_opt_param(optimizer, 'weight_decay', self.opt.wd)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/cls_model_eval_nocls_reg.py
DELETED
|
@@ -1,517 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch import nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from models.losses import DINOLoss
|
| 5 |
-
import os
|
| 6 |
-
import numpy as np
|
| 7 |
-
from collections import OrderedDict
|
| 8 |
-
from ema_pytorch import EMA
|
| 9 |
-
from models.arch.classifier import PretrainedConvNext
|
| 10 |
-
import util.util as util
|
| 11 |
-
import util.index as index
|
| 12 |
-
import models.networks as networks
|
| 13 |
-
import models.losses as losses
|
| 14 |
-
from models import arch
|
| 15 |
-
#from models.arch.dncnn import effnetv2_s
|
| 16 |
-
from .base_model import BaseModel
|
| 17 |
-
from PIL import Image
|
| 18 |
-
from os.path import join
|
| 19 |
-
#from torchviz import make_dot
|
| 20 |
-
from models.arch.RDnet_ import FullNet_NLP
|
| 21 |
-
import timm
|
| 22 |
-
|
| 23 |
-
def tensor2im(image_tensor, imtype=np.uint8):
|
| 24 |
-
image_tensor = image_tensor.detach()
|
| 25 |
-
image_numpy = image_tensor[0].cpu().float().numpy()
|
| 26 |
-
image_numpy = np.clip(image_numpy, 0, 1)
|
| 27 |
-
if image_numpy.shape[0] == 1:
|
| 28 |
-
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
| 29 |
-
image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0
|
| 30 |
-
# image_numpy = image_numpy.astype(imtype)
|
| 31 |
-
return image_numpy
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
class EdgeMap(nn.Module):
|
| 35 |
-
def __init__(self, scale=1):
|
| 36 |
-
super(EdgeMap, self).__init__()
|
| 37 |
-
self.scale = scale
|
| 38 |
-
self.requires_grad = False
|
| 39 |
-
|
| 40 |
-
def forward(self, img):
|
| 41 |
-
img = img / self.scale
|
| 42 |
-
|
| 43 |
-
N, C, H, W = img.shape
|
| 44 |
-
gradX = torch.zeros(N, 1, H, W, dtype=img.dtype, device=img.device)
|
| 45 |
-
gradY = torch.zeros(N, 1, H, W, dtype=img.dtype, device=img.device)
|
| 46 |
-
|
| 47 |
-
gradx = (img[..., 1:, :] - img[..., :-1, :]).abs().sum(dim=1, keepdim=True)
|
| 48 |
-
grady = (img[..., 1:] - img[..., :-1]).abs().sum(dim=1, keepdim=True)
|
| 49 |
-
|
| 50 |
-
gradX[..., :-1, :] += gradx
|
| 51 |
-
gradX[..., 1:, :] += gradx
|
| 52 |
-
gradX[..., 1:-1, :] /= 2
|
| 53 |
-
|
| 54 |
-
gradY[..., :-1] += grady
|
| 55 |
-
gradY[..., 1:] += grady
|
| 56 |
-
gradY[..., 1:-1] /= 2
|
| 57 |
-
|
| 58 |
-
# edge = (gradX + gradY) / 2
|
| 59 |
-
edge = (gradX + gradY)
|
| 60 |
-
|
| 61 |
-
return edge
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
class YTMTNetBase(BaseModel):
|
| 65 |
-
def _init_optimizer(self, optimizers):
|
| 66 |
-
self.optimizers = optimizers
|
| 67 |
-
for optimizer in self.optimizers:
|
| 68 |
-
util.set_opt_param(optimizer, 'initial_lr', self.opt.lr)
|
| 69 |
-
util.set_opt_param(optimizer, 'weight_decay', self.opt.wd)
|
| 70 |
-
|
| 71 |
-
def set_input(self, data, mode='train'):
|
| 72 |
-
target_t = None
|
| 73 |
-
target_r = None
|
| 74 |
-
data_name = None
|
| 75 |
-
identity = False
|
| 76 |
-
mode = mode.lower()
|
| 77 |
-
if mode == 'train':
|
| 78 |
-
input, target_t, target_r = data['input'], data['target_t'], data['target_r']
|
| 79 |
-
elif mode == 'eval':
|
| 80 |
-
input, target_t, target_r, data_name = data['input'], data['target_t'], data['target_r'], data['fn']
|
| 81 |
-
elif mode == 'test':
|
| 82 |
-
input, data_name = data['input'], data['fn']
|
| 83 |
-
else:
|
| 84 |
-
raise NotImplementedError('Mode [%s] is not implemented' % mode)
|
| 85 |
-
|
| 86 |
-
if len(self.gpu_ids) > 0: # transfer data into gpu
|
| 87 |
-
input = input.to(device=self.gpu_ids[0])
|
| 88 |
-
if target_t is not None:
|
| 89 |
-
target_t = target_t.to(device=self.gpu_ids[0])
|
| 90 |
-
if target_r is not None:
|
| 91 |
-
target_r = target_r.to(device=self.gpu_ids[0])
|
| 92 |
-
|
| 93 |
-
self.input = input
|
| 94 |
-
self.identity = identity
|
| 95 |
-
self.input_edge = self.edge_map(self.input)
|
| 96 |
-
self.target_t = target_t
|
| 97 |
-
self.target_r = target_r
|
| 98 |
-
self.data_name = data_name
|
| 99 |
-
|
| 100 |
-
self.issyn = False if 'real' in data else True
|
| 101 |
-
self.aligned = False if 'unaligned' in data else True
|
| 102 |
-
|
| 103 |
-
if target_t is not None:
|
| 104 |
-
self.target_edge = self.edge_map(self.target_t)
|
| 105 |
-
|
| 106 |
-
def eval(self, data, savedir=None, suffix=None, pieapp=None):
|
| 107 |
-
self._eval()
|
| 108 |
-
self.set_input(data, 'eval')
|
| 109 |
-
with torch.no_grad():
|
| 110 |
-
self.forward_eval()
|
| 111 |
-
|
| 112 |
-
output_i = tensor2im(self.output_j[6])
|
| 113 |
-
output_j = tensor2im(self.output_j[7])
|
| 114 |
-
target = tensor2im(self.target_t)
|
| 115 |
-
target_r = tensor2im(self.target_r)
|
| 116 |
-
|
| 117 |
-
if self.aligned:
|
| 118 |
-
res = index.quality_assess(output_i, target)
|
| 119 |
-
else:
|
| 120 |
-
res = {}
|
| 121 |
-
|
| 122 |
-
if savedir is not None:
|
| 123 |
-
if self.data_name is not None:
|
| 124 |
-
name = os.path.splitext(os.path.basename(self.data_name[0]))[0]
|
| 125 |
-
savedir = join(savedir, suffix, name)
|
| 126 |
-
os.makedirs(savedir, exist_ok=True)
|
| 127 |
-
Image.fromarray(output_i.astype(np.uint8)).save(
|
| 128 |
-
join(savedir, '{}_t.png'.format(self.opt.name)))
|
| 129 |
-
Image.fromarray(output_j.astype(np.uint8)).save(
|
| 130 |
-
join(savedir, '{}_r.png'.format(self.opt.name)))
|
| 131 |
-
Image.fromarray(target.astype(np.uint8)).save(join(savedir, 't_label.png'))
|
| 132 |
-
Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(join(savedir, 'm_input.png'))
|
| 133 |
-
else:
|
| 134 |
-
if not os.path.exists(join(savedir, 'transmission_layer')):
|
| 135 |
-
os.makedirs(join(savedir, 'transmission_layer'))
|
| 136 |
-
os.makedirs(join(savedir, 'blended'))
|
| 137 |
-
Image.fromarray(target.astype(np.uint8)).save(
|
| 138 |
-
join(savedir, 'transmission_layer', str(self._count) + '.png'))
|
| 139 |
-
Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(
|
| 140 |
-
join(savedir, 'blended', str(self._count) + '.png'))
|
| 141 |
-
self._count += 1
|
| 142 |
-
|
| 143 |
-
return res
|
| 144 |
-
|
| 145 |
-
def test(self, data, savedir=None):
|
| 146 |
-
# only the 1st input of the whole minibatch would be evaluated
|
| 147 |
-
self._eval()
|
| 148 |
-
self.set_input(data, 'test')
|
| 149 |
-
|
| 150 |
-
if self.data_name is not None and savedir is not None:
|
| 151 |
-
name = os.path.splitext(os.path.basename(self.data_name[0]))[0]
|
| 152 |
-
if not os.path.exists(join(savedir, name)):
|
| 153 |
-
os.makedirs(join(savedir, name))
|
| 154 |
-
|
| 155 |
-
if os.path.exists(join(savedir, name, '{}.png'.format(self.opt.name))):
|
| 156 |
-
return
|
| 157 |
-
|
| 158 |
-
with torch.no_grad():
|
| 159 |
-
output_i, output_j = self.forward()
|
| 160 |
-
output_i = tensor2im(output_i)
|
| 161 |
-
output_j = tensor2im(output_j)
|
| 162 |
-
if self.data_name is not None and savedir is not None:
|
| 163 |
-
Image.fromarray(output_i.astype(np.uint8)).save(join(savedir, name, '{}_l.png'.format(self.opt.name)))
|
| 164 |
-
Image.fromarray(output_j.astype(np.uint8)).save(join(savedir, name, '{}_r.png'.format(self.opt.name)))
|
| 165 |
-
Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(join(savedir, name, 'm_input.png'))
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
class ClsModel(YTMTNetBase):
|
| 169 |
-
def name(self):
|
| 170 |
-
return 'ytmtnet'
|
| 171 |
-
|
| 172 |
-
def __init__(self):
|
| 173 |
-
self.epoch = 0
|
| 174 |
-
self.iterations = 0
|
| 175 |
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 176 |
-
self.net_c = None
|
| 177 |
-
|
| 178 |
-
def print_network(self):
|
| 179 |
-
print('--------------------- Model ---------------------')
|
| 180 |
-
print('##################### NetG #####################')
|
| 181 |
-
networks.print_network(self.net_i)
|
| 182 |
-
if self.isTrain and self.opt.lambda_gan > 0:
|
| 183 |
-
print('##################### NetD #####################')
|
| 184 |
-
networks.print_network(self.netD)
|
| 185 |
-
|
| 186 |
-
def _eval(self):
|
| 187 |
-
self.net_i.eval()
|
| 188 |
-
self.net_c.eval()
|
| 189 |
-
|
| 190 |
-
def _train(self):
|
| 191 |
-
self.net_i.train()
|
| 192 |
-
self.net_c.eval()
|
| 193 |
-
def initialize(self, opt):
|
| 194 |
-
self.opt=opt
|
| 195 |
-
BaseModel.initialize(self, opt)
|
| 196 |
-
|
| 197 |
-
in_channels = 3
|
| 198 |
-
self.vgg = None
|
| 199 |
-
|
| 200 |
-
if opt.hyper:
|
| 201 |
-
self.vgg = losses.Vgg19(requires_grad=False).to(self.device)
|
| 202 |
-
in_channels += 1472
|
| 203 |
-
channels = [64, 128, 256, 512]
|
| 204 |
-
layers = [2, 2, 4, 2]
|
| 205 |
-
num_subnet = opt.num_subnet
|
| 206 |
-
self.net_c = PretrainedConvNext("convnext_small_in22k").cuda()
|
| 207 |
-
|
| 208 |
-
self.net_c.load_state_dict(torch.load('pretrained/cls_model.pth')['icnn'])
|
| 209 |
-
|
| 210 |
-
self.net_i = FullNet_NLP(channels, layers, num_subnet, opt.loss_col,num_classes=1000, drop_path=0,save_memory=True, inter_supv=True, head_init_scale=None, kernel_size=3).to(self.device)
|
| 211 |
-
|
| 212 |
-
self.edge_map = EdgeMap(scale=1).to(self.device)
|
| 213 |
-
|
| 214 |
-
if self.isTrain:
|
| 215 |
-
self.loss_dic = losses.init_loss(opt, self.Tensor)
|
| 216 |
-
vggloss = losses.ContentLoss()
|
| 217 |
-
vggloss.initialize(losses.VGGLoss(self.vgg))
|
| 218 |
-
self.loss_dic['t_vgg'] = vggloss
|
| 219 |
-
|
| 220 |
-
cxloss = losses.ContentLoss()
|
| 221 |
-
if opt.unaligned_loss == 'vgg':
|
| 222 |
-
cxloss.initialize(losses.VGGLoss(self.vgg, weights=[0.1], indices=[opt.vgg_layer]))
|
| 223 |
-
elif opt.unaligned_loss == 'ctx':
|
| 224 |
-
cxloss.initialize(losses.CXLoss(self.vgg, weights=[0.1, 0.1, 0.1], indices=[8, 13, 22]))
|
| 225 |
-
elif opt.unaligned_loss == 'mse':
|
| 226 |
-
cxloss.initialize(nn.MSELoss())
|
| 227 |
-
elif opt.unaligned_loss == 'ctx_vgg':
|
| 228 |
-
cxloss.initialize(losses.CXLoss(self.vgg, weights=[0.1, 0.1, 0.1, 0.1], indices=[8, 13, 22, 31],
|
| 229 |
-
criterions=[losses.CX_loss] * 3 + [nn.L1Loss()]))
|
| 230 |
-
else:
|
| 231 |
-
raise NotImplementedError
|
| 232 |
-
self.scaler=torch.cuda.amp.GradScaler()
|
| 233 |
-
with torch.autocast(device_type='cuda',dtype=torch.float16):
|
| 234 |
-
self.dinoloss=DINOLoss()
|
| 235 |
-
self.loss_dic['t_cx'] = cxloss
|
| 236 |
-
|
| 237 |
-
self.optimizer_G = torch.optim.Adam(self.net_i.parameters(),
|
| 238 |
-
lr=opt.lr, betas=(0.9, 0.999), weight_decay=opt.wd)
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
self._init_optimizer([self.optimizer_G])
|
| 242 |
-
|
| 243 |
-
if opt.resume:
|
| 244 |
-
self.load(self, opt.resume_epoch)
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
def backward_D(self):
|
| 248 |
-
loss_D=[]
|
| 249 |
-
weight=self.opt.weight_loss
|
| 250 |
-
for p in self.netD.parameters():
|
| 251 |
-
p.requires_grad = True
|
| 252 |
-
for i in range(4):
|
| 253 |
-
loss_D_1, pred_fake_1, pred_real_1 = self.loss_dic['gan'].get_loss(
|
| 254 |
-
self.netD, self.input, self.output_j[2*i], self.target_t)
|
| 255 |
-
loss_D.append(loss_D_1*weight)
|
| 256 |
-
weight+=self.opt.weight_loss
|
| 257 |
-
loss_sum=sum(loss_D)
|
| 258 |
-
|
| 259 |
-
self.loss_D, self.pred_fake, self.pred_real = (loss_sum, pred_fake_1, pred_real_1)
|
| 260 |
-
|
| 261 |
-
(self.loss_D * self.opt.lambda_gan).backward(retain_graph=True)
|
| 262 |
-
|
| 263 |
-
def get_loss(self, out_l, out_r):
|
| 264 |
-
loss_G_GAN_sum=[]
|
| 265 |
-
loss_icnn_pixel_sum=[]
|
| 266 |
-
loss_rcnn_pixel_sum=[]
|
| 267 |
-
loss_icnn_vgg_sum=[]
|
| 268 |
-
weight=self.opt.weight_loss
|
| 269 |
-
for i in range(self.opt.loss_col):
|
| 270 |
-
out_r_clean=out_r[2*i]
|
| 271 |
-
out_r_reflection=out_r[2*i+1]
|
| 272 |
-
if i != self.opt.loss_col -1:
|
| 273 |
-
loss_G_GAN = 0
|
| 274 |
-
loss_icnn_pixel = self.loss_dic['t_pixel'].get_loss(out_r_clean, self.target_t)
|
| 275 |
-
loss_rcnn_pixel = self.loss_dic['r_pixel'].get_loss(out_r_reflection, self.target_r) * 1.5 * self.opt.r_pixel_weight
|
| 276 |
-
loss_icnn_vgg = self.loss_dic['t_vgg'].get_loss(out_r_clean, self.target_t) * self.opt.lambda_vgg
|
| 277 |
-
else:
|
| 278 |
-
if self.opt.lambda_gan>0:
|
| 279 |
-
|
| 280 |
-
loss_G_GAN=0
|
| 281 |
-
else:
|
| 282 |
-
loss_G_GAN=0
|
| 283 |
-
loss_icnn_pixel = self.loss_dic['t_pixel'].get_loss(out_r_clean, self.target_t)
|
| 284 |
-
loss_rcnn_pixel = self.loss_dic['r_pixel'].get_loss(out_r_reflection, self.target_r) * 1.5 * self.opt.r_pixel_weight
|
| 285 |
-
loss_icnn_vgg = self.loss_dic['t_vgg'].get_loss(out_r_clean, self.target_t) * self.opt.lambda_vgg
|
| 286 |
-
|
| 287 |
-
loss_G_GAN_sum.append(loss_G_GAN*weight)
|
| 288 |
-
loss_icnn_pixel_sum.append(loss_icnn_pixel*weight)
|
| 289 |
-
loss_rcnn_pixel_sum.append(loss_rcnn_pixel*weight)
|
| 290 |
-
loss_icnn_vgg_sum.append(loss_icnn_vgg*weight)
|
| 291 |
-
weight=weight+self.opt.weight_loss
|
| 292 |
-
return sum(loss_G_GAN_sum), sum(loss_icnn_pixel_sum), sum(loss_rcnn_pixel_sum), sum(loss_icnn_vgg_sum)
|
| 293 |
-
|
| 294 |
-
def backward_G(self):
|
| 295 |
-
|
| 296 |
-
self.loss_G_GAN,self.loss_icnn_pixel, self.loss_rcnn_pixel, \
|
| 297 |
-
self.loss_icnn_vgg = self.get_loss(self.output_i, self.output_j)
|
| 298 |
-
|
| 299 |
-
self.loss_exclu = self.exclusion_loss(self.output_i, self.output_j, 3)
|
| 300 |
-
|
| 301 |
-
self.loss_recons = self.loss_dic['recons'](self.output_i, self.output_j, self.input) * 0.2
|
| 302 |
-
|
| 303 |
-
self.loss_G = self.loss_G_GAN +self.loss_icnn_pixel + self.loss_rcnn_pixel + \
|
| 304 |
-
self.loss_icnn_vgg
|
| 305 |
-
self.scaler.scale(self.loss_G).backward()
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
def hyper_column(self, input_img):
|
| 310 |
-
hypercolumn = self.vgg(input_img)
|
| 311 |
-
_, C, H, W = input_img.shape
|
| 312 |
-
hypercolumn = [F.interpolate(feature.detach(), size=(H, W), mode='bilinear', align_corners=False) for
|
| 313 |
-
feature in hypercolumn]
|
| 314 |
-
input_i = [input_img]
|
| 315 |
-
input_i.extend(hypercolumn)
|
| 316 |
-
input_i = torch.cat(input_i, dim=1)
|
| 317 |
-
return input_i
|
| 318 |
-
|
| 319 |
-
def forward(self):
|
| 320 |
-
# without edge
|
| 321 |
-
|
| 322 |
-
self.output_j=[]
|
| 323 |
-
input_i = self.input
|
| 324 |
-
if self.vgg is not None:
|
| 325 |
-
input_i = self.hyper_column(input_i)
|
| 326 |
-
with torch.no_grad():
|
| 327 |
-
ipt = self.net_c(input_i)
|
| 328 |
-
output_i, output_j = self.net_i(input_i,ipt,prompt=True)
|
| 329 |
-
self.output_i = output_i
|
| 330 |
-
for i in range(self.opt.loss_col):
|
| 331 |
-
out_reflection, out_clean = output_j[i][:, :3, ...], output_j[i][:, 3:, ...]
|
| 332 |
-
self.output_j.append(out_clean)
|
| 333 |
-
self.output_j.append(out_reflection)
|
| 334 |
-
return self.output_i, self.output_j
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
@torch.no_grad()
|
| 338 |
-
def forward_eval(self):
|
| 339 |
-
|
| 340 |
-
self.output_j=[]
|
| 341 |
-
input_i = self.input
|
| 342 |
-
if self.vgg is not None:
|
| 343 |
-
input_i = self.hyper_column(input_i)
|
| 344 |
-
ipt = self.net_c(input_i)
|
| 345 |
-
|
| 346 |
-
output_i, output_j = self.net_i(input_i,ipt,prompt=True)
|
| 347 |
-
self.output_i = output_i #alpha * output_i + beta
|
| 348 |
-
for i in range(self.opt.loss_col):
|
| 349 |
-
out_reflection, out_clean = output_j[i][:, :3, ...], output_j[i][:, 3:, ...]
|
| 350 |
-
self.output_j.append(out_clean)
|
| 351 |
-
self.output_j.append(out_reflection)
|
| 352 |
-
return self.output_i, self.output_j
|
| 353 |
-
|
| 354 |
-
def optimize_parameters(self):
|
| 355 |
-
self._train()
|
| 356 |
-
self.forward()
|
| 357 |
-
self.optimizer_G.zero_grad()
|
| 358 |
-
self.backward_G()
|
| 359 |
-
self.optimizer_G.step()
|
| 360 |
-
|
| 361 |
-
def return_output(self):
|
| 362 |
-
output_clean = self.output_j[1]
|
| 363 |
-
output_reflection = self.output_j[0]
|
| 364 |
-
output_clean = tensor2im(output_clean).astype(np.uint8)
|
| 365 |
-
output_reflection = tensor2im(output_reflection).astype(np.uint8)
|
| 366 |
-
input=tensor2im(self.input)
|
| 367 |
-
return output_clean,output_reflection,input
|
| 368 |
-
def exclusion_loss(self, img_T, img_R, level=3, eps=1e-6):
|
| 369 |
-
loss_gra=[]
|
| 370 |
-
weight=0.25
|
| 371 |
-
for i in range(4):
|
| 372 |
-
grad_x_loss = []
|
| 373 |
-
grad_y_loss = []
|
| 374 |
-
img_T=self.output_j[2*i]
|
| 375 |
-
img_R=self.output_j[2*i+1]
|
| 376 |
-
for l in range(level):
|
| 377 |
-
grad_x_T, grad_y_T = self.compute_grad(img_T)
|
| 378 |
-
grad_x_R, grad_y_R = self.compute_grad(img_R)
|
| 379 |
-
|
| 380 |
-
alphax = (2.0 * torch.mean(torch.abs(grad_x_T))) / (torch.mean(torch.abs(grad_x_R)) + eps)
|
| 381 |
-
alphay = (2.0 * torch.mean(torch.abs(grad_y_T))) / (torch.mean(torch.abs(grad_y_R)) + eps)
|
| 382 |
-
|
| 383 |
-
gradx1_s = (torch.sigmoid(grad_x_T) * 2) - 1 # mul 2 minus 1 is to change sigmoid into tanh
|
| 384 |
-
grady1_s = (torch.sigmoid(grad_y_T) * 2) - 1
|
| 385 |
-
gradx2_s = (torch.sigmoid(grad_x_R * alphax) * 2) - 1
|
| 386 |
-
grady2_s = (torch.sigmoid(grad_y_R * alphay) * 2) - 1
|
| 387 |
-
|
| 388 |
-
grad_x_loss.append(((torch.mean(torch.mul(gradx1_s.pow(2), gradx2_s.pow(2)))) + eps) ** 0.25)
|
| 389 |
-
grad_y_loss.append(((torch.mean(torch.mul(grady1_s.pow(2), grady2_s.pow(2)))) + eps) ** 0.25)
|
| 390 |
-
|
| 391 |
-
img_T = F.interpolate(img_T, scale_factor=0.5, mode='bilinear')
|
| 392 |
-
img_R = F.interpolate(img_R, scale_factor=0.5, mode='bilinear')
|
| 393 |
-
loss_gradxy = torch.sum(sum(grad_x_loss) / 3) + torch.sum(sum(grad_y_loss) / 3)
|
| 394 |
-
loss_gra.append(loss_gradxy*weight)
|
| 395 |
-
weight+=0.25
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
return sum(loss_gra) / 2
|
| 399 |
-
|
| 400 |
-
def contain_loss(self, img_T, img_R, img_I, eps=1e-6):
|
| 401 |
-
pix_num = np.prod(img_I.shape)
|
| 402 |
-
predict_tx, predict_ty = self.compute_grad(img_T)
|
| 403 |
-
predict_tx, predict_ty = self.compute_grad(img_T)
|
| 404 |
-
predict_rx, predict_ry = self.compute_grad(img_R)
|
| 405 |
-
input_x, input_y = self.compute_grad(img_I)
|
| 406 |
-
|
| 407 |
-
out = torch.norm(predict_tx / (input_x + eps), 2) ** 2 + \
|
| 408 |
-
torch.norm(predict_ty / (input_y + eps), 2) ** 2 + \
|
| 409 |
-
torch.norm(predict_rx / (input_x + eps), 2) ** 2 + \
|
| 410 |
-
torch.norm(predict_ry / (input_y + eps), 2) ** 2
|
| 411 |
-
|
| 412 |
-
return out / pix_num
|
| 413 |
-
|
| 414 |
-
def compute_grad(self, img):
|
| 415 |
-
gradx = img[:, :, 1:, :] - img[:, :, :-1, :]
|
| 416 |
-
grady = img[:, :, :, 1:] - img[:, :, :, :-1]
|
| 417 |
-
return gradx, grady
|
| 418 |
-
|
| 419 |
-
def load(self, model, resume_epoch=None):
|
| 420 |
-
icnn_path = model.opt.icnn_path
|
| 421 |
-
state_dict = torch.load(icnn_path)
|
| 422 |
-
model.net_i.load_state_dict(state_dict['icnn'])
|
| 423 |
-
return state_dict
|
| 424 |
-
|
| 425 |
-
def state_dict(self):
|
| 426 |
-
state_dict = {
|
| 427 |
-
'icnn': self.net_i.state_dict(),
|
| 428 |
-
'opt_g': self.optimizer_G.state_dict(),
|
| 429 |
-
#'ema' : self.ema.state_dict(),
|
| 430 |
-
'epoch': self.epoch, 'iterations': self.iterations
|
| 431 |
-
}
|
| 432 |
-
|
| 433 |
-
if self.opt.lambda_gan > 0:
|
| 434 |
-
state_dict.update({
|
| 435 |
-
'opt_d': self.optimizer_D.state_dict(),
|
| 436 |
-
'netD': self.netD.state_dict(),
|
| 437 |
-
})
|
| 438 |
-
|
| 439 |
-
return state_dict
|
| 440 |
-
class AvgPool2d(nn.Module):
|
| 441 |
-
def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
|
| 442 |
-
super().__init__()
|
| 443 |
-
self.kernel_size = kernel_size
|
| 444 |
-
self.base_size = base_size
|
| 445 |
-
self.auto_pad = auto_pad
|
| 446 |
-
|
| 447 |
-
# only used for fast implementation
|
| 448 |
-
self.fast_imp = fast_imp
|
| 449 |
-
self.rs = [5, 4, 3, 2, 1]
|
| 450 |
-
self.max_r1 = self.rs[0]
|
| 451 |
-
self.max_r2 = self.rs[0]
|
| 452 |
-
self.train_size = train_size
|
| 453 |
-
|
| 454 |
-
def extra_repr(self) -> str:
|
| 455 |
-
return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
|
| 456 |
-
self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
|
| 457 |
-
)
|
| 458 |
-
|
| 459 |
-
def forward(self, x):
|
| 460 |
-
if self.kernel_size is None and self.base_size:
|
| 461 |
-
train_size = self.train_size
|
| 462 |
-
if isinstance(self.base_size, int):
|
| 463 |
-
self.base_size = (self.base_size, self.base_size)
|
| 464 |
-
self.kernel_size = list(self.base_size)
|
| 465 |
-
self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
|
| 466 |
-
self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
|
| 467 |
-
|
| 468 |
-
# only used for fast implementation
|
| 469 |
-
self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
|
| 470 |
-
self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
|
| 471 |
-
|
| 472 |
-
if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
|
| 473 |
-
return F.adaptive_avg_pool2d(x, 1)
|
| 474 |
-
|
| 475 |
-
if self.fast_imp: # Non-equivalent implementation but faster
|
| 476 |
-
h, w = x.shape[2:]
|
| 477 |
-
if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
|
| 478 |
-
out = F.adaptive_avg_pool2d(x, 1)
|
| 479 |
-
else:
|
| 480 |
-
r1 = [r for r in self.rs if h % r == 0][0]
|
| 481 |
-
r2 = [r for r in self.rs if w % r == 0][0]
|
| 482 |
-
# reduction_constraint
|
| 483 |
-
r1 = min(self.max_r1, r1)
|
| 484 |
-
r2 = min(self.max_r2, r2)
|
| 485 |
-
s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
|
| 486 |
-
n, c, h, w = s.shape
|
| 487 |
-
k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
|
| 488 |
-
out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
|
| 489 |
-
out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
|
| 490 |
-
else:
|
| 491 |
-
n, c, h, w = x.shape
|
| 492 |
-
s = x.cumsum(dim=-1).cumsum_(dim=-2)
|
| 493 |
-
s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
|
| 494 |
-
k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
|
| 495 |
-
s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
|
| 496 |
-
out = s4 + s1 - s2 - s3
|
| 497 |
-
out = out / (k1 * k2)
|
| 498 |
-
|
| 499 |
-
if self.auto_pad:
|
| 500 |
-
n, c, h, w = x.shape
|
| 501 |
-
_h, _w = out.shape[2:]
|
| 502 |
-
# print(x.shape, self.kernel_size)
|
| 503 |
-
pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
|
| 504 |
-
out = torch.nn.functional.pad(out, pad2d, mode='replicate')
|
| 505 |
-
|
| 506 |
-
return out
|
| 507 |
-
|
| 508 |
-
def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
|
| 509 |
-
for n, m in model.named_children():
|
| 510 |
-
if len(list(m.children())) > 0:
|
| 511 |
-
## compound module, go inside it
|
| 512 |
-
replace_layers(m, base_size, train_size, fast_imp, **kwargs)
|
| 513 |
-
|
| 514 |
-
if isinstance(m, nn.AdaptiveAvgPool2d):
|
| 515 |
-
pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
|
| 516 |
-
assert m.output_size == 1
|
| 517 |
-
setattr(model, n, pool)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/losses.py
DELETED
|
@@ -1,468 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
from pytorch_msssim import SSIM
|
| 6 |
-
from models.vit_feature_extractor import VitExtractor
|
| 7 |
-
from models.vgg import Vgg19
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
###############################################################################
|
| 11 |
-
# Functions
|
| 12 |
-
###############################################################################
|
| 13 |
-
def compute_gradient(img):
|
| 14 |
-
gradx = img[..., 1:, :] - img[..., :-1, :]
|
| 15 |
-
grady = img[..., 1:] - img[..., :-1]
|
| 16 |
-
return gradx, grady
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class GradientLoss(nn.Module):
|
| 20 |
-
def __init__(self):
|
| 21 |
-
super(GradientLoss, self).__init__()
|
| 22 |
-
self.loss = nn.L1Loss()
|
| 23 |
-
|
| 24 |
-
def forward(self, predict, target):
|
| 25 |
-
predict_gradx, predict_grady = compute_gradient(predict)
|
| 26 |
-
target_gradx, target_grady = compute_gradient(target)
|
| 27 |
-
|
| 28 |
-
return self.loss(predict_gradx, target_gradx) + self.loss(predict_grady, target_grady)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
class ContainLoss(nn.Module):
|
| 32 |
-
def __init__(self, eps=1e-12):
|
| 33 |
-
super(ContainLoss, self).__init__()
|
| 34 |
-
self.eps = eps
|
| 35 |
-
|
| 36 |
-
def forward(self, predict_t, predict_r, input_image):
|
| 37 |
-
pix_num = np.prod(input_image.shape)
|
| 38 |
-
predict_tx, predict_ty = compute_gradient(predict_t)
|
| 39 |
-
predict_rx, predict_ry = compute_gradient(predict_r)
|
| 40 |
-
input_x, input_y = compute_gradient(input_image)
|
| 41 |
-
|
| 42 |
-
out = torch.norm(predict_tx / (input_x + self.eps), 2) ** 2 + \
|
| 43 |
-
torch.norm(predict_ty / (input_y + self.eps), 2) ** 2 + \
|
| 44 |
-
torch.norm(predict_rx / (input_x + self.eps), 2) ** 2 + \
|
| 45 |
-
torch.norm(predict_ry / (input_y + self.eps), 2) ** 2
|
| 46 |
-
|
| 47 |
-
return out / pix_num
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class MultipleLoss(nn.Module):
|
| 51 |
-
def __init__(self, losses, weight=None):
|
| 52 |
-
super(MultipleLoss, self).__init__()
|
| 53 |
-
self.losses = nn.ModuleList(losses)
|
| 54 |
-
self.weight = weight or [1 / len(self.losses)] * len(self.losses)
|
| 55 |
-
|
| 56 |
-
def forward(self, predict, target):
|
| 57 |
-
total_loss = 0
|
| 58 |
-
for weight, loss in zip(self.weight, self.losses):
|
| 59 |
-
total_loss += loss(predict, target) * weight
|
| 60 |
-
return total_loss
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
class MeanShift(nn.Conv2d):
|
| 64 |
-
def __init__(self, data_mean, data_std, data_range=1, norm=True):
|
| 65 |
-
"""norm (bool): normalize/denormalize the stats"""
|
| 66 |
-
c = len(data_mean)
|
| 67 |
-
super(MeanShift, self).__init__(c, c, kernel_size=1)
|
| 68 |
-
std = torch.Tensor(data_std)
|
| 69 |
-
self.weight.data = torch.eye(c).view(c, c, 1, 1)
|
| 70 |
-
if norm:
|
| 71 |
-
self.weight.data.div_(std.view(c, 1, 1, 1))
|
| 72 |
-
self.bias.data = -1 * data_range * torch.Tensor(data_mean)
|
| 73 |
-
self.bias.data.div_(std)
|
| 74 |
-
else:
|
| 75 |
-
self.weight.data.mul_(std.view(c, 1, 1, 1))
|
| 76 |
-
self.bias.data = data_range * torch.Tensor(data_mean)
|
| 77 |
-
self.requires_grad = False
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class VGGLoss(nn.Module):
|
| 81 |
-
def __init__(self, vgg=None, weights=None, indices=None, normalize=True):
|
| 82 |
-
super(VGGLoss, self).__init__()
|
| 83 |
-
if vgg is None:
|
| 84 |
-
self.vgg = torch.compile(Vgg19().cuda())
|
| 85 |
-
else:
|
| 86 |
-
self.vgg = vgg
|
| 87 |
-
self.criterion = nn.L1Loss()
|
| 88 |
-
self.weights = weights or [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10 / 1.5]
|
| 89 |
-
self.indices = indices or [2, 7, 12, 21, 30]
|
| 90 |
-
if normalize:
|
| 91 |
-
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
|
| 92 |
-
else:
|
| 93 |
-
self.normalize = None
|
| 94 |
-
|
| 95 |
-
def forward(self, x, y):
|
| 96 |
-
if self.normalize is not None:
|
| 97 |
-
x = self.normalize(x)
|
| 98 |
-
y = self.normalize(y)
|
| 99 |
-
with torch.no_grad():
|
| 100 |
-
y_vgg = self.vgg(y, self.indices)
|
| 101 |
-
x_vgg = self.vgg(x, self.indices) #, self.vgg(y, self.indices)
|
| 102 |
-
loss = 0
|
| 103 |
-
for i in range(len(x_vgg)):
|
| 104 |
-
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i]) #.detach())
|
| 105 |
-
|
| 106 |
-
return loss
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def l1_norm_dim(x, dim):
|
| 110 |
-
return torch.mean(torch.abs(x), dim=dim)
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def l1_norm(x):
|
| 114 |
-
return torch.mean(torch.abs(x))
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def l2_norm(x):
|
| 118 |
-
return torch.mean(torch.square(x))
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def gradient_norm_kernel(x, kernel_size=10):
|
| 122 |
-
out_h, out_v = compute_gradient(x)
|
| 123 |
-
shape = out_h.shape
|
| 124 |
-
out_h = F.unfold(out_h, kernel_size=(kernel_size, kernel_size), stride=(1, 1))
|
| 125 |
-
out_h = out_h.reshape(shape[0], shape[1], kernel_size * kernel_size, -1)
|
| 126 |
-
out_h = l1_norm_dim(out_h, 2)
|
| 127 |
-
out_v = F.unfold(out_v, kernel_size=(kernel_size, kernel_size), stride=(1, 1))
|
| 128 |
-
out_v = out_v.reshape(shape[0], shape[1], kernel_size * kernel_size, -1)
|
| 129 |
-
out_v = l1_norm_dim(out_v, 2)
|
| 130 |
-
return out_h, out_v
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
class KTVLoss(nn.Module):
|
| 134 |
-
def __init__(self, kernel_size=10):
|
| 135 |
-
super().__init__()
|
| 136 |
-
self.kernel_size = kernel_size
|
| 137 |
-
self.criterion = nn.L1Loss()
|
| 138 |
-
self.eps = 1e-6
|
| 139 |
-
|
| 140 |
-
def forward(self, out_l, out_r, input_i):
|
| 141 |
-
out_l_normx, out_l_normy = gradient_norm_kernel(out_l, self.kernel_size)
|
| 142 |
-
out_r_normx, out_r_normy = gradient_norm_kernel(out_r, self.kernel_size)
|
| 143 |
-
input_normx, input_normy = gradient_norm_kernel(input_i, self.kernel_size)
|
| 144 |
-
norm_l = out_l_normx + out_l_normy
|
| 145 |
-
norm_r = out_r_normx + out_r_normy
|
| 146 |
-
norm_target = input_normx + input_normy + self.eps
|
| 147 |
-
norm_loss = (norm_l / norm_target + norm_r / norm_target).mean()
|
| 148 |
-
|
| 149 |
-
out_lx, out_ly = compute_gradient(out_l)
|
| 150 |
-
out_rx, out_ry = compute_gradient(out_r)
|
| 151 |
-
input_x, input_y = compute_gradient(input_i)
|
| 152 |
-
gradient_diffx = self.criterion(out_lx + out_rx, input_x)
|
| 153 |
-
gradient_diffy = self.criterion(out_ly + out_ry, input_y)
|
| 154 |
-
grad_loss = gradient_diffx + gradient_diffy
|
| 155 |
-
|
| 156 |
-
loss = norm_loss * 1e-4 + grad_loss
|
| 157 |
-
return loss
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
class MTVLoss(nn.Module):
|
| 161 |
-
def __init__(self, kernel_size=10):
|
| 162 |
-
super().__init__()
|
| 163 |
-
self.criterion = nn.L1Loss()
|
| 164 |
-
self.norm = l1_norm
|
| 165 |
-
|
| 166 |
-
def forward(self, out_l, out_r, input_i):
|
| 167 |
-
out_lx, out_ly = compute_gradient(out_l)
|
| 168 |
-
out_rx, out_ry = compute_gradient(out_r)
|
| 169 |
-
input_x, input_y = compute_gradient(input_i)
|
| 170 |
-
|
| 171 |
-
norm_l = self.norm(out_lx) + self.norm(out_ly)
|
| 172 |
-
norm_r = self.norm(out_rx) + self.norm(out_ry)
|
| 173 |
-
norm_target = self.norm(input_x) + self.norm(input_y)
|
| 174 |
-
|
| 175 |
-
gradient_diffx = self.criterion(out_lx + out_rx, input_x)
|
| 176 |
-
gradient_diffy = self.criterion(out_ly + out_ry, input_y)
|
| 177 |
-
|
| 178 |
-
loss = (norm_l / norm_target + norm_r / norm_target) * 1e-5 + gradient_diffx + gradient_diffy
|
| 179 |
-
|
| 180 |
-
return loss
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
class ReconsLoss(nn.Module):
|
| 184 |
-
def __init__(self, edge_recons=True):
|
| 185 |
-
super().__init__()
|
| 186 |
-
self.criterion = nn.L1Loss()
|
| 187 |
-
self.norm = l1_norm
|
| 188 |
-
self.edge_recons = edge_recons
|
| 189 |
-
self.mse_loss=nn.MSELoss()
|
| 190 |
-
|
| 191 |
-
def forward(self, out_l, out_r, input_i):
|
| 192 |
-
loss_sum=[]
|
| 193 |
-
weight=0.25
|
| 194 |
-
for i in range(4):
|
| 195 |
-
#out_res = out_l[i]
|
| 196 |
-
out_clean=out_r[2*i]
|
| 197 |
-
out_reflection=out_r[2*i+1]
|
| 198 |
-
#content_diff = self.criterion(out_clean + out_reflection, input_i)
|
| 199 |
-
# if self.edge_recons:
|
| 200 |
-
# out_lx, out_ly = compute_gradient(out_clean)
|
| 201 |
-
# out_rx, out_ry = compute_gradient(out_reflection)
|
| 202 |
-
# #out_resx, out_resy = compute_gradient(out_res)
|
| 203 |
-
# input_x, input_y = compute_gradient(input_i)
|
| 204 |
-
|
| 205 |
-
# gradient_diffx = self.criterion(out_lx + out_rx, input_x)
|
| 206 |
-
# gradient_diffy = self.criterion(out_ly + out_ry, input_y)
|
| 207 |
-
|
| 208 |
-
# loss = content_diff + (gradient_diffx + gradient_diffy) * 5.0
|
| 209 |
-
# else:
|
| 210 |
-
# loss = content_diff
|
| 211 |
-
loss=self.mse_loss(out_clean+out_reflection,input_i)
|
| 212 |
-
loss_sum.append(loss*weight)
|
| 213 |
-
weight=weight+0.25
|
| 214 |
-
|
| 215 |
-
return sum(loss_sum)
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
class ReconsLossX(nn.Module):
|
| 219 |
-
def __init__(self, edge_recons=True):
|
| 220 |
-
super().__init__()
|
| 221 |
-
self.criterion = nn.MSELoss()
|
| 222 |
-
self.norm = l1_norm
|
| 223 |
-
self.edge_recons = edge_recons
|
| 224 |
-
|
| 225 |
-
def forward(self, out, input_i):
|
| 226 |
-
content_diff = self.criterion(out, input_i)
|
| 227 |
-
if self.edge_recons:
|
| 228 |
-
out_x, out_y = compute_gradient(out)
|
| 229 |
-
input_x, input_y = compute_gradient(input_i)
|
| 230 |
-
|
| 231 |
-
gradient_diffx = self.criterion(out_x, input_x)
|
| 232 |
-
gradient_diffy = self.criterion(out_y, input_y)
|
| 233 |
-
|
| 234 |
-
loss = content_diff + (gradient_diffx + gradient_diffy) * 1.0
|
| 235 |
-
else:
|
| 236 |
-
loss = content_diff
|
| 237 |
-
return loss
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
class ContentLoss():
|
| 241 |
-
def initialize(self, loss):
|
| 242 |
-
self.criterion = loss
|
| 243 |
-
|
| 244 |
-
def get_loss(self, fakeIm, realIm):
|
| 245 |
-
return self.criterion(fakeIm, realIm)
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
class GANLoss(nn.Module):
|
| 249 |
-
def __init__(self, use_l1=True, target_real_label=1.0, target_fake_label=0.0,
|
| 250 |
-
tensor=torch.FloatTensor):
|
| 251 |
-
super(GANLoss, self).__init__()
|
| 252 |
-
self.real_label = target_real_label
|
| 253 |
-
self.fake_label = target_fake_label
|
| 254 |
-
self.real_label_var = None
|
| 255 |
-
self.fake_label_var = None
|
| 256 |
-
self.Tensor = tensor
|
| 257 |
-
if use_l1:
|
| 258 |
-
self.loss = nn.L1Loss()
|
| 259 |
-
else:
|
| 260 |
-
self.loss = nn.BCEWithLogitsLoss() # absorb sigmoid into BCELoss
|
| 261 |
-
|
| 262 |
-
def get_target_tensor(self, input, target_is_real):
|
| 263 |
-
target_tensor = None
|
| 264 |
-
if target_is_real:
|
| 265 |
-
create_label = ((self.real_label_var is None) or
|
| 266 |
-
(self.real_label_var.numel() != input.numel()))
|
| 267 |
-
if create_label:
|
| 268 |
-
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
|
| 269 |
-
self.real_label_var = real_tensor
|
| 270 |
-
target_tensor = self.real_label_var
|
| 271 |
-
else:
|
| 272 |
-
create_label = ((self.fake_label_var is None) or
|
| 273 |
-
(self.fake_label_var.numel() != input.numel()))
|
| 274 |
-
if create_label:
|
| 275 |
-
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
|
| 276 |
-
self.fake_label_var = fake_tensor
|
| 277 |
-
target_tensor = self.fake_label_var
|
| 278 |
-
return target_tensor
|
| 279 |
-
|
| 280 |
-
def __call__(self, input, target_is_real):
|
| 281 |
-
if isinstance(input, list):
|
| 282 |
-
loss = 0
|
| 283 |
-
for input_i in input:
|
| 284 |
-
target_tensor = self.get_target_tensor(input_i, target_is_real)
|
| 285 |
-
loss += self.loss(input_i, target_tensor)
|
| 286 |
-
return loss
|
| 287 |
-
else:
|
| 288 |
-
target_tensor = self.get_target_tensor(input, target_is_real)
|
| 289 |
-
return self.loss(input, target_tensor)
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
class DiscLoss():
|
| 293 |
-
def name(self):
|
| 294 |
-
return 'SGAN'
|
| 295 |
-
|
| 296 |
-
def initialize(self, opt, tensor):
|
| 297 |
-
self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
|
| 298 |
-
|
| 299 |
-
def get_g_loss(self, net, realA, fakeB, realB):
|
| 300 |
-
# First, G(A) should fake the discriminator
|
| 301 |
-
pred_fake = net.forward(fakeB)
|
| 302 |
-
return self.criterionGAN(pred_fake, 1)
|
| 303 |
-
|
| 304 |
-
def get_loss(self, net, realA=None, fakeB=None, realB=None):
|
| 305 |
-
pred_fake = None
|
| 306 |
-
pred_real = None
|
| 307 |
-
loss_D_fake = 0
|
| 308 |
-
loss_D_real = 0
|
| 309 |
-
# Fake
|
| 310 |
-
# stop backprop to the generator by detaching fake_B
|
| 311 |
-
# Generated Image Disc Output should be close to zero
|
| 312 |
-
|
| 313 |
-
if fakeB is not None:
|
| 314 |
-
pred_fake = net.forward(fakeB.detach())
|
| 315 |
-
loss_D_fake = self.criterionGAN(pred_fake, 0)
|
| 316 |
-
|
| 317 |
-
# Real
|
| 318 |
-
if realB is not None:
|
| 319 |
-
pred_real = net.forward(realB)
|
| 320 |
-
loss_D_real = self.criterionGAN(pred_real, 1)
|
| 321 |
-
|
| 322 |
-
# Combined loss
|
| 323 |
-
loss_D = (loss_D_fake + loss_D_real) * 0.5
|
| 324 |
-
return loss_D, pred_fake, pred_real
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
class DiscLossR(DiscLoss):
|
| 328 |
-
# RSGAN from
|
| 329 |
-
# https://arxiv.org/abs/1807.00734
|
| 330 |
-
def name(self):
|
| 331 |
-
return 'RSGAN'
|
| 332 |
-
|
| 333 |
-
def initialize(self, opt, tensor):
|
| 334 |
-
DiscLoss.initialize(self, opt, tensor)
|
| 335 |
-
self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
|
| 336 |
-
|
| 337 |
-
def get_g_loss(self, net, realA, fakeB, realB, pred_real=None):
|
| 338 |
-
if pred_real is None:
|
| 339 |
-
pred_real = net.forward(realB)
|
| 340 |
-
pred_fake = net.forward(fakeB)
|
| 341 |
-
return self.criterionGAN(pred_fake - pred_real, 1)
|
| 342 |
-
|
| 343 |
-
def get_loss(self, net, realA, fakeB, realB):
|
| 344 |
-
pred_real = net.forward(realB)
|
| 345 |
-
pred_fake = net.forward(fakeB.detach())
|
| 346 |
-
|
| 347 |
-
loss_D = self.criterionGAN(pred_real - pred_fake, 1) # BCE_stable loss
|
| 348 |
-
return loss_D, pred_fake, pred_real
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
class DiscLossRa(DiscLoss):
|
| 352 |
-
# RaSGAN from
|
| 353 |
-
# https://arxiv.org/abs/1807.00734
|
| 354 |
-
def name(self):
|
| 355 |
-
return 'RaSGAN'
|
| 356 |
-
|
| 357 |
-
def initialize(self, opt, tensor):
|
| 358 |
-
DiscLoss.initialize(self, opt, tensor)
|
| 359 |
-
self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
|
| 360 |
-
|
| 361 |
-
def get_g_loss(self, net, realA, fakeB, realB, pred_real=None):
|
| 362 |
-
if pred_real is None:
|
| 363 |
-
pred_real = net.forward(realB)
|
| 364 |
-
pred_fake = net.forward(fakeB)
|
| 365 |
-
|
| 366 |
-
loss_G = self.criterionGAN(pred_real - torch.mean(pred_fake, dim=0, keepdim=True), 0)
|
| 367 |
-
loss_G += self.criterionGAN(pred_fake - torch.mean(pred_real, dim=0, keepdim=True), 1)
|
| 368 |
-
return loss_G * 0.5
|
| 369 |
-
|
| 370 |
-
def get_loss(self, net, realA, fakeB, realB):
|
| 371 |
-
pred_real = net.forward(realB)
|
| 372 |
-
|
| 373 |
-
pred_fake = net.forward(fakeB.detach())
|
| 374 |
-
|
| 375 |
-
loss_D = self.criterionGAN(pred_real - torch.mean(pred_fake, dim=0, keepdim=True), 1)
|
| 376 |
-
loss_D += self.criterionGAN(pred_fake - torch.mean(pred_real, dim=0, keepdim=True), 0)
|
| 377 |
-
return loss_D * 0.5, pred_fake, pred_real
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
class SSIM_Loss(nn.Module):
|
| 381 |
-
def __init__(self):
|
| 382 |
-
super().__init__()
|
| 383 |
-
self.ssim = SSIM(data_range=1, size_average=True, channel=3)
|
| 384 |
-
|
| 385 |
-
def forward(self, output, target):
|
| 386 |
-
return 1 - self.ssim(output, target)
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
def init_loss(opt, tensor):
|
| 390 |
-
disc_loss = None
|
| 391 |
-
content_loss = None
|
| 392 |
-
|
| 393 |
-
loss_dic = {}
|
| 394 |
-
|
| 395 |
-
pixel_loss = ContentLoss()
|
| 396 |
-
pixel_loss.initialize(MultipleLoss([nn.MSELoss(), GradientLoss()], [0.3, 0.6]))
|
| 397 |
-
|
| 398 |
-
loss_dic['t_pixel'] = pixel_loss
|
| 399 |
-
|
| 400 |
-
r_loss = ContentLoss()
|
| 401 |
-
r_loss.initialize(MultipleLoss([nn.MSELoss()], [0.9]))
|
| 402 |
-
loss_dic['r_pixel'] = pixel_loss
|
| 403 |
-
|
| 404 |
-
loss_dic['t_ssim'] = SSIM_Loss()
|
| 405 |
-
loss_dic['r_ssim'] = SSIM_Loss()
|
| 406 |
-
|
| 407 |
-
loss_dic['mtv'] = MTVLoss()
|
| 408 |
-
loss_dic['ktv'] = KTVLoss()
|
| 409 |
-
loss_dic['recons'] = ReconsLoss(edge_recons=False)
|
| 410 |
-
loss_dic['reconsx'] = ReconsLossX(edge_recons=False)
|
| 411 |
-
|
| 412 |
-
if opt.lambda_gan > 0:
|
| 413 |
-
if opt.gan_type == 'sgan' or opt.gan_type == 'gan':
|
| 414 |
-
disc_loss = DiscLoss()
|
| 415 |
-
elif opt.gan_type == 'rsgan':
|
| 416 |
-
disc_loss = DiscLossR()
|
| 417 |
-
elif opt.gan_type == 'rasgan':
|
| 418 |
-
disc_loss = DiscLossRa()
|
| 419 |
-
else:
|
| 420 |
-
raise ValueError("GAN [%s] not recognized." % opt.gan_type)
|
| 421 |
-
|
| 422 |
-
disc_loss.initialize(opt, tensor)
|
| 423 |
-
loss_dic['gan'] = disc_loss
|
| 424 |
-
|
| 425 |
-
return loss_dic
|
| 426 |
-
|
| 427 |
-
class DINOLoss(nn.Module):
|
| 428 |
-
'''
|
| 429 |
-
DINO-ViT as perceptual loss
|
| 430 |
-
'''
|
| 431 |
-
|
| 432 |
-
def resize_to_dino(self, feature, size = (224, 224)):
|
| 433 |
-
return F.interpolate(feature, size = size, mode='bilinear', align_corners=False)
|
| 434 |
-
|
| 435 |
-
def calculate_crop_cls_loss(self, outputs, inputs):
|
| 436 |
-
loss = 0.0
|
| 437 |
-
for a, b in zip(outputs, inputs): # avoid memory limitations
|
| 438 |
-
a = self.global_transform(a).unsqueeze(0)
|
| 439 |
-
b = self.global_transform(b).unsqueeze(0)
|
| 440 |
-
cls_token = self.extractor.get_feature_from_input(a)[-1][0, 0, :]
|
| 441 |
-
with torch.no_grad():
|
| 442 |
-
target_cls_token = self.extractor.get_feature_from_input(b)[-1][0, 0, :]
|
| 443 |
-
loss += F.mse_loss(cls_token, target_cls_token)
|
| 444 |
-
return loss
|
| 445 |
-
|
| 446 |
-
def __init__(self) :
|
| 447 |
-
super(DINOLoss, self).__init__()
|
| 448 |
-
self.extractor = VitExtractor(model_name = 'dino_vits8', device = 'cuda')
|
| 449 |
-
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
|
| 450 |
-
|
| 451 |
-
def forward(self, output, target):
|
| 452 |
-
output = self.normalize(self.resize_to_dino(output))
|
| 453 |
-
output_cls_token = self.extractor.get_feature_from_input(output)[-1][0, 0, :]
|
| 454 |
-
with torch.no_grad():
|
| 455 |
-
target = self.normalize(self.resize_to_dino(target))
|
| 456 |
-
target_cls_token = self.extractor.get_feature_from_input(target)[-1][0, 0, :]
|
| 457 |
-
|
| 458 |
-
return F.mse_loss(output_cls_token, target_cls_token)
|
| 459 |
-
|
| 460 |
-
if __name__ == '__main__':
|
| 461 |
-
x = torch.randn(3, 32, 224, 224).cuda()
|
| 462 |
-
import time
|
| 463 |
-
|
| 464 |
-
s = time.time()
|
| 465 |
-
out1, out2 = gradient_norm_kernel(x)
|
| 466 |
-
t = time.time()
|
| 467 |
-
print(t - s)
|
| 468 |
-
print(out1.shape, out2.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/losses_opt.py
DELETED
|
@@ -1,404 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
from pytorch_msssim import MS_SSIM, SSIM
|
| 6 |
-
|
| 7 |
-
from models.vgg import Vgg19
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
###############################################################################
|
| 11 |
-
# Functions
|
| 12 |
-
###############################################################################
|
| 13 |
-
def compute_gradient(img):
|
| 14 |
-
gradx = img[..., 1:, :] - img[..., :-1, :]
|
| 15 |
-
grady = img[..., 1:] - img[..., :-1]
|
| 16 |
-
return gradx, grady
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class GradientLoss(nn.Module):
|
| 20 |
-
def __init__(self):
|
| 21 |
-
super(GradientLoss, self).__init__()
|
| 22 |
-
self.loss = nn.L1Loss()
|
| 23 |
-
|
| 24 |
-
def forward(self, predict, target):
|
| 25 |
-
predict_gradx, predict_grady = compute_gradient(predict)
|
| 26 |
-
target_gradx, target_grady = compute_gradient(target)
|
| 27 |
-
|
| 28 |
-
return self.loss(predict_gradx, target_gradx) + self.loss(predict_grady, target_grady)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
class ContainLoss(nn.Module):
|
| 32 |
-
def __init__(self, eps=1e-12):
|
| 33 |
-
super(ContainLoss, self).__init__()
|
| 34 |
-
self.eps = eps
|
| 35 |
-
|
| 36 |
-
def forward(self, predict_t, predict_r, input_image):
|
| 37 |
-
pix_num = np.prod(input_image.shape)
|
| 38 |
-
predict_tx, predict_ty = compute_gradient(predict_t)
|
| 39 |
-
predict_rx, predict_ry = compute_gradient(predict_r)
|
| 40 |
-
input_x, input_y = compute_gradient(input_image)
|
| 41 |
-
|
| 42 |
-
out = torch.norm(predict_tx / (input_x + self.eps), 2) ** 2 + \
|
| 43 |
-
torch.norm(predict_ty / (input_y + self.eps), 2) ** 2 + \
|
| 44 |
-
torch.norm(predict_rx / (input_x + self.eps), 2) ** 2 + \
|
| 45 |
-
torch.norm(predict_ry / (input_y + self.eps), 2) ** 2
|
| 46 |
-
|
| 47 |
-
return out / pix_num
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class MultipleLoss(nn.Module):
|
| 51 |
-
def __init__(self, losses, weight=None):
|
| 52 |
-
super(MultipleLoss, self).__init__()
|
| 53 |
-
self.losses = nn.ModuleList(losses)
|
| 54 |
-
self.weight = weight or [1 / len(self.losses)] * len(self.losses)
|
| 55 |
-
|
| 56 |
-
def forward(self, predict, target):
|
| 57 |
-
total_loss = 0
|
| 58 |
-
for weight, loss in zip(self.weight, self.losses):
|
| 59 |
-
total_loss += loss(predict, target) * weight
|
| 60 |
-
return total_loss
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
class MeanShift(nn.Conv2d):
|
| 64 |
-
def __init__(self, data_mean, data_std, data_range=1, norm=True):
|
| 65 |
-
"""norm (bool): normalize/denormalize the stats"""
|
| 66 |
-
c = len(data_mean)
|
| 67 |
-
super(MeanShift, self).__init__(c, c, kernel_size=1)
|
| 68 |
-
std = torch.Tensor(data_std)
|
| 69 |
-
self.weight.data = torch.eye(c).view(c, c, 1, 1)
|
| 70 |
-
if norm:
|
| 71 |
-
self.weight.data.div_(std.view(c, 1, 1, 1))
|
| 72 |
-
self.bias.data = -1 * data_range * torch.Tensor(data_mean)
|
| 73 |
-
self.bias.data.div_(std)
|
| 74 |
-
else:
|
| 75 |
-
self.weight.data.mul_(std.view(c, 1, 1, 1))
|
| 76 |
-
self.bias.data = data_range * torch.Tensor(data_mean)
|
| 77 |
-
self.requires_grad = False
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class VGGLoss(nn.Module):
|
| 81 |
-
def __init__(self, vgg=None, weights=None, indices=None, normalize=True):
|
| 82 |
-
super(VGGLoss, self).__init__()
|
| 83 |
-
if vgg is None:
|
| 84 |
-
self.vgg = Vgg19().cuda()
|
| 85 |
-
else:
|
| 86 |
-
self.vgg = vgg
|
| 87 |
-
self.criterion = nn.L1Loss()
|
| 88 |
-
self.weights = weights or [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10 / 1.5]
|
| 89 |
-
self.indices = indices or [2, 7, 12, 21, 30]
|
| 90 |
-
if normalize:
|
| 91 |
-
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
|
| 92 |
-
else:
|
| 93 |
-
self.normalize = None
|
| 94 |
-
|
| 95 |
-
def forward(self, x, y):
|
| 96 |
-
if self.normalize is not None:
|
| 97 |
-
x = self.normalize(x)
|
| 98 |
-
y = self.normalize(y)
|
| 99 |
-
x_vgg, y_vgg = self.vgg(x, self.indices), self.vgg(y, self.indices)
|
| 100 |
-
loss = 0
|
| 101 |
-
for i in range(len(x_vgg)):
|
| 102 |
-
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
| 103 |
-
|
| 104 |
-
return loss
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def l1_norm_dim(x, dim):
|
| 108 |
-
return torch.mean(torch.abs(x), dim=dim)
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
def l1_norm(x):
|
| 112 |
-
return torch.mean(torch.abs(x))
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def l2_norm(x):
|
| 116 |
-
return torch.mean(torch.square(x))
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def gradient_norm_kernel(x, kernel_size=10):
|
| 120 |
-
out_h, out_v = compute_gradient(x)
|
| 121 |
-
shape = out_h.shape
|
| 122 |
-
out_h = F.unfold(out_h, kernel_size=(kernel_size, kernel_size), stride=(1, 1))
|
| 123 |
-
out_h = out_h.reshape(shape[0], shape[1], kernel_size * kernel_size, -1)
|
| 124 |
-
out_h = l1_norm_dim(out_h, 2)
|
| 125 |
-
out_v = F.unfold(out_v, kernel_size=(kernel_size, kernel_size), stride=(1, 1))
|
| 126 |
-
out_v = out_v.reshape(shape[0], shape[1], kernel_size * kernel_size, -1)
|
| 127 |
-
out_v = l1_norm_dim(out_v, 2)
|
| 128 |
-
return out_h, out_v
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
class KTVLoss(nn.Module):
|
| 132 |
-
def __init__(self, kernel_size=10):
|
| 133 |
-
super().__init__()
|
| 134 |
-
self.kernel_size = kernel_size
|
| 135 |
-
self.criterion = nn.L1Loss()
|
| 136 |
-
self.eps = 1e-6
|
| 137 |
-
|
| 138 |
-
def forward(self, out_l, out_r, input_i):
|
| 139 |
-
out_l_normx, out_l_normy = gradient_norm_kernel(out_l, self.kernel_size)
|
| 140 |
-
out_r_normx, out_r_normy = gradient_norm_kernel(out_r, self.kernel_size)
|
| 141 |
-
input_normx, input_normy = gradient_norm_kernel(input_i, self.kernel_size)
|
| 142 |
-
norm_l = out_l_normx + out_l_normy
|
| 143 |
-
norm_r = out_r_normx + out_r_normy
|
| 144 |
-
norm_target = input_normx + input_normy + self.eps
|
| 145 |
-
norm_loss = (norm_l / norm_target + norm_r / norm_target).mean()
|
| 146 |
-
|
| 147 |
-
out_lx, out_ly = compute_gradient(out_l)
|
| 148 |
-
out_rx, out_ry = compute_gradient(out_r)
|
| 149 |
-
input_x, input_y = compute_gradient(input_i)
|
| 150 |
-
gradient_diffx = self.criterion(out_lx + out_rx, input_x)
|
| 151 |
-
gradient_diffy = self.criterion(out_ly + out_ry, input_y)
|
| 152 |
-
grad_loss = gradient_diffx + gradient_diffy
|
| 153 |
-
|
| 154 |
-
loss = norm_loss * 1e-4 + grad_loss
|
| 155 |
-
return loss
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
class MTVLoss(nn.Module):
|
| 159 |
-
def __init__(self, kernel_size=10):
|
| 160 |
-
super().__init__()
|
| 161 |
-
self.criterion = nn.L1Loss()
|
| 162 |
-
self.norm = l1_norm
|
| 163 |
-
|
| 164 |
-
def forward(self, out_l, out_r, input_i):
|
| 165 |
-
out_lx, out_ly = compute_gradient(out_l)
|
| 166 |
-
out_rx, out_ry = compute_gradient(out_r)
|
| 167 |
-
input_x, input_y = compute_gradient(input_i)
|
| 168 |
-
|
| 169 |
-
norm_l = self.norm(out_lx) + self.norm(out_ly)
|
| 170 |
-
norm_r = self.norm(out_rx) + self.norm(out_ry)
|
| 171 |
-
norm_target = self.norm(input_x) + self.norm(input_y)
|
| 172 |
-
|
| 173 |
-
gradient_diffx = self.criterion(out_lx + out_rx, input_x)
|
| 174 |
-
gradient_diffy = self.criterion(out_ly + out_ry, input_y)
|
| 175 |
-
|
| 176 |
-
loss = (norm_l / norm_target + norm_r / norm_target) * 1e-5 + gradient_diffx + gradient_diffy
|
| 177 |
-
|
| 178 |
-
return loss
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
class ReconsLoss(nn.Module):
|
| 182 |
-
def __init__(self):
|
| 183 |
-
super().__init__()
|
| 184 |
-
self.criterion = nn.L1Loss()
|
| 185 |
-
self.norm = l1_norm
|
| 186 |
-
|
| 187 |
-
def forward(self, out_l, out_r, input_i):
|
| 188 |
-
content_diff = self.criterion(out_l + out_r, input_i)
|
| 189 |
-
out_lx, out_ly = compute_gradient(out_l)
|
| 190 |
-
out_rx, out_ry = compute_gradient(out_r)
|
| 191 |
-
input_x, input_y = compute_gradient(input_i)
|
| 192 |
-
|
| 193 |
-
gradient_diffx = self.criterion(out_lx + out_rx, input_x)
|
| 194 |
-
gradient_diffy = self.criterion(out_ly + out_ry, input_y)
|
| 195 |
-
|
| 196 |
-
loss = content_diff + (gradient_diffx + gradient_diffy) * 0.5
|
| 197 |
-
|
| 198 |
-
return loss
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
class ContentLoss():
|
| 202 |
-
def initialize(self, loss):
|
| 203 |
-
self.criterion = loss
|
| 204 |
-
|
| 205 |
-
def get_loss(self, fakeIm, realIm):
|
| 206 |
-
return self.criterion(fakeIm, realIm)
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
class GANLoss(nn.Module):
|
| 210 |
-
def __init__(self, use_l1=True, target_real_label=1.0, target_fake_label=0.0,
|
| 211 |
-
tensor=torch.FloatTensor):
|
| 212 |
-
super(GANLoss, self).__init__()
|
| 213 |
-
self.real_label = target_real_label
|
| 214 |
-
self.fake_label = target_fake_label
|
| 215 |
-
self.real_label_var = None
|
| 216 |
-
self.fake_label_var = None
|
| 217 |
-
self.Tensor = tensor
|
| 218 |
-
if use_l1:
|
| 219 |
-
self.loss = nn.L1Loss()
|
| 220 |
-
else:
|
| 221 |
-
self.loss = nn.BCEWithLogitsLoss() # absorb sigmoid into BCELoss
|
| 222 |
-
|
| 223 |
-
def get_target_tensor(self, input, target_is_real):
|
| 224 |
-
target_tensor = None
|
| 225 |
-
if target_is_real:
|
| 226 |
-
create_label = ((self.real_label_var is None) or
|
| 227 |
-
(self.real_label_var.numel() != input.numel()))
|
| 228 |
-
if create_label:
|
| 229 |
-
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
|
| 230 |
-
self.real_label_var = real_tensor
|
| 231 |
-
target_tensor = self.real_label_var
|
| 232 |
-
else:
|
| 233 |
-
create_label = ((self.fake_label_var is None) or
|
| 234 |
-
(self.fake_label_var.numel() != input.numel()))
|
| 235 |
-
if create_label:
|
| 236 |
-
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
|
| 237 |
-
self.fake_label_var = fake_tensor
|
| 238 |
-
target_tensor = self.fake_label_var
|
| 239 |
-
return target_tensor
|
| 240 |
-
|
| 241 |
-
def __call__(self, input, target_is_real):
|
| 242 |
-
if isinstance(input, list):
|
| 243 |
-
loss = 0
|
| 244 |
-
for input_i in input:
|
| 245 |
-
target_tensor = self.get_target_tensor(input_i, target_is_real)
|
| 246 |
-
loss += self.loss(input_i, target_tensor)
|
| 247 |
-
return loss
|
| 248 |
-
else:
|
| 249 |
-
target_tensor = self.get_target_tensor(input, target_is_real)
|
| 250 |
-
return self.loss(input, target_tensor)
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
class DiscLoss():
|
| 254 |
-
def name(self):
|
| 255 |
-
return 'SGAN'
|
| 256 |
-
|
| 257 |
-
def initialize(self, opt, tensor):
|
| 258 |
-
self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
|
| 259 |
-
|
| 260 |
-
def get_g_loss(self, net, realA, fakeB, realB):
|
| 261 |
-
# First, G(A) should fake the discriminator
|
| 262 |
-
pred_fake = net.forward(fakeB)
|
| 263 |
-
return self.criterionGAN(pred_fake, 1)
|
| 264 |
-
|
| 265 |
-
def get_loss(self, net, realA=None, fakeB=None, realB=None):
|
| 266 |
-
pred_fake = None
|
| 267 |
-
pred_real = None
|
| 268 |
-
loss_D_fake = 0
|
| 269 |
-
loss_D_real = 0
|
| 270 |
-
# Fake
|
| 271 |
-
# stop backprop to the generator by detaching fake_B
|
| 272 |
-
# Generated Image Disc Output should be close to zero
|
| 273 |
-
|
| 274 |
-
if fakeB is not None:
|
| 275 |
-
pred_fake = net.forward(fakeB.detach())
|
| 276 |
-
loss_D_fake = self.criterionGAN(pred_fake, 0)
|
| 277 |
-
|
| 278 |
-
# Real
|
| 279 |
-
if realB is not None:
|
| 280 |
-
pred_real = net.forward(realB)
|
| 281 |
-
loss_D_real = self.criterionGAN(pred_real, 1)
|
| 282 |
-
|
| 283 |
-
# Combined loss
|
| 284 |
-
loss_D = (loss_D_fake + loss_D_real) * 0.5
|
| 285 |
-
return loss_D, pred_fake, pred_real
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
class DiscLossR(DiscLoss):
|
| 289 |
-
# RSGAN from
|
| 290 |
-
# https://arxiv.org/abs/1807.00734
|
| 291 |
-
def name(self):
|
| 292 |
-
return 'RSGAN'
|
| 293 |
-
|
| 294 |
-
def initialize(self, opt, tensor):
|
| 295 |
-
DiscLoss.initialize(self, opt, tensor)
|
| 296 |
-
self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
|
| 297 |
-
|
| 298 |
-
def get_g_loss(self, net, realA, fakeB, realB, pred_real=None):
|
| 299 |
-
if pred_real is None:
|
| 300 |
-
pred_real = net.forward(realB)
|
| 301 |
-
pred_fake = net.forward(fakeB)
|
| 302 |
-
return self.criterionGAN(pred_fake - pred_real, 1)
|
| 303 |
-
|
| 304 |
-
def get_loss(self, net, realA, fakeB, realB):
|
| 305 |
-
pred_real = net.forward(realB)
|
| 306 |
-
pred_fake = net.forward(fakeB.detach())
|
| 307 |
-
|
| 308 |
-
loss_D = self.criterionGAN(pred_real - pred_fake, 1) # BCE_stable loss
|
| 309 |
-
return loss_D, pred_fake, pred_real
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
class DiscLossRa(DiscLoss):
|
| 313 |
-
# RaSGAN from
|
| 314 |
-
# https://arxiv.org/abs/1807.00734
|
| 315 |
-
def name(self):
|
| 316 |
-
return 'RaSGAN'
|
| 317 |
-
|
| 318 |
-
def initialize(self, opt, tensor):
|
| 319 |
-
DiscLoss.initialize(self, opt, tensor)
|
| 320 |
-
self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
|
| 321 |
-
|
| 322 |
-
def get_g_loss(self, net, realA, fakeB, realB, pred_real=None):
|
| 323 |
-
if pred_real is None:
|
| 324 |
-
pred_real = net.forward(realB)
|
| 325 |
-
pred_fake = net.forward(fakeB)
|
| 326 |
-
|
| 327 |
-
loss_G = self.criterionGAN(pred_real - torch.mean(pred_fake, dim=0, keepdim=True), 0)
|
| 328 |
-
loss_G += self.criterionGAN(pred_fake - torch.mean(pred_real, dim=0, keepdim=True), 1)
|
| 329 |
-
return loss_G * 0.5
|
| 330 |
-
|
| 331 |
-
def get_loss(self, net, realA, fakeB, realB):
|
| 332 |
-
pred_real = net.forward(realB)
|
| 333 |
-
pred_fake = net.forward(fakeB.detach())
|
| 334 |
-
|
| 335 |
-
loss_D = self.criterionGAN(pred_real - torch.mean(pred_fake, dim=0, keepdim=True), 1)
|
| 336 |
-
loss_D += self.criterionGAN(pred_fake - torch.mean(pred_real, dim=0, keepdim=True), 0)
|
| 337 |
-
return loss_D * 0.5, pred_fake, pred_real
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
class MS_SSIM_Loss(nn.Module):
|
| 341 |
-
def __init__(self):
|
| 342 |
-
super().__init__()
|
| 343 |
-
self.ms_ssim = MS_SSIM(data_range=1, size_average=True, channel=3)
|
| 344 |
-
|
| 345 |
-
def forward(self, output, target):
|
| 346 |
-
return 1 - self.ms_ssim(output, target)
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
class SSIM_Loss(nn.Module):
|
| 350 |
-
def __init__(self):
|
| 351 |
-
super().__init__()
|
| 352 |
-
self.ssim = SSIM(data_range=1, size_average=True, channel=3)
|
| 353 |
-
|
| 354 |
-
def forward(self, output, target):
|
| 355 |
-
return 1 - self.ssim(output, target)
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
def init_loss(opt, tensor):
|
| 359 |
-
disc_loss = None
|
| 360 |
-
content_loss = None
|
| 361 |
-
|
| 362 |
-
loss_dic = {}
|
| 363 |
-
|
| 364 |
-
t_pixel_loss = ContentLoss()
|
| 365 |
-
t_pixel_loss.initialize(
|
| 366 |
-
MultipleLoss([nn.MSELoss(), MS_SSIM_Loss(), GradientLoss()], [1.0, 0.4, 0.6]))
|
| 367 |
-
|
| 368 |
-
loss_dic['t_pixel'] = t_pixel_loss
|
| 369 |
-
|
| 370 |
-
r_pixel_loss = ContentLoss()
|
| 371 |
-
r_pixel_loss.initialize(
|
| 372 |
-
MultipleLoss([nn.MSELoss()], [4.0]))
|
| 373 |
-
|
| 374 |
-
loss_dic['r_pixel'] = r_pixel_loss
|
| 375 |
-
loss_dic['recons'] = ReconsLoss()
|
| 376 |
-
|
| 377 |
-
loss_dic['mtv'] = MTVLoss()
|
| 378 |
-
loss_dic['ktv'] = KTVLoss()
|
| 379 |
-
|
| 380 |
-
if opt.lambda_gan > 0:
|
| 381 |
-
if opt.gan_type == 'sgan' or opt.gan_type == 'gan':
|
| 382 |
-
disc_loss = DiscLoss()
|
| 383 |
-
elif opt.gan_type == 'rsgan':
|
| 384 |
-
disc_loss = DiscLossR()
|
| 385 |
-
elif opt.gan_type == 'rasgan':
|
| 386 |
-
disc_loss = DiscLossRa()
|
| 387 |
-
else:
|
| 388 |
-
raise ValueError("GAN [%s] not recognized." % opt.gan_type)
|
| 389 |
-
|
| 390 |
-
disc_loss.initialize(opt, tensor)
|
| 391 |
-
loss_dic['gan'] = disc_loss
|
| 392 |
-
|
| 393 |
-
return loss_dic
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
if __name__ == '__main__':
|
| 397 |
-
x = torch.randn(3, 32, 224, 224).cuda()
|
| 398 |
-
import time
|
| 399 |
-
|
| 400 |
-
s = time.time()
|
| 401 |
-
out1, out2 = gradient_norm_kernel(x)
|
| 402 |
-
t = time.time()
|
| 403 |
-
print(t - s)
|
| 404 |
-
print(out1.shape, out2.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/networks.py
DELETED
|
@@ -1,335 +0,0 @@
|
|
| 1 |
-
import functools
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
from torch.nn import init
|
| 7 |
-
from torch.nn.utils import spectral_norm
|
| 8 |
-
from torch.nn import functional as F
|
| 9 |
-
###############################################################################
|
| 10 |
-
# Functions
|
| 11 |
-
###############################################################################
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def weights_init_normal(m):
|
| 15 |
-
classname = m.__class__.__name__
|
| 16 |
-
# print(classname)
|
| 17 |
-
if isinstance(m, nn.Sequential):
|
| 18 |
-
return
|
| 19 |
-
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
| 20 |
-
init.normal_(m.weight.data, 0.0, 0.02)
|
| 21 |
-
elif isinstance(m, nn.Linear):
|
| 22 |
-
init.normal_(m.weight.data, 0.0, 0.02)
|
| 23 |
-
elif isinstance(m, nn.BatchNorm2d):
|
| 24 |
-
init.normal_(m.weight.data, 1.0, 0.02)
|
| 25 |
-
init.constant_(m.bias.data, 0.0)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def weights_init_xavier(m):
|
| 29 |
-
classname = m.__class__.__name__
|
| 30 |
-
# print(classname)
|
| 31 |
-
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
| 32 |
-
init.xavier_normal_(m.weight.data, gain=0.02)
|
| 33 |
-
elif isinstance(m, nn.Linear):
|
| 34 |
-
init.xavier_normal_(m.weight.data, gain=0.02)
|
| 35 |
-
elif isinstance(m, nn.BatchNorm2d):
|
| 36 |
-
init.normal_(m.weight.data, 1.0, 0.02)
|
| 37 |
-
init.constant_(m.bias.data, 0.0)
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def weights_init_kaiming(m):
|
| 41 |
-
classname = m.__class__.__name__
|
| 42 |
-
# print(classname)
|
| 43 |
-
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
| 44 |
-
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
| 45 |
-
elif isinstance(m, nn.Linear):
|
| 46 |
-
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
| 47 |
-
elif isinstance(m, nn.BatchNorm2d):
|
| 48 |
-
init.normal_(m.weight.data, 1.0, 0.02)
|
| 49 |
-
init.constant_(m.bias.data, 0.0)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def weights_init_orthogonal(m):
|
| 53 |
-
classname = m.__class__.__name__
|
| 54 |
-
print(classname)
|
| 55 |
-
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
| 56 |
-
init.orthogonal(m.weight.data, gain=1)
|
| 57 |
-
elif isinstance(m, nn.Linear):
|
| 58 |
-
init.orthogonal(m.weight.data, gain=1)
|
| 59 |
-
elif isinstance(m, nn.BatchNorm2d):
|
| 60 |
-
init.normal(m.weight.data, 1.0, 0.02)
|
| 61 |
-
init.constant_(m.bias.data, 0.0)
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def init_weights(net, init_type='normal'):
|
| 65 |
-
print('[i] initialization method [%s]' % init_type)
|
| 66 |
-
if init_type == 'normal':
|
| 67 |
-
net.apply(weights_init_normal)
|
| 68 |
-
elif init_type == 'xavier':
|
| 69 |
-
net.apply(weights_init_xavier)
|
| 70 |
-
elif init_type == 'kaiming':
|
| 71 |
-
net.apply(weights_init_kaiming)
|
| 72 |
-
elif init_type == 'orthogonal':
|
| 73 |
-
net.apply(weights_init_orthogonal)
|
| 74 |
-
elif init_type == 'edsr':
|
| 75 |
-
pass
|
| 76 |
-
else:
|
| 77 |
-
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def get_norm_layer(norm_type='instance'):
|
| 81 |
-
if norm_type == 'batch':
|
| 82 |
-
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
| 83 |
-
elif norm_type == 'instance':
|
| 84 |
-
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
|
| 85 |
-
elif norm_type == 'none':
|
| 86 |
-
norm_layer = None
|
| 87 |
-
else:
|
| 88 |
-
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
| 89 |
-
return norm_layer
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def define_D(opt, in_channels=3):
|
| 93 |
-
# use_sigmoid = opt.gan_type == 'gan'
|
| 94 |
-
use_sigmoid = False # incorporate sigmoid into BCE_stable loss
|
| 95 |
-
|
| 96 |
-
if opt.which_model_D == 'disc_vgg':
|
| 97 |
-
netD = Discriminator_VGG(in_channels, use_sigmoid=use_sigmoid)
|
| 98 |
-
init_weights(netD, init_type='kaiming')
|
| 99 |
-
elif opt.which_model_D == 'disc_patch':
|
| 100 |
-
netD = NLayerDiscriminator(in_channels, 64, 3, nn.InstanceNorm2d, use_sigmoid, getIntermFeat=False)
|
| 101 |
-
init_weights(netD, init_type='normal')
|
| 102 |
-
elif opt.which_model_D == 'disc_unet':
|
| 103 |
-
netD = UNetDiscriminatorSN(in_channels)
|
| 104 |
-
else:
|
| 105 |
-
raise NotImplementedError('%s is not implemented' %opt.which_model_D)
|
| 106 |
-
|
| 107 |
-
if len(opt.gpu_ids) > 0:
|
| 108 |
-
assert(torch.cuda.is_available())
|
| 109 |
-
netD.cuda(opt.gpu_ids[0])
|
| 110 |
-
|
| 111 |
-
return netD
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def print_network(net):
|
| 115 |
-
num_params = 0
|
| 116 |
-
for param in net.parameters():
|
| 117 |
-
num_params += param.numel()
|
| 118 |
-
print(net)
|
| 119 |
-
print('Total number of parameters: %d' % num_params)
|
| 120 |
-
print('The size of receptive field: %d' % receptive_field(net))
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def receptive_field(net):
|
| 124 |
-
def _f(output_size, ksize, stride, dilation):
|
| 125 |
-
return (output_size - 1) * stride + ksize * dilation - dilation + 1
|
| 126 |
-
|
| 127 |
-
stats = []
|
| 128 |
-
for m in net.modules():
|
| 129 |
-
if isinstance(m, nn.Conv2d):
|
| 130 |
-
stats.append((m.kernel_size, m.stride, m.dilation))
|
| 131 |
-
|
| 132 |
-
rsize = 1
|
| 133 |
-
for (ksize, stride, dilation) in reversed(stats):
|
| 134 |
-
if type(ksize) == tuple: ksize = ksize[0]
|
| 135 |
-
if type(stride) == tuple: stride = stride[0]
|
| 136 |
-
if type(dilation) == tuple: dilation = dilation[0]
|
| 137 |
-
rsize = _f(rsize, ksize, stride, dilation)
|
| 138 |
-
return rsize
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
def debug_network(net):
|
| 142 |
-
def _hook(m, i, o):
|
| 143 |
-
print(o.size())
|
| 144 |
-
for m in net.modules():
|
| 145 |
-
m.register_forward_hook(_hook)
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
##############################################################################
|
| 149 |
-
# Classes
|
| 150 |
-
##############################################################################
|
| 151 |
-
|
| 152 |
-
# Defines the PatchGAN discriminator with the specified arguments.
|
| 153 |
-
class NLayerDiscriminator(nn.Module):
|
| 154 |
-
def __init__(self, input_nc, ndf=64, n_layers=3,
|
| 155 |
-
norm_layer=nn.BatchNorm2d, use_sigmoid=False,
|
| 156 |
-
branch=1, bias=True, getIntermFeat=False):
|
| 157 |
-
super(NLayerDiscriminator, self).__init__()
|
| 158 |
-
self.getIntermFeat = getIntermFeat
|
| 159 |
-
self.n_layers = n_layers
|
| 160 |
-
kw = 4
|
| 161 |
-
padw = int(np.ceil((kw-1.0)/2))
|
| 162 |
-
sequence = [[nn.Conv2d(input_nc*branch, ndf*branch, kernel_size=kw, stride=2, padding=padw, groups=branch, bias=True), nn.LeakyReLU(0.2, True)]]
|
| 163 |
-
|
| 164 |
-
nf = ndf
|
| 165 |
-
for n in range(1, n_layers):
|
| 166 |
-
nf_prev = nf
|
| 167 |
-
nf = min(nf * 2, 512)
|
| 168 |
-
sequence += [[
|
| 169 |
-
nn.Conv2d(nf_prev*branch, nf*branch, groups=branch, kernel_size=kw, stride=2, padding=padw, bias=bias),
|
| 170 |
-
norm_layer(nf*branch), nn.LeakyReLU(0.2, True)
|
| 171 |
-
]]
|
| 172 |
-
|
| 173 |
-
nf_prev = nf
|
| 174 |
-
nf = min(nf * 2, 512)
|
| 175 |
-
sequence += [[
|
| 176 |
-
nn.Conv2d(nf_prev*branch, nf*branch, groups=branch, kernel_size=kw, stride=1, padding=padw, bias=bias),
|
| 177 |
-
norm_layer(nf*branch),
|
| 178 |
-
nn.LeakyReLU(0.2, True)
|
| 179 |
-
]]
|
| 180 |
-
|
| 181 |
-
sequence += [[nn.Conv2d(nf*branch, 1*branch, groups=branch, kernel_size=kw, stride=1, padding=padw, bias=True)]]
|
| 182 |
-
|
| 183 |
-
if use_sigmoid:
|
| 184 |
-
sequence += [[nn.Sigmoid()]]
|
| 185 |
-
|
| 186 |
-
if getIntermFeat:
|
| 187 |
-
for n in range(len(sequence)):
|
| 188 |
-
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
|
| 189 |
-
else:
|
| 190 |
-
sequence_stream = []
|
| 191 |
-
for n in range(len(sequence)):
|
| 192 |
-
sequence_stream += sequence[n]
|
| 193 |
-
self.model = nn.Sequential(*sequence_stream)
|
| 194 |
-
|
| 195 |
-
def forward(self, input):
|
| 196 |
-
if self.getIntermFeat:
|
| 197 |
-
res = [input]
|
| 198 |
-
for n in range(self.n_layers+2):
|
| 199 |
-
model = getattr(self, 'model'+str(n))
|
| 200 |
-
res.append(model(res[-1]))
|
| 201 |
-
return res[1:]
|
| 202 |
-
else:
|
| 203 |
-
return self.model(input)
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
class Discriminator_VGG(nn.Module):
|
| 207 |
-
def __init__(self, in_channels=3, use_sigmoid=True):
|
| 208 |
-
super(Discriminator_VGG, self).__init__()
|
| 209 |
-
def conv(*args, **kwargs):
|
| 210 |
-
return nn.Conv2d(*args, **kwargs)
|
| 211 |
-
|
| 212 |
-
num_groups = 32
|
| 213 |
-
|
| 214 |
-
body = [
|
| 215 |
-
conv(in_channels, 64, kernel_size=3, padding=1), # 224
|
| 216 |
-
nn.LeakyReLU(0.2),
|
| 217 |
-
|
| 218 |
-
conv(64, 64, kernel_size=3, stride=2, padding=1), # 112
|
| 219 |
-
nn.GroupNorm(num_groups, 64),
|
| 220 |
-
nn.LeakyReLU(0.2),
|
| 221 |
-
|
| 222 |
-
conv(64, 128, kernel_size=3, padding=1),
|
| 223 |
-
nn.GroupNorm(num_groups, 128),
|
| 224 |
-
nn.LeakyReLU(0.2),
|
| 225 |
-
|
| 226 |
-
conv(128, 128, kernel_size=3, stride=2, padding=1), # 56
|
| 227 |
-
nn.GroupNorm(num_groups, 128),
|
| 228 |
-
nn.LeakyReLU(0.2),
|
| 229 |
-
|
| 230 |
-
conv(128, 256, kernel_size=3, padding=1),
|
| 231 |
-
nn.GroupNorm(num_groups, 256),
|
| 232 |
-
nn.LeakyReLU(0.2),
|
| 233 |
-
|
| 234 |
-
conv(256, 256, kernel_size=3, stride=2, padding=1), # 28
|
| 235 |
-
nn.GroupNorm(num_groups, 256),
|
| 236 |
-
nn.LeakyReLU(0.2),
|
| 237 |
-
|
| 238 |
-
conv(256, 512, kernel_size=3, padding=1),
|
| 239 |
-
nn.GroupNorm(num_groups, 512),
|
| 240 |
-
nn.LeakyReLU(0.2),
|
| 241 |
-
|
| 242 |
-
conv(512, 512, kernel_size=3, stride=2, padding=1), # 14
|
| 243 |
-
nn.GroupNorm(num_groups, 512),
|
| 244 |
-
nn.LeakyReLU(0.2),
|
| 245 |
-
|
| 246 |
-
conv(512, 512, kernel_size=3, stride=1, padding=1),
|
| 247 |
-
nn.GroupNorm(num_groups, 512),
|
| 248 |
-
nn.LeakyReLU(0.2),
|
| 249 |
-
|
| 250 |
-
conv(512, 512, kernel_size=3, stride=2, padding=1), # 7
|
| 251 |
-
nn.GroupNorm(num_groups, 512),
|
| 252 |
-
nn.LeakyReLU(0.2),
|
| 253 |
-
]
|
| 254 |
-
|
| 255 |
-
tail = [
|
| 256 |
-
nn.AdaptiveAvgPool2d(1),
|
| 257 |
-
nn.Conv2d(512, 1024, kernel_size=1),
|
| 258 |
-
nn.LeakyReLU(0.2),
|
| 259 |
-
nn.Conv2d(1024, 1, kernel_size=1)
|
| 260 |
-
]
|
| 261 |
-
|
| 262 |
-
if use_sigmoid:
|
| 263 |
-
tail.append(nn.Sigmoid())
|
| 264 |
-
|
| 265 |
-
self.body = nn.Sequential(*body)
|
| 266 |
-
self.tail = nn.Sequential(*tail)
|
| 267 |
-
|
| 268 |
-
def forward(self, x):
|
| 269 |
-
x = self.body(x)
|
| 270 |
-
out = self.tail(x)
|
| 271 |
-
return out
|
| 272 |
-
|
| 273 |
-
class UNetDiscriminatorSN(nn.Module):
|
| 274 |
-
"""Defines a U-Net discriminator with spectral normalization (SN)
|
| 275 |
-
|
| 276 |
-
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
| 277 |
-
|
| 278 |
-
Arg:
|
| 279 |
-
num_in_ch (int): Channel number of inputs. Default: 3.
|
| 280 |
-
num_feat (int): Channel number of base intermediate features. Default: 64.
|
| 281 |
-
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
|
| 282 |
-
"""
|
| 283 |
-
|
| 284 |
-
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
| 285 |
-
super(UNetDiscriminatorSN, self).__init__()
|
| 286 |
-
self.skip_connection = skip_connection
|
| 287 |
-
norm = spectral_norm
|
| 288 |
-
# the first convolution
|
| 289 |
-
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
| 290 |
-
# downsample
|
| 291 |
-
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
| 292 |
-
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
| 293 |
-
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
| 294 |
-
# upsample
|
| 295 |
-
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
| 296 |
-
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
| 297 |
-
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
| 298 |
-
# extra convolutions
|
| 299 |
-
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
| 300 |
-
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
| 301 |
-
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
| 302 |
-
|
| 303 |
-
def forward(self, x, illu = None):
|
| 304 |
-
# downsample
|
| 305 |
-
ingress = self.conv0(x)
|
| 306 |
-
if illu is not None : ingress = ingress * (1 - illu * 2)
|
| 307 |
-
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
| 308 |
-
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
| 309 |
-
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
| 310 |
-
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
|
| 311 |
-
|
| 312 |
-
# upsample
|
| 313 |
-
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
|
| 314 |
-
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
|
| 315 |
-
|
| 316 |
-
if self.skip_connection:
|
| 317 |
-
x4 = x4 + x2
|
| 318 |
-
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
| 319 |
-
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
|
| 320 |
-
|
| 321 |
-
if self.skip_connection:
|
| 322 |
-
x5 = x5 + x1
|
| 323 |
-
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
|
| 324 |
-
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
|
| 325 |
-
|
| 326 |
-
if self.skip_connection:
|
| 327 |
-
x6 = x6 + x0
|
| 328 |
-
|
| 329 |
-
# extra convolutions
|
| 330 |
-
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
| 331 |
-
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
| 332 |
-
out = self.conv9(out)
|
| 333 |
-
|
| 334 |
-
# print(out.shape, 'real_esrgan out shape')
|
| 335 |
-
return out #if illu is None else out * illu
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/vgg.py
DELETED
|
@@ -1,66 +0,0 @@
|
|
| 1 |
-
from collections import namedtuple
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torchvision import models
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class Vgg16(torch.nn.Module):
|
| 8 |
-
def __init__(self, requires_grad=False):
|
| 9 |
-
super(Vgg16, self).__init__()
|
| 10 |
-
vgg_pretrained_features = models.vgg16(pretrained=True).features
|
| 11 |
-
self.slice1 = torch.nn.Sequential()
|
| 12 |
-
self.slice2 = torch.nn.Sequential()
|
| 13 |
-
self.slice3 = torch.nn.Sequential()
|
| 14 |
-
self.slice4 = torch.nn.Sequential()
|
| 15 |
-
for x in range(4):
|
| 16 |
-
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
| 17 |
-
for x in range(4, 9):
|
| 18 |
-
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
| 19 |
-
for x in range(9, 16):
|
| 20 |
-
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
| 21 |
-
for x in range(16, 23):
|
| 22 |
-
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
| 23 |
-
if not requires_grad:
|
| 24 |
-
for param in self.parameters():
|
| 25 |
-
param.requires_grad = False
|
| 26 |
-
|
| 27 |
-
def forward(self, X):
|
| 28 |
-
h = self.slice1(X)
|
| 29 |
-
h_relu1_2 = h
|
| 30 |
-
h = self.slice2(h)
|
| 31 |
-
h_relu2_2 = h
|
| 32 |
-
h = self.slice3(h)
|
| 33 |
-
h_relu3_3 = h
|
| 34 |
-
h = self.slice4(h)
|
| 35 |
-
h_relu4_3 = h
|
| 36 |
-
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
|
| 37 |
-
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
|
| 38 |
-
return out
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class Vgg19(torch.nn.Module):
|
| 42 |
-
def __init__(self, requires_grad=False):
|
| 43 |
-
super(Vgg19, self).__init__()
|
| 44 |
-
self.vgg_pretrained_features = models.vgg19(pretrained=True).features
|
| 45 |
-
|
| 46 |
-
if not requires_grad:
|
| 47 |
-
for param in self.parameters():
|
| 48 |
-
param.requires_grad = False
|
| 49 |
-
|
| 50 |
-
def forward(self, X, indices=None):
|
| 51 |
-
if indices is None:
|
| 52 |
-
indices = [2, 7, 12, 21, 30]
|
| 53 |
-
out = []
|
| 54 |
-
for i in range(indices[-1]):
|
| 55 |
-
X = self.vgg_pretrained_features[i](X)
|
| 56 |
-
if (i + 1) in indices:
|
| 57 |
-
out.append(X)
|
| 58 |
-
|
| 59 |
-
return out
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
if __name__ == '__main__':
|
| 63 |
-
vgg = Vgg19()
|
| 64 |
-
import ipdb
|
| 65 |
-
|
| 66 |
-
ipdb.set_trace()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/vit_feature_extractor.py
DELETED
|
@@ -1,164 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
def attn_cosine_sim(x, eps=1e-08):
|
| 5 |
-
assert x.shape[0] == 1, 'x.shape[0] must eqs 1'
|
| 6 |
-
x = x[0] # TEMP: getting rid of redundant dimension, TBF
|
| 7 |
-
norm1 = x.norm(dim=2, keepdim=True)
|
| 8 |
-
factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps)
|
| 9 |
-
sim_matrix = (x @ x.permute(0, 2, 1)) / factor
|
| 10 |
-
return sim_matrix
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class VitExtractor:
|
| 14 |
-
BLOCK_KEY = 'block'
|
| 15 |
-
ATTN_KEY = 'attn'
|
| 16 |
-
PATCH_IMD_KEY = 'patch_imd'
|
| 17 |
-
QKV_KEY = 'qkv'
|
| 18 |
-
KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY]
|
| 19 |
-
|
| 20 |
-
def __init__(self, model_name, device):
|
| 21 |
-
self.model = torch.hub.load('facebookresearch/dino:main', model_name).to(device)
|
| 22 |
-
self.model.eval()
|
| 23 |
-
self.model_name = model_name
|
| 24 |
-
self.hook_handlers = []
|
| 25 |
-
self.layers_dict = {}
|
| 26 |
-
self.outputs_dict = {}
|
| 27 |
-
for key in VitExtractor.KEY_LIST:
|
| 28 |
-
self.layers_dict[key] = []
|
| 29 |
-
self.outputs_dict[key] = []
|
| 30 |
-
self._init_hooks_data()
|
| 31 |
-
|
| 32 |
-
def _init_hooks_data(self):
|
| 33 |
-
self.layers_dict[VitExtractor.BLOCK_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
| 34 |
-
self.layers_dict[VitExtractor.ATTN_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
| 35 |
-
self.layers_dict[VitExtractor.QKV_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
| 36 |
-
self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
| 37 |
-
for key in VitExtractor.KEY_LIST:
|
| 38 |
-
# self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else []
|
| 39 |
-
self.outputs_dict[key] = []
|
| 40 |
-
|
| 41 |
-
def _register_hooks(self, **kwargs):
|
| 42 |
-
for block_idx, block in enumerate(self.model.blocks):
|
| 43 |
-
if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]:
|
| 44 |
-
self.hook_handlers.append(block.register_forward_hook(self._get_block_hook()))
|
| 45 |
-
if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]:
|
| 46 |
-
self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook()))
|
| 47 |
-
if block_idx in self.layers_dict[VitExtractor.QKV_KEY]:
|
| 48 |
-
self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook()))
|
| 49 |
-
if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]:
|
| 50 |
-
self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook()))
|
| 51 |
-
|
| 52 |
-
def _clear_hooks(self):
|
| 53 |
-
for handler in self.hook_handlers:
|
| 54 |
-
handler.remove()
|
| 55 |
-
self.hook_handlers = []
|
| 56 |
-
|
| 57 |
-
def _get_block_hook(self):
|
| 58 |
-
def _get_block_output(model, input, output):
|
| 59 |
-
self.outputs_dict[VitExtractor.BLOCK_KEY].append(output)
|
| 60 |
-
|
| 61 |
-
return _get_block_output
|
| 62 |
-
|
| 63 |
-
def _get_attn_hook(self):
|
| 64 |
-
def _get_attn_output(model, inp, output):
|
| 65 |
-
self.outputs_dict[VitExtractor.ATTN_KEY].append(output)
|
| 66 |
-
|
| 67 |
-
return _get_attn_output
|
| 68 |
-
|
| 69 |
-
def _get_qkv_hook(self):
|
| 70 |
-
def _get_qkv_output(model, inp, output):
|
| 71 |
-
self.outputs_dict[VitExtractor.QKV_KEY].append(output)
|
| 72 |
-
|
| 73 |
-
return _get_qkv_output
|
| 74 |
-
|
| 75 |
-
# TODO: CHECK ATTN OUTPUT TUPLE
|
| 76 |
-
def _get_patch_imd_hook(self):
|
| 77 |
-
def _get_attn_output(model, inp, output):
|
| 78 |
-
self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0])
|
| 79 |
-
|
| 80 |
-
return _get_attn_output
|
| 81 |
-
|
| 82 |
-
def get_feature_from_input(self, input_img): # List([B, N, D])
|
| 83 |
-
self._register_hooks()
|
| 84 |
-
self.model(input_img)
|
| 85 |
-
feature = self.outputs_dict[VitExtractor.BLOCK_KEY]
|
| 86 |
-
self._clear_hooks()
|
| 87 |
-
self._init_hooks_data()
|
| 88 |
-
return feature
|
| 89 |
-
|
| 90 |
-
def get_qkv_feature_from_input(self, input_img):
|
| 91 |
-
self._register_hooks()
|
| 92 |
-
self.model(input_img)
|
| 93 |
-
feature = self.outputs_dict[VitExtractor.QKV_KEY]
|
| 94 |
-
self._clear_hooks()
|
| 95 |
-
self._init_hooks_data()
|
| 96 |
-
return feature
|
| 97 |
-
|
| 98 |
-
def get_attn_feature_from_input(self, input_img):
|
| 99 |
-
self._register_hooks()
|
| 100 |
-
self.model(input_img)
|
| 101 |
-
feature = self.outputs_dict[VitExtractor.ATTN_KEY]
|
| 102 |
-
self._clear_hooks()
|
| 103 |
-
self._init_hooks_data()
|
| 104 |
-
return feature
|
| 105 |
-
|
| 106 |
-
def get_patch_size(self):
|
| 107 |
-
return 8 if "8" in self.model_name else 16
|
| 108 |
-
|
| 109 |
-
def get_width_patch_num(self, input_img_shape):
|
| 110 |
-
b, c, h, w = input_img_shape
|
| 111 |
-
patch_size = self.get_patch_size()
|
| 112 |
-
return w // patch_size
|
| 113 |
-
|
| 114 |
-
def get_height_patch_num(self, input_img_shape):
|
| 115 |
-
b, c, h, w = input_img_shape
|
| 116 |
-
patch_size = self.get_patch_size()
|
| 117 |
-
return h // patch_size
|
| 118 |
-
|
| 119 |
-
def get_patch_num(self, input_img_shape):
|
| 120 |
-
patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape))
|
| 121 |
-
return patch_num
|
| 122 |
-
|
| 123 |
-
def get_head_num(self):
|
| 124 |
-
if "dino" in self.model_name:
|
| 125 |
-
return 6 if "s" in self.model_name else 12
|
| 126 |
-
return 6 if "small" in self.model_name else 12
|
| 127 |
-
|
| 128 |
-
def get_embedding_dim(self):
|
| 129 |
-
if "dino" in self.model_name:
|
| 130 |
-
return 384 if "s" in self.model_name else 768
|
| 131 |
-
return 384 if "small" in self.model_name else 768
|
| 132 |
-
|
| 133 |
-
def get_queries_from_qkv(self, qkv, input_img_shape):
|
| 134 |
-
patch_num = self.get_patch_num(input_img_shape)
|
| 135 |
-
head_num = self.get_head_num()
|
| 136 |
-
embedding_dim = self.get_embedding_dim()
|
| 137 |
-
q = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[0]
|
| 138 |
-
return q
|
| 139 |
-
|
| 140 |
-
def get_keys_from_qkv(self, qkv, input_img_shape):
|
| 141 |
-
patch_num = self.get_patch_num(input_img_shape)
|
| 142 |
-
head_num = self.get_head_num()
|
| 143 |
-
embedding_dim = self.get_embedding_dim()
|
| 144 |
-
k = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[1]
|
| 145 |
-
return k
|
| 146 |
-
|
| 147 |
-
def get_values_from_qkv(self, qkv, input_img_shape):
|
| 148 |
-
patch_num = self.get_patch_num(input_img_shape)
|
| 149 |
-
head_num = self.get_head_num()
|
| 150 |
-
embedding_dim = self.get_embedding_dim()
|
| 151 |
-
v = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[2]
|
| 152 |
-
return v
|
| 153 |
-
|
| 154 |
-
def get_keys_from_input(self, input_img, layer_num):
|
| 155 |
-
qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num]
|
| 156 |
-
keys = self.get_keys_from_qkv(qkv_features, input_img.shape)
|
| 157 |
-
return keys
|
| 158 |
-
|
| 159 |
-
def get_keys_self_sim_from_input(self, input_img, layer_num):
|
| 160 |
-
keys = self.get_keys_from_input(input_img, layer_num=layer_num)
|
| 161 |
-
h, t, d = keys.shape
|
| 162 |
-
concatenated_keys = keys.transpose(0, 1).reshape(t, h * d)
|
| 163 |
-
ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...])
|
| 164 |
-
return ssim_map
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
options/__init__.py
DELETED
|
File without changes
|
options/__pycache__/__init__.cpython-38.pyc
DELETED
|
Binary file (151 Bytes)
|
|
|
options/__pycache__/base_option.cpython-38.pyc
DELETED
|
Binary file (2.68 kB)
|
|
|
options/base_option.py
DELETED
|
@@ -1,47 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import models
|
| 3 |
-
|
| 4 |
-
model_names = sorted(name for name in models.__dict__
|
| 5 |
-
if name.islower() and not name.startswith("__")
|
| 6 |
-
and callable(models.__dict__[name]))
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class BaseOptions():
|
| 10 |
-
def __init__(self):
|
| 11 |
-
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 12 |
-
self.initialized = False
|
| 13 |
-
|
| 14 |
-
def initialize(self):
|
| 15 |
-
# experiment specifics
|
| 16 |
-
self.parser.add_argument('--name', type=str, default='ytmt_ucs_sirs',
|
| 17 |
-
help='name of the experiment. It decides where to store samples and models')
|
| 18 |
-
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
|
| 19 |
-
self.parser.add_argument('--model', type=str, default='revcol', help='chooses which model to use.')
|
| 20 |
-
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
| 21 |
-
self.parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
|
| 22 |
-
self.parser.add_argument('--resume_epoch', '-re', type=int, default=None,
|
| 23 |
-
help='checkpoint to use. (default: latest')
|
| 24 |
-
self.parser.add_argument('--seed', type=int, default=2018, help='random seed to use. Default=2018')
|
| 25 |
-
self.parser.add_argument('--supp_eval', action='store_true', help='supplementary evaluation')
|
| 26 |
-
self.parser.add_argument('--start_now', action='store_true', help='supplementary evaluation')
|
| 27 |
-
self.parser.add_argument('--testr', action='store_true', help='test for reflections')
|
| 28 |
-
self.parser.add_argument('--select', type=str, default=None)
|
| 29 |
-
|
| 30 |
-
# for setting input
|
| 31 |
-
self.parser.add_argument('--serial_batches', action='store_true',
|
| 32 |
-
help='if true, takes images in order to make batches, otherwise takes them randomly')
|
| 33 |
-
self.parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data')
|
| 34 |
-
self.parser.add_argument('--max_dataset_size', type=int, default=None,
|
| 35 |
-
help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
|
| 36 |
-
|
| 37 |
-
# for display
|
| 38 |
-
self.parser.add_argument('--no-log', action='store_true', help='disable tf logger?')
|
| 39 |
-
self.parser.add_argument('--no-verbose', action='store_true', help='disable verbose info?')
|
| 40 |
-
self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
|
| 41 |
-
self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
|
| 42 |
-
self.parser.add_argument('--display_id', type=int, default=0,
|
| 43 |
-
help='window id of the web display (use 0 to disable visdom)')
|
| 44 |
-
self.parser.add_argument('--display_single_pane_ncols', type=int, default=0,
|
| 45 |
-
help='if positive, display all images in a single visdom web panel with certain number of images per row.')
|
| 46 |
-
|
| 47 |
-
self.initialized = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
options/net_options/__init__.py
DELETED
|
File without changes
|
options/net_options/__pycache__/__init__.cpython-38.pyc
DELETED
|
Binary file (163 Bytes)
|
|
|
options/net_options/__pycache__/base_options.cpython-38.pyc
DELETED
|
Binary file (2.4 kB)
|
|
|
options/net_options/__pycache__/train_options.cpython-38.pyc
DELETED
|
Binary file (3.54 kB)
|
|
|
options/net_options/base_options.py
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
from options.base_option import BaseOptions as Base
|
| 2 |
-
from util import util
|
| 3 |
-
import os
|
| 4 |
-
import torch
|
| 5 |
-
import numpy as np
|
| 6 |
-
import random
|
| 7 |
-
|
| 8 |
-
class BaseOptions(Base):
|
| 9 |
-
def initialize(self):
|
| 10 |
-
Base.initialize(self)
|
| 11 |
-
# experiment specifics
|
| 12 |
-
self.parser.add_argument('--inet', type=str, default='ytmt_ucs', help='chooses which architecture to use for inet.')
|
| 13 |
-
self.parser.add_argument('--icnn_path', type=str, default=None, help='icnn checkpoint to use.')
|
| 14 |
-
self.parser.add_argument('--init_type', type=str, default='edsr', help='network initialization [normal|xavier|kaiming|orthogonal|uniform]')
|
| 15 |
-
# for network
|
| 16 |
-
self.parser.add_argument('--hyper', action='store_true', help='if true, augment input with vgg hypercolumn feature')
|
| 17 |
-
|
| 18 |
-
self.initialized = True
|
| 19 |
-
|
| 20 |
-
def parse(self):
|
| 21 |
-
if not self.initialized:
|
| 22 |
-
self.initialize()
|
| 23 |
-
self.opt = self.parser.parse_args()
|
| 24 |
-
self.opt.isTrain = self.isTrain # train or test
|
| 25 |
-
|
| 26 |
-
torch.backends.cudnn.deterministic = True
|
| 27 |
-
torch.manual_seed(self.opt.seed)
|
| 28 |
-
np.random.seed(self.opt.seed) # seed for every module
|
| 29 |
-
random.seed(self.opt.seed)
|
| 30 |
-
|
| 31 |
-
str_ids = self.opt.gpu_ids.split(',')
|
| 32 |
-
self.opt.gpu_ids = []
|
| 33 |
-
for str_id in str_ids:
|
| 34 |
-
id = int(str_id)
|
| 35 |
-
if id >= 0:
|
| 36 |
-
self.opt.gpu_ids.append(id)
|
| 37 |
-
|
| 38 |
-
# set gpu ids
|
| 39 |
-
if len(self.opt.gpu_ids) > 0:
|
| 40 |
-
torch.cuda.set_device(self.opt.gpu_ids[0])
|
| 41 |
-
|
| 42 |
-
args = vars(self.opt)
|
| 43 |
-
|
| 44 |
-
print('------------ Options -------------')
|
| 45 |
-
for k, v in sorted(args.items()):
|
| 46 |
-
print('%s: %s' % (str(k), str(v)))
|
| 47 |
-
print('-------------- End ----------------')
|
| 48 |
-
|
| 49 |
-
# save to the disk
|
| 50 |
-
self.opt.name = self.opt.name or '_'.join([self.opt.model])
|
| 51 |
-
expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
|
| 52 |
-
util.mkdirs(expr_dir)
|
| 53 |
-
file_name = os.path.join(expr_dir, 'opt.txt')
|
| 54 |
-
with open(file_name, 'wt') as opt_file:
|
| 55 |
-
opt_file.write('------------ Options -------------\n')
|
| 56 |
-
for k, v in sorted(args.items()):
|
| 57 |
-
opt_file.write('%s: %s\n' % (str(k), str(v)))
|
| 58 |
-
opt_file.write('-------------- End ----------------\n')
|
| 59 |
-
|
| 60 |
-
if self.opt.debug:
|
| 61 |
-
self.opt.display_freq = 20
|
| 62 |
-
self.opt.print_freq = 20
|
| 63 |
-
self.opt.nEpochs = 40
|
| 64 |
-
self.opt.max_dataset_size = 100
|
| 65 |
-
self.opt.no_log = False
|
| 66 |
-
self.opt.nThreads = 0
|
| 67 |
-
self.opt.decay_iter = 0
|
| 68 |
-
self.opt.serial_batches = True
|
| 69 |
-
self.opt.no_flip = True
|
| 70 |
-
|
| 71 |
-
return self.opt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
options/net_options/train_options.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
| 1 |
-
from .base_options import BaseOptions
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
class TrainOptions(BaseOptions):
|
| 5 |
-
def initialize(self):
|
| 6 |
-
BaseOptions.initialize(self)
|
| 7 |
-
# for displays
|
| 8 |
-
self.parser.add_argument('--display_freq', type=int, default=100,
|
| 9 |
-
help='frequency of showing training results on screen')
|
| 10 |
-
self.parser.add_argument('--update_html_freq', type=int, default=1000,
|
| 11 |
-
help='frequency of saving training results to html')
|
| 12 |
-
self.parser.add_argument('--print_freq', type=int, default=100,
|
| 13 |
-
help='frequency of showing training results on console')
|
| 14 |
-
self.parser.add_argument('--eval_freq', type=int, default=1, help='frequency of evaluation')
|
| 15 |
-
self.parser.add_argument('--save_freq', type=int, default=1, help='frequency of save eval samples')
|
| 16 |
-
self.parser.add_argument('--no_html', action='store_true',
|
| 17 |
-
help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
|
| 18 |
-
self.parser.add_argument('--save_epoch_freq', type=int, default=1,
|
| 19 |
-
help='frequency of saving checkpoints at the end of epochs')
|
| 20 |
-
self.parser.add_argument('--debug', action='store_true',
|
| 21 |
-
help='only do one epoch and displays at each iteration')
|
| 22 |
-
self.parser.add_argument('--finetune', action='store_true',
|
| 23 |
-
help='finetune the network using identity inputs and outputs')
|
| 24 |
-
self.parser.add_argument('--if_align', action='store_true',
|
| 25 |
-
help='if align 4x')
|
| 26 |
-
|
| 27 |
-
# self.parser.add_argument('--graph', action='store_true',
|
| 28 |
-
# help='print computation graph')
|
| 29 |
-
# for training (Note: in train_sirs.py, we mannually tune the training protocol, but you can also use following setting by modifying the code in errnet_model.py)
|
| 30 |
-
self.parser.add_argument('--nEpochs', '-n', type=int, default=60, help='# of epochs to run')
|
| 31 |
-
self.parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate for adam')
|
| 32 |
-
self.parser.add_argument('--wd', type=float, default=0, help='weight decay for adam')
|
| 33 |
-
|
| 34 |
-
self.parser.add_argument('--r_pixel_weight', '-rw', type=float, default=1.0, help='weight for r_pixel loss')
|
| 35 |
-
|
| 36 |
-
self.parser.add_argument('--low_sigma', type=float, default=2, help='min sigma in synthetic dataset')
|
| 37 |
-
self.parser.add_argument('--high_sigma', type=float, default=5, help='max sigma in synthetic dataset')
|
| 38 |
-
self.parser.add_argument('--low_gamma', type=float, default=1.3, help='max gamma in synthetic dataset')
|
| 39 |
-
self.parser.add_argument('--high_gamma', type=float, default=1.3, help='max gamma in synthetic dataset')
|
| 40 |
-
|
| 41 |
-
# data augmentation
|
| 42 |
-
self.parser.add_argument('--real20_size', type=int, default=420, help='scale images to compat size')
|
| 43 |
-
self.parser.add_argument('--batchSize', '-b', type=int, default=2, help='input batch size')
|
| 44 |
-
self.parser.add_argument('--loadSize', type=str, default='224,336,448', help='scale images to multiple size')
|
| 45 |
-
self.parser.add_argument('--fineSize', type=str, default='224,224', help='then crop to this size')
|
| 46 |
-
self.parser.add_argument('--no_flip', action='store_true',
|
| 47 |
-
help='if specified, do not flip the images for data augmentation')
|
| 48 |
-
self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop',
|
| 49 |
-
help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
|
| 50 |
-
self.parser.add_argument('--debug_eval', action='store_true',
|
| 51 |
-
help='if specified, do not flip the images for data augmentation')
|
| 52 |
-
self.parser.add_argument('--graph', action='store_true', help='print graph')
|
| 53 |
-
|
| 54 |
-
# for discriminator
|
| 55 |
-
self.parser.add_argument('--which_model_D', type=str, default='disc_vgg', choices=['disc_vgg', 'disc_patch'])
|
| 56 |
-
self.parser.add_argument('--gan_type', type=str, default='rasgan',
|
| 57 |
-
help='gan/sgan : Vanilla GAN; rasgan : relativistic gan')
|
| 58 |
-
# loss weight
|
| 59 |
-
self.parser.add_argument('--unaligned_loss', type=str, default='vgg',
|
| 60 |
-
help='learning rate policy: vgg|mse|ctx|ctx_vgg')
|
| 61 |
-
self.parser.add_argument('--tv_type', type=str, default=None, choices=['ktv', 'mtv'])
|
| 62 |
-
self.parser.add_argument('--vgg_layer', type=int, default=31, help='vgg layer of unaligned loss')
|
| 63 |
-
self.parser.add_argument('--init_lr', type=float, default=1e-2, help='initial learning rate')
|
| 64 |
-
self.parser.add_argument('--fixed_lr', type=float, default=0, help='initial learning rate')
|
| 65 |
-
self.parser.add_argument('--lambda_gan', type=float, default=0.01, help='weight for gan loss')
|
| 66 |
-
self.parser.add_argument('--lambda_vgg', type=float, default=0.1, help='weight for vgg loss')
|
| 67 |
-
self.parser.add_argument('--weight_loss',type=float,default=0.25,help='weight fot overall loss')
|
| 68 |
-
self.parser.add_argument('--num_subnet',type=int,default=4,help='num_number of subnet')
|
| 69 |
-
self.parser.add_argument('--dataset',type=float,default=0.5,help='the setting of dataset')
|
| 70 |
-
self.parser.add_argument('--loss_col',type=int,default=4,help='numcol for loss')
|
| 71 |
-
self.parser.add_argument('--drop_path',type=float,default=0.6,help='drop_path')
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
self.isTrain = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pretrained/README.md
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
# Pretrained models
|
| 2 |
-
|
| 3 |
-
This folder is for pretrained models.
|
|
|
|
|
|
|
|
|
|
|
|
script.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
# Load the original weights file
|
| 4 |
-
original_weights = torch.load('/home/xteam/zhaohao/pycharmproject/YTMT/merge_stem_reg_014_00055524.pt')
|
| 5 |
-
|
| 6 |
-
# Create a new weights dictionary
|
| 7 |
-
# new_weights = {}
|
| 8 |
-
|
| 9 |
-
# # Iterate through the original weights dictionary
|
| 10 |
-
# for key, value in original_weights.items():
|
| 11 |
-
# # Check if the key contains 'projec_shit'
|
| 12 |
-
# if 'projback_shit' in key:
|
| 13 |
-
# # Replace 'projec_shit' with 'project_'
|
| 14 |
-
# new_key = key.replace('projback_shit', 'projback_')
|
| 15 |
-
# new_weights[new_key] = value
|
| 16 |
-
# else:
|
| 17 |
-
# # If the key doesn't contain 'projec_shit', keep it unchanged
|
| 18 |
-
# new_weights[key] = value
|
| 19 |
-
# if 'projback_shit_2' in key:
|
| 20 |
-
# # Replace 'projec_shit' with 'project_'
|
| 21 |
-
# new_key = key.replace('projback_shit_2', 'projback_2')
|
| 22 |
-
# new_weights[new_key] = value
|
| 23 |
-
# else:
|
| 24 |
-
# # If the key doesn't contain 'projec_shit', keep it unchanged
|
| 25 |
-
# new_weights[key] = value
|
| 26 |
-
|
| 27 |
-
# # Save the modified weights
|
| 28 |
-
# torch.save(new_weights, '/home/xteam/zhaohao/pycharmproject/RDNet/new_weights.pth')
|
| 29 |
-
|
| 30 |
-
# print("Weights file has been updated.")
|
| 31 |
-
|
| 32 |
-
# # 打印原始权重字典中的所有键,以检查确切的层名称
|
| 33 |
-
# print("原始权重文件中的层名:")
|
| 34 |
-
# for key in original_weights['icnn'].keys():
|
| 35 |
-
# print(key)
|
| 36 |
-
|
| 37 |
-
# 创建一个新的权重字典
|
| 38 |
-
new_weights = {'icnn': {}}
|
| 39 |
-
|
| 40 |
-
# 遍历原始权重字典
|
| 41 |
-
for key, value in original_weights['icnn'].items():
|
| 42 |
-
# 检查并替换包含 'projback_shit' 的键
|
| 43 |
-
if 'projback_shit_2' in key:
|
| 44 |
-
new_key = key.replace('projback_shit_2', 'projback_2')
|
| 45 |
-
new_weights['icnn'][new_key] = value
|
| 46 |
-
|
| 47 |
-
# 检查并替换包含 'projback_shit_2' 的键
|
| 48 |
-
elif 'projback_shit' in key:
|
| 49 |
-
new_key = key.replace('projback_shit', 'projback_')
|
| 50 |
-
new_weights['icnn'][new_key] = value
|
| 51 |
-
else:
|
| 52 |
-
# 如果键不包含上述字符串,保持不变
|
| 53 |
-
new_weights['icnn'][key] = value
|
| 54 |
-
|
| 55 |
-
# 打印新的权重字典中的所有键,以验证更改
|
| 56 |
-
print("\n更新后的权重文件中的层名:")
|
| 57 |
-
for key in new_weights['icnn'].keys():
|
| 58 |
-
print(key)
|
| 59 |
-
|
| 60 |
-
# 保存修改后的权重
|
| 61 |
-
torch.save(new_weights, '/home/xteam/zhaohao/pycharmproject/RDNet/new_weights_4.pth')
|
| 62 |
-
|
| 63 |
-
print("\n权重文件已更新。")
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_sirs.py
DELETED
|
@@ -1,60 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from os.path import join
|
| 3 |
-
|
| 4 |
-
import torch.backends.cudnn as cudnn
|
| 5 |
-
|
| 6 |
-
# import data.sirs_dataset as datasets
|
| 7 |
-
import data.dataset_sir as datasets
|
| 8 |
-
from data.image_folder import read_fns
|
| 9 |
-
from engine import Engine
|
| 10 |
-
from options.net_options.train_options import TrainOptions
|
| 11 |
-
from tools import mutils
|
| 12 |
-
|
| 13 |
-
opt = TrainOptions().parse()
|
| 14 |
-
|
| 15 |
-
opt.isTrain = False
|
| 16 |
-
cudnn.benchmark = True
|
| 17 |
-
opt.no_log = True
|
| 18 |
-
opt.display_id = 0
|
| 19 |
-
opt.verbose = False
|
| 20 |
-
datadir = os.path.join(os.path.expanduser('~'), '/opt/datasets/sirs')
|
| 21 |
-
|
| 22 |
-
eval_dataset_real = datasets.DSRTestDataset(join(datadir, f'test/real20_{opt.real20_size}'),
|
| 23 |
-
fns=read_fns('data/real_test.txt'), if_align=opt.if_align)
|
| 24 |
-
eval_dataset_solidobject = datasets.DSRTestDataset(join(datadir, 'test/SIR2/SolidObjectDataset'),
|
| 25 |
-
if_align=opt.if_align)
|
| 26 |
-
eval_dataset_postcard = datasets.DSRTestDataset(join(datadir, 'test/SIR2/PostcardDataset'), if_align=opt.if_align)
|
| 27 |
-
eval_dataset_wild = datasets.DSRTestDataset(join(datadir, 'test/SIR2/WildSceneDataset'), if_align=opt.if_align)
|
| 28 |
-
|
| 29 |
-
eval_dataloader_real = datasets.DataLoader(
|
| 30 |
-
eval_dataset_real, batch_size=1, shuffle=True,
|
| 31 |
-
num_workers=opt.nThreads, pin_memory=True)
|
| 32 |
-
|
| 33 |
-
eval_dataloader_solidobject = datasets.DataLoader(
|
| 34 |
-
eval_dataset_solidobject, batch_size=1, shuffle=False,
|
| 35 |
-
num_workers=opt.nThreads, pin_memory=True)
|
| 36 |
-
|
| 37 |
-
eval_dataloader_postcard = datasets.DataLoader(
|
| 38 |
-
eval_dataset_postcard, batch_size=1, shuffle=False,
|
| 39 |
-
num_workers=opt.nThreads, pin_memory=True)
|
| 40 |
-
|
| 41 |
-
eval_dataloader_wild = datasets.DataLoader(
|
| 42 |
-
eval_dataset_wild, batch_size=1, shuffle=False,
|
| 43 |
-
num_workers=opt.nThreads, pin_memory=True)
|
| 44 |
-
|
| 45 |
-
engine = Engine(opt, eval_dataset_real, eval_dataset_solidobject, eval_dataset_postcard, eval_dataloader_wild)
|
| 46 |
-
|
| 47 |
-
"""Main Loop"""
|
| 48 |
-
result_dir = os.path.join('./results', opt.name, mutils.get_formatted_time())
|
| 49 |
-
|
| 50 |
-
res1 = engine.eval(eval_dataloader_real, dataset_name='testdata_real',
|
| 51 |
-
savedir=join(result_dir, 'real20'), suffix='real20')
|
| 52 |
-
|
| 53 |
-
res2 = engine.eval(eval_dataloader_solidobject, dataset_name='testdata_solidobject',
|
| 54 |
-
savedir=join(result_dir, 'solidobject'), suffix='solidobject')
|
| 55 |
-
res3 = engine.eval(eval_dataloader_postcard, dataset_name='testdata_postcard',
|
| 56 |
-
savedir=join(result_dir, 'postcard'), suffix='postcard')
|
| 57 |
-
|
| 58 |
-
res4 = engine.eval(eval_dataloader_wild, dataset_name='testdata_wild',
|
| 59 |
-
savedir=join(result_dir, 'wild'), suffix='wild')
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|