SuperPatrick commited on
Commit
98feea6
·
verified ·
1 Parent(s): 74f7b76

Upload 20 files

Browse files
LMAR_GAN_train.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ import torchvision.transforms as transforms
4
+ from utils import read_args, save_checkpoint, AverageMeter, CosineAnnealingWarmRestarts
5
+ import time
6
+ from tqdm import trange, tqdm
7
+ from torchvision.utils import save_image
8
+ # from tensorboardX import SummaryWriter
9
+ import os
10
+ import json
11
+ import time
12
+ import logging
13
+
14
+ os.environ['CUDA_VISIBLE_DEVICES'] = '1'
15
+ import torch
16
+ from torch import optim
17
+ import torch.nn as nn
18
+ import torchvision.utils as vutils
19
+ import torch.nn.functional as F
20
+
21
+ from data import *
22
+ from model import *
23
+ from loss import *
24
+ import pyiqa
25
+ from torch.autograd import Variable
26
+ import numpy as np
27
+
28
+ global_step = 0
29
+ psnr_calculator = pyiqa.create_metric('psnr').cuda()
30
+ ssim_calculator = pyiqa.create_metric('ssimc', downsample=True).cuda()
31
+
32
+ criterion_GAN = nn.MSELoss()
33
+ Tensor = torch.cuda.FloatTensor
34
+
35
+ mmdLoss = MMDLoss().cuda()
36
+
37
+ # cos_loss = cos_loss
38
+ # feature_extractor.eval()
39
+
40
+
41
+ def train(model, data_loader, criterion, optimizer_G, optimizer_D, epoch, args, discriminator):
42
+ global global_step
43
+ iter_bar = tqdm(data_loader, desc='Iter (loss=X.XXX)')
44
+ nbatches = len(data_loader)
45
+
46
+ total_losses = AverageMeter()
47
+
48
+ pixel_losses = AverageMeter()
49
+ resize_losses = AverageMeter()
50
+ pseudo_losses = AverageMeter()
51
+ up_losses = AverageMeter()
52
+ dis_losses = AverageMeter()
53
+
54
+ psnrs = AverageMeter()
55
+ ssims = AverageMeter()
56
+
57
+ optimizer_G.zero_grad()
58
+ optimizer_D.zero_grad()
59
+
60
+ start_time = time.time()
61
+
62
+ if not os.path.exists(args.output_dir + '/image_train'):
63
+ os.mkdir(args.output_dir + '/image_train')
64
+
65
+ if not os.path.exists(args.output_dir + "/models"):
66
+ os.mkdir(args.output_dir + "/models")
67
+
68
+ for i, batch in enumerate(iter_bar):
69
+ optimizer_G.zero_grad()
70
+ optimizer_D.zero_grad()
71
+
72
+ inp_img, gt_img, down_h, down_w, inp_img_path = batch
73
+ batch_size = inp_img.size(0)
74
+ inp_img = inp_img.cuda()
75
+ gt_img = gt_img.cuda()
76
+
77
+ down_size = (down_h.item(), down_w.item())
78
+ up_size = eval(args.train_loader["img_size"])
79
+
80
+ down_x, hr_feature, new_lr_feature, ori_lr_feature, residual, res = model(inp_img, down_size, up_size)
81
+
82
+
83
+ dis_patch_lr = (1, down_size[0] // 2 ** 4, down_size[1] // 2 ** 4)
84
+ valid_lr = Variable(Tensor(np.ones((batch_size, *dis_patch_lr))), requires_grad=False)
85
+ fake_lr = Variable(Tensor(np.zeros((batch_size, *dis_patch_lr))), requires_grad=False)
86
+
87
+
88
+ pixel_loss = criterion_GAN(discriminator(down_x), valid_lr)
89
+ pixel_losses.update(pixel_loss.item(), batch_size)
90
+
91
+ resize_loss = criterion(hr_feature, new_lr_feature)
92
+ resize_losses.update(resize_loss.item(), batch_size)
93
+
94
+ pseudo_loss = similarity_loss(new_lr_feature, hr_feature) * 5000
95
+ pseudo_losses.update(pseudo_loss.item(), batch_size)
96
+
97
+ up_loss, gradient = feat_ssim(new_lr_feature, hr_feature, inp_img)
98
+ up_losses.update(up_loss.item(), batch_size)
99
+
100
+ total_loss = pixel_loss + resize_loss + pseudo_loss + up_loss
101
+ total_losses.update(total_loss.item(), batch_size)
102
+
103
+ total_loss.backward()
104
+ optimizer_G.step()
105
+
106
+
107
+
108
+ loss_real_lr = criterion_GAN(discriminator(resize(inp_img, out_shape=down_size, antialiasing=False)), valid_lr)
109
+
110
+ loss_fake_lr = criterion_GAN(discriminator(down_x.detach()), fake_lr)
111
+
112
+ loss_D = (loss_fake_lr + loss_real_lr) * 0.5
113
+ dis_losses.update(loss_D.item(), batch_size)
114
+
115
+ loss_D.backward()
116
+ optimizer_D.step()
117
+
118
+ iter_bar.set_description('Iter (loss=%5.6f)' % (total_losses.avg + dis_losses.avg))
119
+
120
+ if i % 200 == 0:
121
+ error = torch.abs(resize(inp_img, out_shape=down_size, antialiasing=False) - down_x)
122
+ saved_image = torch.cat(
123
+ [resize(inp_img, out_shape=down_size, antialiasing=False)[0:2], down_x[0:2], error[0:2]],
124
+ dim=0)
125
+ save_image(saved_image, args.output_dir + '/image_train/epoch_{}_iter_down_{}.png'.format(epoch, i))
126
+
127
+ saved_image = torch.cat(
128
+ [torch.mean(hr_feature, dim=1, keepdim=True)[0:2], torch.mean(new_lr_feature, dim=1, keepdim=True)[0:2],
129
+ torch.mean(ori_lr_feature, dim=1, keepdim=True)[0:2], torch.mean(torch.abs(new_lr_feature-ori_lr_feature), dim=1, keepdim=True)[0:2]],
130
+ dim=0)
131
+ save_image(saved_image, args.output_dir + '/image_train/epoch_{}_iter_feat_{}.png'.format(epoch, i))
132
+ residual = residual * 10
133
+ save_image(residual[0], args.output_dir + '/image_train/epoch_{}_iter_out_{}.png'.format(epoch, i))
134
+
135
+ if i % max(1, nbatches // 10) == 0:
136
+ psnr_val, ssim_val = 0.0, 0.0
137
+ psnrs.update(psnr_val, batch_size)
138
+ ssims.update(ssim_val, batch_size)
139
+
140
+ logging.info(
141
+ "Epoch {}, learning rates {:}, Iter {}, total_loss {:.4f}, pixel_loss {:.4f}, resize_loss {:.4f}, pseudo_loss {:.4f}, up_loss {:.4f}, dis_loss: {:.4f}, PSNR {:.4f}, SSIM {:.4f}, Elapse time {:.2f}\n".format(
142
+ epoch, optimizer_G.param_groups[0]["lr"], i, total_losses.avg, pixel_losses.avg, resize_losses.avg,
143
+ pseudo_losses.avg, up_losses.avg, dis_losses.avg,
144
+ psnrs.avg, ssims.avg,
145
+ time.time() - start_time))
146
+
147
+ if epoch % 1 == 0:
148
+ logging.info("** ** * Saving model and optimizer ** ** * ")
149
+
150
+ output_model_file = os.path.join(args.output_dir + "/models", "model.%d.bin" % (epoch))
151
+ state = {"epoch": epoch, "state_dict": model.state_dict(), "step": global_step}
152
+ save_checkpoint(state, output_model_file)
153
+
154
+ output_model_file = os.path.join(args.output_dir + "/models", "discriminator.%d.bin" % (epoch))
155
+ state = {"epoch": epoch, "state_dict": discriminator.state_dict(), "step": global_step}
156
+ save_checkpoint(state, output_model_file)
157
+ logging.info("Save model to %s", output_model_file)
158
+
159
+ logging.info(
160
+ "Finish training epoch %d, avg total_loss: %.4f, avg pixel_loss: %.4f, avg resize_loss: %.4f, avg pseudo_loss: %.4f, avg up_loss: %.4f, "
161
+ "avg dis_loss: %.4f, avg PSNR: %.2f, avg SSIM: %.2F, and takes %.2f seconds" % (
162
+ epoch, total_losses.avg, pixel_losses.avg, resize_losses.avg, pseudo_losses.avg, up_losses.avg, dis_losses.avg, psnrs.avg,
163
+ ssims.avg,
164
+ time.time() - start_time))
165
+
166
+ logging.info("***** CUDA.empty_cache() *****\n")
167
+ torch.cuda.empty_cache()
168
+
169
+
170
+ def evaluate(model, load_path, data_loader, epoch):
171
+ checkpoint = torch.load(load_path)
172
+ model.load_state_dict(checkpoint["state_dict"])
173
+ model.cuda()
174
+ model.eval()
175
+
176
+ psnrs = AverageMeter()
177
+ ssims = AverageMeter()
178
+ random_index = torch.randint(low=0, high=5, size=(1,))
179
+ down_size = eval(args.test_loader["img_size"])
180
+ down_size = down_size[random_index]
181
+ logging.info("Inference at down size: {}".format(down_size))
182
+ up_size = eval(args.test_loader["gt_size"])
183
+
184
+ start_time = time.time()
185
+ with torch.no_grad():
186
+ for i, batch in enumerate(tqdm(data_loader)):
187
+ inp_img, gt_img, inp_img_path = batch
188
+ inp_img = inp_img.cuda()
189
+ batch_size = inp_img.size(0)
190
+ up_out, _ = model(inp_img, down_size, up_size, test_flag=True)
191
+
192
+ # metrics
193
+ clamped_out = torch.clamp(up_out, 0, 1)
194
+ psnr_val, ssim_val = psnr_calculator(clamped_out, gt_img), ssim_calculator(clamped_out, gt_img)
195
+ psnrs.update(torch.mean(psnr_val).item(), batch_size)
196
+ ssims.update(torch.mean(ssim_val).item(), batch_size)
197
+ torch.cuda.empty_cache()
198
+
199
+ if i % 100 == 0:
200
+ logging.info(
201
+ "PSNR {:.4f}, SSIM {:.4f}, Elapse time {:.2f}\n".format(psnrs.avg, ssims.avg,
202
+ time.time() - start_time))
203
+
204
+ logging.info("avg PSNR: %.4f, avg SSIM: %.4F, and takes %.2f seconds" % (
205
+ psnrs.avg, ssims.avg, time.time() - start_time))
206
+
207
+
208
+ def main(args):
209
+ global global_step
210
+
211
+ start_epoch = 1
212
+ global_step = 0
213
+
214
+ if not os.path.exists(args.output_dir):
215
+ os.mkdir(args.output_dir)
216
+
217
+ with open(os.path.join(args.output_dir, "args.json"), "w") as f:
218
+ json.dump(args.__dict__, f, sort_keys=True, indent=2)
219
+
220
+ log_format = "%(asctime)s %(levelname)-8s %(message)s"
221
+ log_file = os.path.join(args.output_dir, "train_log")
222
+ logging.basicConfig(filename=log_file, level=logging.INFO, format=log_format)
223
+ logging.getLogger().addHandler(logging.StreamHandler())
224
+
225
+ # device setting
226
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
227
+ args.device = device
228
+
229
+ logging.info(args.__dict__)
230
+
231
+ model = codebook_model(args)
232
+
233
+ discriminator = Discriminator(3).cuda()
234
+
235
+
236
+ optimizer_G = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.optimizer["lr"],
237
+ betas=(0.9, 0.999))
238
+
239
+ optimizer_D = optim.Adam(list(discriminator.parameters()),
240
+ lr=args.optimizer["lr"],
241
+ betas=(0.9, 0.999))
242
+
243
+ logging.info("Building data loader")
244
+
245
+ if args.train_loader["loader"] == "resize":
246
+ train_transforms = transforms.Compose([transforms.Resize(eval(args.train_loader["img_size"])),
247
+ transforms.ToTensor()])
248
+ train_loader = get_loader(args.data["train_dir"],
249
+ eval(args.train_loader["img_size"]), train_transforms, False,
250
+ int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
251
+ args.train_loader["shuffle"], random_flag=False)
252
+
253
+ elif args.train_loader["loader"] == "crop":
254
+ train_loader = get_loader(args.data["train_dir"],
255
+ eval(args.train_loader["img_size"]), False, True,
256
+ int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
257
+ args.train_loader["shuffle"], random_flag=args.train_loader["random_flag"])
258
+
259
+ elif args.train_loader["loader"] == "default":
260
+ train_transforms = transforms.Compose([transforms.ToTensor()])
261
+ train_loader = get_loader(args.data["train_dir"],
262
+ eval(args.train_loader["img_size"]), train_transforms, False,
263
+ int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
264
+ args.train_loader["shuffle"], random_flag=args.train_loader["random_flag"])
265
+ else:
266
+ raise NotImplementedError
267
+
268
+ if args.test_loader["loader"] == "default":
269
+
270
+ test_transforms = transforms.Compose([transforms.ToTensor()])
271
+ test_loader = get_loader(args.data["test_dir"],
272
+ None, test_transforms, False,
273
+ int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
274
+ args.test_loader["shuffle"], random_flag=False)
275
+
276
+ elif args.test_loader["loader"] == "resize":
277
+
278
+ test_transforms = transforms.Compose([transforms.Resize(eval(args.test_loader["img_size"])),
279
+ transforms.ToTensor()])
280
+ test_loader = get_loader(args.data["test_dir"],
281
+ eval(args.test_loader["img_size"]), test_transforms, False,
282
+ int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
283
+ args.test_loader["shuffle"], random_flag=False)
284
+ else:
285
+ raise NotImplementedError
286
+
287
+ # criterion = similarity_loss
288
+ criterion = nn.SmoothL1Loss()
289
+ # criterion = nn.L1Loss()
290
+
291
+ # vgg_loss = VGGLoss()
292
+
293
+ if args.optimizer["type"] == "cos":
294
+ lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=args.optimizer["T_0"],
295
+ T_mult=args.optimizer["T_MULT"],
296
+ eta_min=args.optimizer["ETA_MIN"],
297
+ last_epoch=-1)
298
+ elif args.optimizer["type"] == "step":
299
+ lr_scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=args.optimizer["step"],
300
+ gamma=args.optimizer["gamma"])
301
+ lr_scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=args.optimizer["step"],
302
+ gamma=args.optimizer["gamma"])
303
+
304
+ t_total = int(len(train_loader) * args.optimizer["total_epoch"])
305
+ logging.info("***** CUDA.empty_cache() *****")
306
+ torch.cuda.empty_cache()
307
+
308
+ logging.info("***** Running training *****")
309
+ logging.info(" Batch size = %d", args.train_loader["batch_size"])
310
+ logging.info(" Num steps = %d", t_total)
311
+ logging.info(" Loader length = %d", len(train_loader))
312
+
313
+ model.train()
314
+ model.cuda()
315
+
316
+ logging.info("Begin training from epoch = %d\n", start_epoch)
317
+ for epoch in trange(start_epoch, args.optimizer["total_epoch"] + 1, desc="Epoch"):
318
+ train(model, train_loader, criterion, optimizer_G, optimizer_D, epoch, args, discriminator)
319
+ lr_scheduler_G.step()
320
+ lr_scheduler_D.step()
321
+ if epoch % args.evaluate_intervel == 0:
322
+ logging.info("***** Running testing *****")
323
+ load_path = os.path.join(args.output_dir + "/models", "model.%d.bin" % (epoch))
324
+ evaluate(model, load_path, test_loader, epoch)
325
+ logging.info("***** End testing *****")
326
+
327
+
328
+ if __name__ == '__main__':
329
+ parser = read_args("/home/yuwei/code/cvpr/config/LMAR_config.yaml")
330
+ args = parser.parse_args()
331
+ main(args)
LMAR_VGG_train.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ['CUDA_VISIBLE_DEVICES'] = '1'
4
+ import argparse
5
+ import yaml
6
+ import torchvision.transforms as transforms
7
+ from utils import read_args, save_checkpoint, AverageMeter, CosineAnnealingWarmRestarts
8
+ import time
9
+ from tqdm import trange, tqdm
10
+ from torchvision.utils import save_image
11
+ import json
12
+ import time
13
+ import logging
14
+ import torch
15
+ from torch import nn, optim
16
+ import torchvision.utils as vutils
17
+ import torch.nn.functional as F
18
+ import torch.nn as nn
19
+
20
+ from data import *
21
+ from model import *
22
+ from loss import *
23
+ import pyiqa
24
+
25
+ from torch.autograd import Variable
26
+
27
+ global_step = 0
28
+ psnr_calculator = pyiqa.create_metric('psnr').cuda()
29
+ ssim_calculator = pyiqa.create_metric('ssimc', downsample=True).cuda()
30
+
31
+ feature_extractor = VGGPerceptualLoss(resize=False).cuda()
32
+ feature_extractor.eval()
33
+
34
+
35
+ def weight_annealing(epoch):
36
+ initial_weight = 1
37
+ if epoch < 2:
38
+ return initial_weight # 初始阶段保持权重不变
39
+ else:
40
+ return initial_weight * 0.001 # 后续阶段权重继续减小
41
+
42
+
43
+ def train(model, data_loader, criterion, optimizer, epoch, args):
44
+ global global_step
45
+ iter_bar = tqdm(data_loader, desc='Iter (loss=X.XXX)')
46
+ nbatches = len(data_loader)
47
+
48
+ total_losses = AverageMeter()
49
+
50
+ pixel_losses = AverageMeter()
51
+ resize_losses = AverageMeter()
52
+ pseudo_losses = AverageMeter()
53
+
54
+ psnrs = AverageMeter()
55
+ ssims = AverageMeter()
56
+
57
+ optimizer.zero_grad()
58
+
59
+ start_time = time.time()
60
+
61
+ if not os.path.exists(args.output_dir + '/image_train'):
62
+ os.mkdir(args.output_dir + '/image_train')
63
+
64
+ if not os.path.exists(args.output_dir + "/models"):
65
+ os.mkdir(args.output_dir + "/models")
66
+
67
+ for i, batch in enumerate(iter_bar):
68
+ optimizer.zero_grad()
69
+ inp_img, gt_img, down_h, down_w, inp_img_path = batch
70
+ batch_size = inp_img.size(0)
71
+ inp_img = inp_img.cuda()
72
+ gt_img = gt_img.cuda()
73
+
74
+ down_size = (down_h.item(), down_w.item())
75
+ up_size = eval(args.train_loader["img_size"])
76
+
77
+
78
+ down_x, hr_feature, new_lr_feature, ori_lr_feature, residual, res = model(inp_img, down_size, up_size)
79
+
80
+ pixel_loss = criterion(new_lr_feature, hr_feature)
81
+
82
+ pixel_losses.update(pixel_loss.item(), batch_size)
83
+
84
+ resize_loss = feature_extractor(down_x, resize(inp_img, out_shape=down_size, antialiasing=False),
85
+ feature_layers=[3])
86
+ resize_loss = resize_loss * weight_annealing(epoch)
87
+ resize_losses.update(resize_loss.item(), batch_size)
88
+
89
+
90
+ pseudo_loss, _ = feat_ssim(new_lr_feature, hr_feature, inp_img)
91
+ pseudo_losses.update(pseudo_loss.item(), batch_size)
92
+
93
+ total_loss = pixel_loss + resize_loss + pseudo_loss
94
+ total_losses.update(total_loss.item(), batch_size)
95
+
96
+ total_loss.backward()
97
+
98
+ optimizer.step()
99
+
100
+ iter_bar.set_description('Iter (loss=%5.6f)' % total_losses.avg)
101
+
102
+ if i % 200 == 0:
103
+ # print(residual.max())
104
+ error = torch.abs(resize(inp_img, out_shape=down_size, antialiasing=False) - down_x)
105
+ # error = (error - error.min()) / (error.max()-error.min())
106
+ saved_image = torch.cat(
107
+ [resize(inp_img, out_shape=down_size, antialiasing=False)[0:2], down_x[0:2], error[0:2]],
108
+ dim=0)
109
+ save_image(saved_image, args.output_dir + '/image_train/epoch_{}_iter_down_{}.png'.format(epoch, i))
110
+
111
+ saved_image = torch.cat(
112
+ [torch.mean(hr_feature, dim=1, keepdim=True)[0:2], torch.mean(new_lr_feature, dim=1, keepdim=True)[0:2],
113
+ torch.mean(ori_lr_feature, dim=1, keepdim=True)[0:2],
114
+ torch.mean(torch.abs(new_lr_feature - ori_lr_feature), dim=1, keepdim=True)[0:2]],
115
+ dim=0)
116
+ save_image(saved_image, args.output_dir + '/image_train/epoch_{}_iter_feat_{}.png'.format(epoch, i))
117
+ # residual = (residual - residual.min()) / (residual.max()-residual.min())
118
+ residual = residual * 10
119
+ save_image(residual[0], args.output_dir + '/image_train/epoch_{}_iter_out_{}.png'.format(epoch, i))
120
+
121
+ if i % max(1, nbatches // 10) == 0:
122
+ psnr_val, ssim_val = 0.0, 0.0
123
+ psnrs.update(psnr_val, batch_size)
124
+ ssims.update(ssim_val, batch_size)
125
+
126
+ logging.info(
127
+ "Epoch {}, learning rates {:}, Iter {}, total_loss {:.4f}, pixel_loss {:.4f}, resize_loss {:.4f}, pseudo_loss {:.4f}, PSNR {:.4f}, SSIM {:.4f}, Elapse time {:.2f}\n".format(
128
+ epoch, optimizer.param_groups[0]["lr"], i, total_losses.avg, pixel_losses.avg, resize_losses.avg,
129
+ pseudo_losses.avg,
130
+ psnrs.avg, ssims.avg,
131
+ time.time() - start_time))
132
+
133
+ if epoch % 1 == 0:
134
+ logging.info("** ** * Saving model and optimizer ** ** * ")
135
+
136
+ output_model_file = os.path.join(args.output_dir + "/models", "model.%d.bin" % (epoch))
137
+ state = {"epoch": epoch, "state_dict": model.state_dict(),
138
+ "optimizer": optimizer.state_dict(), "step": global_step}
139
+
140
+ save_checkpoint(state, output_model_file)
141
+ logging.info("Save model to %s", output_model_file)
142
+
143
+ logging.info(
144
+ "Finish training epoch %d, avg total_loss: %.4f, avg pixel_loss: %.4f, avg resize_loss: %.4f, avg pseudo_loss: %.4f, avg PSNR: %.2f, avg SSIM: %.2F, and takes %.2f seconds" % (
145
+ epoch, total_losses.avg, pixel_losses.avg, resize_losses.avg, pseudo_losses.avg, psnrs.avg, ssims.avg,
146
+ time.time() - start_time))
147
+
148
+ logging.info("***** CUDA.empty_cache() *****\n")
149
+ torch.cuda.empty_cache()
150
+
151
+
152
+ def evaluate(model, load_path, data_loader, epoch):
153
+ checkpoint = torch.load(load_path)
154
+ model.load_state_dict(checkpoint["state_dict"])
155
+ model.cuda()
156
+ model.eval()
157
+
158
+ psnrs = AverageMeter()
159
+ ssims = AverageMeter()
160
+ random_index = torch.randint(low=0, high=5, size=(1,))
161
+ down_size = eval(args.test_loader["img_size"])
162
+ down_size = down_size[random_index]
163
+ logging.info("Inference at down size: {}".format(down_size))
164
+ up_size = eval(args.test_loader["gt_size"])
165
+
166
+ start_time = time.time()
167
+ with torch.no_grad():
168
+ for i, batch in enumerate(tqdm(data_loader)):
169
+ inp_img, gt_img, inp_img_path = batch
170
+ inp_img = inp_img.cuda()
171
+ batch_size = inp_img.size(0)
172
+ up_out = model(inp_img, down_size, up_size, test_flag=True)
173
+
174
+ # metrics
175
+ clamped_out = torch.clamp(up_out, 0, 1)
176
+ psnr_val, ssim_val = psnr_calculator(clamped_out, gt_img), ssim_calculator(clamped_out, gt_img)
177
+ psnrs.update(torch.mean(psnr_val).item(), batch_size)
178
+ ssims.update(torch.mean(ssim_val).item(), batch_size)
179
+ torch.cuda.empty_cache()
180
+
181
+ if i % 100 == 0:
182
+ logging.info(
183
+ "PSNR {:.4f}, SSIM {:.4f}, Elapse time {:.2f}\n".format(psnrs.avg, ssims.avg,
184
+ time.time() - start_time))
185
+
186
+ logging.info(f"Finish test at epoch {epoch}: avg PSNR: %.4f, avg SSIM: %.4F, and takes %.2f seconds" % (
187
+ psnrs.avg, ssims.avg, time.time() - start_time))
188
+
189
+
190
+ def main(args):
191
+ global global_step
192
+
193
+ start_epoch = 1
194
+ global_step = 0
195
+
196
+ if not os.path.exists(args.output_dir):
197
+ os.mkdir(args.output_dir)
198
+
199
+ with open(os.path.join(args.output_dir, "args.json"), "w") as f:
200
+ json.dump(args.__dict__, f, sort_keys=True, indent=2)
201
+
202
+ log_format = "%(asctime)s %(levelname)-8s %(message)s"
203
+ log_file = os.path.join(args.output_dir, "train_log")
204
+ logging.basicConfig(filename=log_file, level=logging.INFO, format=log_format)
205
+ logging.getLogger().addHandler(logging.StreamHandler())
206
+
207
+ # device setting
208
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
209
+ args.device = device
210
+
211
+ logging.info(args.__dict__)
212
+
213
+ model = codebook_model(args)
214
+
215
+ optimizer = optim.Adam(model.parameters(), lr=args.optimizer["lr"],
216
+ betas=(0.9, 0.999))
217
+
218
+ logging.info("Building data loader")
219
+
220
+ if args.train_loader["loader"] == "resize":
221
+ train_transforms = transforms.Compose([transforms.Resize(eval(args.train_loader["img_size"])),
222
+ transforms.ToTensor()])
223
+ train_loader = get_loader(args.data["train_dir"],
224
+ eval(args.train_loader["img_size"]), train_transforms, False,
225
+ int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
226
+ args.train_loader["shuffle"], random_flag=False)
227
+
228
+ elif args.train_loader["loader"] == "crop":
229
+ train_loader = get_loader(args.data["train_dir"],
230
+ eval(args.train_loader["img_size"]), False, True,
231
+ int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
232
+ args.train_loader["shuffle"], random_flag=args.train_loader["random_flag"])
233
+
234
+ elif args.train_loader["loader"] == "default":
235
+ train_transforms = transforms.Compose([transforms.ToTensor()])
236
+ train_loader = get_loader(args.data["train_dir"],
237
+ eval(args.train_loader["img_size"]), train_transforms, False,
238
+ int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
239
+ args.train_loader["shuffle"], random_flag=args.train_loader["random_flag"])
240
+ else:
241
+ raise NotImplementedError
242
+
243
+ if args.test_loader["loader"] == "default":
244
+
245
+ test_transforms = transforms.Compose([transforms.ToTensor()])
246
+ test_loader = get_loader(args.data["test_dir"],
247
+ eval(args.test_loader["img_size"]), test_transforms, False,
248
+ int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
249
+ args.test_loader["shuffle"], random_flag=False)
250
+
251
+ elif args.test_loader["loader"] == "resize":
252
+ test_transforms = transforms.Compose([transforms.Resize(eval(args.test_loader["img_size"])),
253
+ transforms.ToTensor()])
254
+ test_loader = get_loader(args.data["test_dir"],
255
+ eval(args.test_loader["img_size"]), test_transforms, False,
256
+ int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
257
+ args.test_loader["shuffle"], random_flag=False)
258
+ else:
259
+ raise NotImplementedError
260
+
261
+ criterion = nn.SmoothL1Loss()
262
+ # vgg_loss = VGGLoss()
263
+
264
+ if args.optimizer["type"] == "cos":
265
+ lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=args.optimizer["T_0"],
266
+ T_mult=args.optimizer["T_MULT"],
267
+ eta_min=args.optimizer["ETA_MIN"],
268
+ last_epoch=-1)
269
+ elif args.optimizer["type"] == "step":
270
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.optimizer["step"],
271
+ gamma=args.optimizer["gamma"])
272
+
273
+ t_total = int(len(train_loader) * args.optimizer["total_epoch"])
274
+ logging.info("***** CUDA.empty_cache() *****")
275
+ torch.cuda.empty_cache()
276
+
277
+ logging.info("***** Running training *****")
278
+ logging.info(" Batch size = %d", args.train_loader["batch_size"])
279
+ logging.info(" Num steps = %d", t_total)
280
+ logging.info(" Loader length = %d", len(train_loader))
281
+
282
+ model.train()
283
+ model.cuda()
284
+
285
+ logging.info("Begin training from epoch = %d\n", start_epoch)
286
+ # evaluate(model, "/home/yuwei/experiment/cvpr/prompt_final_vgg_gradient/models/model.1.bin", test_loader, 1)
287
+ for epoch in trange(start_epoch, args.optimizer["total_epoch"] + 1, desc="Epoch"):
288
+ train(model, train_loader, criterion, optimizer, epoch, args)
289
+ lr_scheduler.step()
290
+ if epoch % args.evaluate_intervel == 0:
291
+ logging.info("***** Running testing *****")
292
+ load_path = os.path.join(args.output_dir + "/models", "model.%d.bin" % (epoch))
293
+ evaluate(model, load_path, test_loader, epoch)
294
+ logging.info("***** End testing *****")
295
+
296
+
297
+ if __name__ == '__main__':
298
+ parser = read_args("/home/yuwei/code/cvpr/config/LMAR_config.yaml")
299
+ args = parser.parse_args()
300
+ main(args)
LMAR_test.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ import torchvision.transforms as transforms
4
+ from utils import read_args, save_checkpoint, AverageMeter, CosineAnnealingWarmRestarts
5
+ import time
6
+ from tqdm import trange, tqdm
7
+ from torchvision.utils import save_image
8
+ import os
9
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
10
+ import json
11
+ import time
12
+ import logging
13
+ import torch
14
+ from torch import nn, optim
15
+ import numpy as np
16
+ import torch.nn.functional as F
17
+
18
+ import copy
19
+ from model import *
20
+ from data import *
21
+ from PIL import Image
22
+ from torch.optim import LBFGS
23
+ import pyiqa
24
+ from thop import profile
25
+ from thop import clever_format
26
+
27
+ from torchvision.models.feature_extraction import create_feature_extractor
28
+
29
+ psnr_calculator = pyiqa.create_metric('psnr').cuda()
30
+ ssim_calculator = pyiqa.create_metric('ssimc', downsample=True).cuda()
31
+
32
+
33
+ def test(load_path, data_loader, args):
34
+ model = codebook_model(args)
35
+ checkpoint = torch.load(load_path)
36
+ model.load_state_dict(checkpoint["state_dict"])
37
+ model.cuda()
38
+ model.eval()
39
+
40
+ psnrs = AverageMeter()
41
+ ssims = AverageMeter()
42
+ lpipss = AverageMeter()
43
+ niqes = AverageMeter()
44
+
45
+ down_size = (1440, 2560)
46
+ logging.info("Inference at down size: {}".format(down_size))
47
+ up_size = eval(args.test_loader["gt_size"])
48
+
49
+ start_time = time.time()
50
+ with torch.no_grad():
51
+ for i, batch in enumerate(tqdm(data_loader)):
52
+ inp_img, gt_img, inp_img_path = batch
53
+ inp_img = inp_img.cuda()
54
+ batch_size = inp_img.size(0)
55
+ gt_img = gt_img.cuda()
56
+ up_out = model(inp_img, down_size, up_size, test_flag=True)
57
+ name = inp_img_path[0].split("/")[-1]
58
+ # save_image(up_out[0], os.path.join(save_path, name))
59
+
60
+ # metrics
61
+ clamped_out = torch.clamp(up_out, 0, 1)
62
+
63
+ psnr_val, ssim_val = psnr_calculator(clamped_out, gt_img), ssim_calculator(clamped_out, gt_img)
64
+ psnrs.update(psnr_val.item(), batch_size)
65
+ ssims.update(ssim_val.item(), batch_size)
66
+
67
+ if i % 700 == 0:
68
+ logging.info(
69
+ "PSNR {:.4f}, SSIM {:.4f}, LPIPS {:.4F}, NIQE {:.4F}, Elapse time {:.2f}\n".format(psnrs.avg, ssims.avg, lpipss.avg, niqes.avg,
70
+ time.time() - start_time))
71
+
72
+ logging.info("Finish test: avg PSNR: %.4f, avg SSIM: %.4F, avg LPIPS: %.4F, avg NIQE: %.4F, and takes %.2f seconds" % (
73
+ psnrs.avg, ssims.avg, lpipss.avg, niqes.avg, time.time() - start_time))
74
+
75
+
76
+ def main(args, load_path):
77
+ if not os.path.exists(args.output_dir):
78
+ os.mkdir(args.output_dir)
79
+ test_transforms = transforms.Compose([transforms.ToTensor()])
80
+
81
+ log_format = "%(asctime)s %(levelname)-8s %(message)s"
82
+ log_file = os.path.join(args.output_dir, "test_log")
83
+ logging.basicConfig(filename=log_file, level=logging.INFO, format=log_format)
84
+ logging.getLogger().addHandler(logging.StreamHandler())
85
+
86
+ logging.info("Building data loader")
87
+
88
+ test_loader = get_loader(args.data["test_dir"],
89
+ eval(args.test_loader["img_size"]), test_transforms, False,
90
+ int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
91
+ args.test_loader["shuffle"], random_flag=False)
92
+ test_time(load_path, test_loader, args)
93
+
94
+
95
+ if __name__ == '__main__':
96
+ parser = read_args("/home/yuwei/code/cvpr/config/LMAR_config.yaml")
97
+ args = parser.parse_args()
98
+ main(args, "./pretrained_models\LMAR_model.bin")
base_test.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ import torchvision.transforms as transforms
4
+ from utils import read_args, save_checkpoint, AverageMeter, calculate_metrics, CosineAnnealingWarmRestarts
5
+ # import torchvision.transforms.InterpolationMode
6
+ import time
7
+ from tqdm import trange, tqdm
8
+ from torchvision.utils import save_image
9
+ import os
10
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
11
+ import json
12
+ import time
13
+ import logging
14
+ import torch
15
+ from torch import nn, optim
16
+ import numpy as np
17
+ import torch.nn.functional as F
18
+
19
+ from model import *
20
+ from data import *
21
+ from PIL import Image
22
+ from torchvision.transforms import Resize
23
+ import pyiqa
24
+ from thop import profile
25
+ from thop import clever_format
26
+
27
+ psnr_calculator = pyiqa.create_metric('psnr').cuda()
28
+ ssim_calculator = pyiqa.create_metric('ssimc', downsample=True).cuda()
29
+ lpips_calculator = pyiqa.create_metric('lpips').cuda()
30
+ niqe_calculator = pyiqa.create_metric('niqe').cuda()
31
+
32
+
33
+ def test(load_path, data_loader, args):
34
+ # if not os.path.exists(args.output_dir + '/out_my'):
35
+ # os.mkdir(args.output_dir + '/out_my')
36
+
37
+ # save_path = args.output_dir + "/out_my"
38
+ model = net(args)
39
+ checkpoint = torch.load(load_path)
40
+ model.load_state_dict(checkpoint["state_dict"])
41
+ model.cuda()
42
+ model.eval()
43
+
44
+ psnrs = AverageMeter()
45
+ ssims = AverageMeter()
46
+ lpipss = AverageMeter()
47
+ niqes = AverageMeter()
48
+
49
+ start_time = time.time()
50
+ down_size = (1440, 2560)
51
+ logging.info("Inference at down size: {}".format(down_size))
52
+ with torch.no_grad():
53
+ for i, batch in enumerate(tqdm(data_loader)):
54
+ input_img, gt_img, inp_img_path = batch
55
+
56
+ name = inp_img_path[0].split("/")[-1]
57
+ input_img = input_img.cuda()
58
+ batch_size = input_img.size(0)
59
+ start_time = time.time()
60
+ input_img = resize(input_img, out_shape=down_size, antialiasing=False)
61
+ out_img = model(input_img)
62
+ out_img = resize(out_img, out_shape=eval(args.test_loader["gt_size"]), antialiasing=False)
63
+
64
+ # metrics
65
+ clamped_out = torch.clamp(out_img, 0, 1)
66
+ psnr_val, ssim_val = psnr_calculator(clamped_out, gt_img), ssim_calculator(clamped_out, gt_img)
67
+ psnrs.update(torch.mean(psnr_val).item(), batch_size)
68
+ ssims.update(torch.mean(ssim_val).item(), batch_size)
69
+
70
+ # lpips = lpips_calculator(clamped_out, gt_img)
71
+ # lpipss.update(torch.mean(lpips).item(), batch_size)
72
+ # niqe = niqe_calculator(clamped_out)
73
+ # niqes.update(torch.mean(niqe).item(), batch_size)
74
+ torch.cuda.empty_cache()
75
+
76
+ if i % 700 == 0:
77
+ logging.info(
78
+ "PSNR {:.4f}, SSIM {:.4f}, LPIPS {:.4F}, NIQE {:.4F}, Elapse time {:.2f}\n".format(psnrs.avg, ssims.avg, lpipss.avg, niqes.avg,
79
+ time.time() - start_time))
80
+
81
+ logging.info("Finish test: avg PSNR: %.4f, avg SSIM: %.4F, avg LPIPS: %.4F, avg NIQE: %.4F, and takes %.2f seconds" % (
82
+ psnrs.avg, ssims.avg, lpipss.avg, niqes.avg, time.time() - start_time))
83
+
84
+ def main(args, load_path):
85
+ if not os.path.exists(args.output_dir):
86
+ os.mkdir(args.output_dir)
87
+ test_transforms = transforms.Compose([transforms.ToTensor()])
88
+
89
+ log_format = "%(asctime)s %(levelname)-8s %(message)s"
90
+ log_file = os.path.join(args.output_dir, "baseline_log")
91
+ logging.basicConfig(filename=log_file, level=logging.INFO, format=log_format)
92
+ logging.getLogger().addHandler(logging.StreamHandler())
93
+
94
+ logging.info("Building data loader")
95
+
96
+ test_loader = get_loader(args.data["test_dir"],
97
+ eval(args.test_loader["img_size"]), test_transforms, False,
98
+ int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
99
+ args.test_loader["shuffle"], random_flag=False)
100
+ test(load_path, test_loader, args)
101
+
102
+
103
+ if __name__ == '__main__':
104
+ parser = read_args("/home/yuwei/code/cvpr/config/base_config.yaml")
105
+ args = parser.parse_args()
106
+ main(args, "./pretrained_models/base_model.bin")
base_train.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ import torchvision.transforms as transforms
4
+ from utils import read_args, save_checkpoint, AverageMeter, calculate_metrics, CosineAnnealingWarmRestarts
5
+ import time
6
+ from tqdm import trange, tqdm
7
+ from torchvision.utils import save_image
8
+ # from tensorboardX import SummaryWriter
9
+ import os
10
+ import json
11
+ import time
12
+ import logging
13
+ import torch
14
+ from torch import nn, optim
15
+ import torchvision.utils as vutils
16
+ import torch.nn.functional as F
17
+
18
+ from data import *
19
+ from model import *
20
+ from loss import *
21
+
22
+
23
+ os.environ['CUDA_VISIBLE_DEVICES'] = '1'
24
+
25
+ global_step = 0
26
+
27
+
28
+ def train(model, data_loader, criterion, optimizer, epoch, args):
29
+ global global_step
30
+ iter_bar = tqdm(data_loader, desc='Iter (loss=X.XXX)')
31
+ nbatches = len(data_loader)
32
+
33
+ total_losses = AverageMeter()
34
+ pixel_losses = AverageMeter()
35
+ gradient_losses = AverageMeter()
36
+ psnrs = AverageMeter()
37
+ ssims = AverageMeter()
38
+
39
+ optimizer.zero_grad()
40
+
41
+ start_time = time.time()
42
+
43
+ if not os.path.exists(args.output_dir + '/image_train'):
44
+ os.mkdir(args.output_dir + '/image_train')
45
+
46
+ if not os.path.exists(args.output_dir + "/models"):
47
+ os.mkdir(args.output_dir + "/models")
48
+
49
+ for i, batch in enumerate(iter_bar):
50
+ optimizer.zero_grad()
51
+
52
+ input_img, gt_img, image_path = batch
53
+ input_img = input_img.cuda()
54
+ gt_img = gt_img.cuda()
55
+ batch_size = input_img.size(0)
56
+
57
+ out_img = model(input_img)
58
+
59
+ pixel_loss = criterion(out_img, gt_img)
60
+ pixel_losses.update(pixel_loss.item(), batch_size)
61
+
62
+ # gradient_loss = vggloss(out_img, gt_img).cuda()
63
+ # gradient_loss = args.hyper_params["x_lambda"] * gradient_loss
64
+ # gradient_losses.update(gradient_loss.item(), batch_size)
65
+
66
+ total_loss = pixel_loss
67
+ total_losses.update(total_loss.item(), batch_size)
68
+
69
+ total_loss.backward()
70
+ optimizer.step()
71
+
72
+ iter_bar.set_description('Iter (loss=%5.6f)' % total_losses.avg)
73
+
74
+ if i % 200 == 0:
75
+ saved_image = torch.cat([input_img[0:2], out_img[0:2], gt_img[0:2]], dim=0)
76
+ save_image(saved_image, args.output_dir + '/image_train/epoch_{}_iter_{}.jpg'.format(epoch, i))
77
+
78
+ # metrics
79
+ norm_out = torch.clamp(out_img, 0, 1)
80
+ #psnr_val, ssim_val = calculate_metrics(norm_out, gt_img)
81
+ #psnrs.update(psnr_val.item(), batch_size)
82
+ #ssims.update(ssim_val.item(), batch_size)
83
+
84
+ if i % max(1, nbatches // 10) == 0:
85
+ logging.info(
86
+ "Epoch {}, learning rates {:}, Iter {}, total_loss {:.4f}, pixel_loss {:.4f}, PSNR {:.4f}, SSIM {:.4f}, Elapse time {:.2f}\n".format(
87
+ epoch, optimizer.param_groups[0]["lr"], i, total_losses.avg, pixel_losses.avg,
88
+ psnrs.avg, ssims.avg,
89
+ time.time() - start_time))
90
+
91
+ if epoch % 1 == 0:
92
+ logging.info("** ** * Saving model and optimizer ** ** * ")
93
+
94
+ output_model_file = os.path.join(args.output_dir + "/models", "model.%d.bin" % (epoch))
95
+ state = {"epoch": epoch, "state_dict": model.state_dict(),
96
+ "optimizer": optimizer.state_dict(), "step": global_step}
97
+
98
+ save_checkpoint(state, output_model_file)
99
+ logging.info("Save model to %s", output_model_file)
100
+
101
+ logging.info(
102
+ "Finish training epoch %d, avg total_loss: %.4f, avg pixel_loss: %.4f, avg PSNR: %.2f, avg SSIM: %.2F, and takes %.2f seconds" % (
103
+ epoch, total_losses.avg, pixel_losses.avg, psnrs.avg, ssims.avg,
104
+ time.time() - start_time))
105
+
106
+ logging.info("***** CUDA.empty_cache() *****\n")
107
+ torch.cuda.empty_cache()
108
+
109
+
110
+ def evaluate(model, load_path, data_loader, epoch):
111
+
112
+ checkpoint = torch.load(load_path)
113
+ model.load_state_dict(checkpoint["state_dict"])
114
+ model.cuda()
115
+ model.eval()
116
+
117
+ psnrs = AverageMeter()
118
+ ssims = AverageMeter()
119
+
120
+ start_time = time.time()
121
+ with torch.no_grad():
122
+ for i, batch in enumerate(tqdm(data_loader)):
123
+ input_img, gt_img, inp_img_path = batch
124
+ input_img = input_img.cuda()
125
+ batch_size = input_img.size(0)
126
+ out_img = model(input_img)
127
+
128
+ # metrics
129
+ norm_out = torch.clamp(out_img, 0, 1)
130
+ psnr_val, ssim_val = calculate_metrics(norm_out, gt_img)
131
+ psnrs.update(psnr_val.item(), batch_size)
132
+ ssims.update(ssim_val.item(), batch_size)
133
+ torch.cuda.empty_cache()
134
+
135
+ if i % 100 == 0:
136
+ logging.info(
137
+ "PSNR {:.4f}, SSIM {:.4f}, Elapse time {:.2f}\n".format(psnrs.avg, ssims.avg,
138
+ time.time() - start_time))
139
+
140
+ logging.info(f"Finish test at epoch {epoch}: avg PSNR: %.4f, avg SSIM: %.4F, and takes %.2f seconds" % (
141
+ psnrs.avg, ssims.avg, time.time() - start_time))
142
+
143
+
144
+ def main(args):
145
+ global global_step
146
+
147
+ start_epoch = 1
148
+ global_step = 0
149
+
150
+ if not os.path.exists(args.output_dir):
151
+ os.mkdir(args.output_dir)
152
+
153
+ with open(os.path.join(args.output_dir, "args.json"), "w") as f:
154
+ json.dump(args.__dict__, f, sort_keys=True, indent=2)
155
+
156
+ log_format = "%(asctime)s %(levelname)-8s %(message)s"
157
+ log_file = os.path.join(args.output_dir, "train_log")
158
+ logging.basicConfig(filename=log_file, level=logging.INFO, format=log_format)
159
+ logging.getLogger().addHandler(logging.StreamHandler())
160
+
161
+ # device setting
162
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
163
+ args.device = device
164
+
165
+ logging.info(args.__dict__)
166
+
167
+ if args.resume["flag"]:
168
+ model = net(args)
169
+ model.to(args.device)
170
+ check_point = torch.load(args.resume["checkpoint"])
171
+ model.load_state_dict(check_point["state_dict"])
172
+ optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.optimizer["lr"],
173
+ betas=(0.9, 0.999))
174
+ optimizer.load_state_dict(check_point["optimizer"])
175
+ start_epoch = check_point["epoch"] + 1
176
+ # start_epoch = check_point["epoch"]
177
+
178
+ else:
179
+ model = net(args)
180
+ model.to(args.device)
181
+ optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.optimizer["lr"],
182
+ betas=(0.9, 0.999))
183
+
184
+ logging.info("Building data loader")
185
+
186
+ if args.train_loader["loader"] == "resize":
187
+ train_transforms = transforms.Compose([transforms.Resize(eval(args.train_loader["img_size"])),
188
+ transforms.ToTensor()])
189
+ train_loader = get_loader(args.data["train_dir"],
190
+ eval(args.train_loader["img_size"]), train_transforms, False,
191
+ int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
192
+ args.train_loader["shuffle"], inference_flag=False)
193
+
194
+ elif args.train_loader["loader"] == "crop":
195
+ train_loader = get_loader(args.data["train_dir"],
196
+ eval(args.train_loader["img_size"]), False, True,
197
+ int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
198
+ args.train_loader["shuffle"], inference_flag=False)
199
+ else:
200
+ raise NotImplementedError
201
+
202
+ if args.test_loader["loader"] == "default":
203
+
204
+ test_transforms = transforms.Compose([transforms.ToTensor()])
205
+ test_loader = get_loader(args.data["test_dir"],
206
+ eval(args.test_loader["img_size"]), test_transforms, False,
207
+ int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
208
+ args.test_loader["shuffle"], inference_flag=False)
209
+
210
+ elif args.test_loader["loader"] == "resize":
211
+
212
+ test_transforms = transforms.Compose([transforms.Resize(eval(args.test_loader["img_size"])),
213
+ transforms.ToTensor()])
214
+ test_loader = get_loader(args.data["test_dir"],
215
+ eval(args.test_loader["img_size"]), test_transforms, False,
216
+ int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
217
+ args.test_loader["shuffle"], inference_flag=False)
218
+
219
+ criterion = nn.L1Loss()
220
+ # vgg_loss = VGGLoss()
221
+
222
+ if args.optimizer["type"] == "cos":
223
+ lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=args.optimizer["T_0"],
224
+ T_mult=args.optimizer["T_MULT"],
225
+ eta_min=args.optimizer["ETA_MIN"],
226
+ last_epoch=-1)
227
+ elif args.optimizer["type"] == "step":
228
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.optimizer["step"],
229
+ gamma=args.optimizer["gamma"])
230
+
231
+ if args.resume["flag"]:
232
+ for i in range(start_epoch):
233
+ lr_scheduler.step()
234
+
235
+ t_total = int(len(train_loader) * args.optimizer["total_epoch"])
236
+ logging.info("***** CUDA.empty_cache() *****")
237
+ torch.cuda.empty_cache()
238
+
239
+ logging.info("***** Running training *****")
240
+ logging.info(" Batch size = %d", args.train_loader["batch_size"])
241
+ logging.info(" Num steps = %d", t_total)
242
+ logging.info(" Loader length = %d", len(train_loader))
243
+
244
+ model.train()
245
+ model.cuda()
246
+
247
+ logging.info("Begin training from epoch = %d\n", start_epoch)
248
+ for epoch in trange(start_epoch, args.optimizer["total_epoch"] + 1, desc="Epoch"):
249
+ train(model, train_loader, criterion, optimizer, epoch, args)
250
+ lr_scheduler.step()
251
+ if epoch % args.evaluate_intervel == 0:
252
+ logging.info("***** Running testing *****")
253
+ load_path = os.path.join(args.output_dir + "/models", "model.%d.bin" % (epoch))
254
+ evaluate(model, load_path, test_loader, epoch)
255
+ logging.info("***** End testing *****")
256
+
257
+
258
+ if __name__ == '__main__':
259
+ parser = read_args("/home/yuwei/code/cvpr/config/base_config.yaml")
260
+ args = parser.parse_args()
261
+ main(args)
config/LMAR_config.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: '/home/yuwei/experiment/cvpr/LMAR_cubic'
2
+ data:
3
+ train_dir: /home/data/yuwei/data/uhd4k_ll/train
4
+ test_dir: /home/data/yuwei/data/uhd4k_ll/test
5
+
6
+ model:
7
+ in_channel: 3
8
+ model_channel: 8
9
+ sparsity_threshold: 0.01
10
+ num_blocks: 8
11
+ threslhold_frac: 0.6
12
+ hidden_channel: 48
13
+
14
+ train_loader:
15
+ num_workers: 8
16
+ batch_size: 1
17
+ loader: crop
18
+ img_size: (1024, 1024)
19
+ shuffle: True
20
+ gt_size: (2160, 3840)
21
+ random_flag: True
22
+
23
+ test_loader:
24
+ num_workers: 8
25
+ batch_size: 1
26
+ loader: default
27
+ img_size: ((1440, 2560), (1080, 1920), (1200, 1600), (720, 1280), (540, 960))
28
+ shuffle: False
29
+ gt_size: (2160, 3840)
30
+
31
+ optimizer:
32
+ type: step
33
+ total_epoch: 12
34
+ lr: 0.0004
35
+ T_0: 0.00001
36
+ T_MULT: 1
37
+ ETA_MIN: 0.000001
38
+ step: 4
39
+ gamma: 0.75
40
+
41
+ hyper_params:
42
+ lambda: 0.5
43
+
44
+ resume:
45
+ flag: True
46
+ checkpoint: ./pretrained_models/base_model.bin
47
+
48
+ evaluate_intervel: 1
config/base_config.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: '/home/yuwei/experiment/cvpr/uhd4k_ll_pretrain'
2
+ data:
3
+ train_dir: /home/data/yuwei/data/uhd4k_ll/train
4
+ test_dir: /home/data/yuwei/data/uhd4k_ll/test
5
+
6
+ model:
7
+ in_channel: 3
8
+ model_channel: 8
9
+
10
+ train_loader:
11
+ num_workers: 8
12
+ batch_size: 2
13
+ loader: resize
14
+ img_size: (1024, 1024)
15
+ shuffle: True
16
+
17
+ test_loader:
18
+ num_workers: 8
19
+ batch_size: 1
20
+ loader: default
21
+ img_size: (1200, 1600)
22
+ shuffle: False
23
+ gt_size: (2160, 3840)
24
+
25
+ optimizer:
26
+ type: step
27
+ total_epoch: 100
28
+ lr: 0.001
29
+ T_0: 100
30
+ T_MULT: 1
31
+ ETA_MIN: 0.000001
32
+ step: 20
33
+ gamma: 0.75
34
+
35
+ hyper_params:
36
+ x_lambda: 0.03
37
+
38
+ resume:
39
+ flag: False
40
+ checkpoint: Null
41
+
42
+ evaluate_intervel: 5
data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .loader import *
data/loader.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from torch.utils.data import DataLoader
5
+ from torch.utils.data.dataset import Dataset
6
+ from PIL import Image
7
+ import torchvision.transforms.functional as TF
8
+ import torchvision.transforms as tf
9
+ from PIL import Image, ImageFile
10
+ import random
11
+ import math
12
+ from model import *
13
+ import torch
14
+ # import cv2
15
+ # cv2.setNumThreads(0)
16
+
17
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
18
+
19
+
20
+ class base_dataset(Dataset):
21
+ def __init__(self, data_dir, img_size, transforms=False, crop=False):
22
+ imgs = sorted(os.listdir(data_dir + "/input"))
23
+ self.input_imgs = [os.path.join(data_dir + "/input", name) for name in imgs]
24
+ self.gt_imgs = [os.path.join(data_dir + "/gt", name) for name in imgs]
25
+ self.transforms = transforms
26
+ self.crop = crop
27
+ self.img_size = img_size
28
+
29
+ def __getitem__(self, index):
30
+ inp_img_path = self.input_imgs[index]
31
+ gt_img_path = self.gt_imgs[index]
32
+ inp_img = Image.open(inp_img_path).convert("RGB")
33
+ gt_img = Image.open(gt_img_path).convert("RGB")
34
+ if self.transforms:
35
+ inp_img = self.transforms(inp_img)
36
+ gt_img = self.transforms(gt_img)
37
+
38
+ if self.crop:
39
+ inp_img, gt_img = self.crop_image(inp_img, gt_img)
40
+
41
+ return inp_img, gt_img, inp_img_path
42
+
43
+ def __len__(self):
44
+ return len(self.gt_imgs)
45
+
46
+ def crop_image(self, inp_img, gt_img):
47
+ crop_h, crop_w = self.img_size
48
+ i, j, h, w = tf.RandomCrop.get_params(
49
+ inp_img, output_size=((crop_h, crop_w)))
50
+ inp_img = TF.crop(inp_img, i, j, h, w)
51
+ gt_img = TF.crop(gt_img, i, j, h, w)
52
+ inp_img = TF.to_tensor(inp_img)
53
+ gt_img = TF.to_tensor(gt_img)
54
+
55
+ return inp_img, gt_img
56
+
57
+
58
+ class random_scale_dataset(Dataset):
59
+ def __init__(self, data_dir, img_size, transforms=False, crop=False):
60
+ imgs = sorted(os.listdir(data_dir + "/input"))
61
+ self.input_imgs = [os.path.join(data_dir + "/input", name) for name in imgs]
62
+ self.gt_imgs = [os.path.join(data_dir + "/gt", name) for name in imgs]
63
+ self.transforms = transforms
64
+ self.crop = crop
65
+ self.img_size = img_size
66
+
67
+ def __getitem__(self, index):
68
+ inp_img_path = self.input_imgs[index]
69
+ gt_img_path = self.gt_imgs[index]
70
+ inp_img = Image.open(inp_img_path).convert("RGB")
71
+ gt_img = Image.open(gt_img_path).convert("RGB")
72
+
73
+ random_scale_factor = random.randrange(self.img_size[0] * 0.25, self.img_size[0], 8)
74
+ down_h = down_w = random_scale_factor
75
+
76
+ if self.transforms:
77
+ inp_img = self.transforms(inp_img)
78
+ gt_img = self.transforms(gt_img)
79
+ return inp_img, gt_img, down_h, down_w, inp_img_path
80
+
81
+ if self.crop:
82
+ inp_img, gt_img = self.crop_image(inp_img, gt_img)
83
+ return inp_img, gt_img, down_h, down_w, inp_img_path
84
+
85
+ def __len__(self):
86
+ return len(self.gt_imgs)
87
+
88
+ def crop_image(self, inp_img, gt_img):
89
+ crop_h, crop_w = self.img_size
90
+ i, j, h, w = tf.RandomCrop.get_params(
91
+ inp_img, output_size=((crop_h, crop_w)))
92
+ inp_img = TF.crop(inp_img, i, j, h, w)
93
+ gt_img = TF.crop(gt_img, i, j, h, w)
94
+ inp_img = TF.to_tensor(inp_img)
95
+ gt_img = TF.to_tensor(gt_img)
96
+
97
+ return inp_img, gt_img
98
+
99
+
100
+ def get_loader(data_dir, img_size, transforms, crop_flag, batch_size, num_workers, shuffle, random_flag=False, inference_flag=False):
101
+ if random_flag:
102
+ dataset = random_scale_dataset(data_dir, img_size, transforms, crop_flag)
103
+ dataloader = DataLoader(dataset, batch_size=batch_size,
104
+ shuffle=shuffle, num_workers=num_workers, pin_memory=True)
105
+ else:
106
+ dataset = base_dataset(data_dir, img_size, transforms, crop_flag)
107
+ dataloader = DataLoader(dataset, batch_size=batch_size,
108
+ shuffle=shuffle, num_workers=num_workers, pin_memory=True)
109
+ return dataloader
loss.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+
7
+ class VGG19(torch.nn.Module):
8
+ def __init__(self, requires_grad=False):
9
+ super().__init__()
10
+ vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
11
+ self.slice1 = torch.nn.Sequential()
12
+ self.slice2 = torch.nn.Sequential()
13
+ self.slice3 = torch.nn.Sequential()
14
+ self.slice4 = torch.nn.Sequential()
15
+ self.slice5 = torch.nn.Sequential()
16
+ for x in range(2):
17
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
18
+ for x in range(2, 7):
19
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
20
+ for x in range(7, 12):
21
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
22
+ for x in range(12, 21):
23
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
24
+ for x in range(21, 30):
25
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
26
+ if not requires_grad:
27
+ for param in self.parameters():
28
+ param.requires_grad = False
29
+
30
+ def forward(self, X):
31
+ h_relu1 = self.slice1(X)
32
+ h_relu2 = self.slice2(h_relu1)
33
+ h_relu3 = self.slice3(h_relu2)
34
+ h_relu4 = self.slice4(h_relu3)
35
+ h_relu5 = self.slice5(h_relu4)
36
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
37
+ return out
38
+
39
+
40
+ class VGGLoss(nn.Module):
41
+ def __init__(self):
42
+ super(VGGLoss, self).__init__()
43
+ self.vgg = VGG19().cuda()
44
+ self.criterion = nn.L1Loss()
45
+ self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
46
+
47
+ def forward(self, x, y):
48
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
49
+ loss = 0
50
+ for i in range(len(x_vgg)):
51
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
52
+ return loss
53
+
54
+
55
+ class VGGPerceptualLoss(torch.nn.Module):
56
+ def __init__(self, lam=1, lam_p=1):
57
+ super(VGGPerceptualLoss, self).__init__()
58
+ self.loss_fn = VGGPerceptualLoss()
59
+
60
+ def forward(self, out, gt):
61
+ loss = self.loss_fn(out, gt, feature_layers=[2])
62
+
63
+ return loss
64
+
65
+
66
+ class VGGPerceptualLoss(torch.nn.Module):
67
+ def __init__(self, resize=True):
68
+ super(VGGPerceptualLoss, self).__init__()
69
+ blocks = []
70
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
71
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
72
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
73
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
74
+ for bl in blocks:
75
+ for p in bl:
76
+ p.requires_grad = False
77
+ self.blocks = torch.nn.ModuleList(blocks).cuda()
78
+ self.transform = torch.nn.functional.interpolate
79
+ self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)).cuda()
80
+ self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)).cuda()
81
+ self.resize = resize
82
+
83
+ def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
84
+ if input.shape[1] != 3:
85
+ input = input.repeat(1, 3, 1, 1)
86
+ target = target.repeat(1, 3, 1, 1)
87
+ input = (input - self.mean) / self.std
88
+ target = (target - self.mean) / self.std
89
+ if self.resize:
90
+ input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
91
+ target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
92
+ loss = 0.0
93
+ x = input
94
+ y = target
95
+ for i, block in enumerate(self.blocks):
96
+ x = block(x)
97
+ y = block(y)
98
+ if i in feature_layers:
99
+ loss += torch.nn.functional.l1_loss(x, y)
100
+ if i in style_layers:
101
+ act_x = x.reshape(x.shape[0], x.shape[1], -1)
102
+ act_y = y.reshape(y.shape[0], y.shape[1], -1)
103
+ gram_x = act_x @ act_x.permute(0, 2, 1)
104
+ gram_y = act_y @ act_y.permute(0, 2, 1)
105
+ loss += torch.nn.functional.l1_loss(gram_x, gram_y)
106
+ return loss
107
+
108
+
109
+ def scharr(x): # 输入前对RGB通道求均值在灰度图上算
110
+ b, c, h, w = x.shape
111
+ pad = nn.ReplicationPad2d(padding=(1, 1, 1, 1))
112
+ x = pad(x)
113
+ kx = F.unfold(x, kernel_size=3, stride=1, padding=0) # b,n*k*k,n_H*n_W
114
+ kx = kx.permute([0, 2, 1]) # b,n_H*n_W,n*k*k
115
+ # kx=kx.view(1, b*h*w, 9) #1,b*n_H*n_W,n*k*k
116
+
117
+ w1 = torch.tensor([-3, 0, 3, -10, 0, 10, -3, 0, 3]).float().cuda()
118
+ w2 = torch.tensor([-3, -10, -3, 0, 0, 0, 3, 10, 3]).float().cuda()
119
+
120
+ y1 = torch.matmul((kx * 255.0), w1) # 1,b*n_H*n_W,1
121
+ y2 = torch.matmul((kx * 255.0), w2) # 1,b*n_H*n_W,1
122
+ # y1=y1.view(b,h*w,1) #b,n_H*n_W,1
123
+ y1 = y1.unsqueeze(-1).permute([0, 2, 1]) # b,1,n_H*n_W
124
+ # y2=y2.view(b,h*w,1) #b,n_H*n_W,1
125
+ y2 = y2.unsqueeze(-1).permute([0, 2, 1]) # b,1,n_H*n_W
126
+
127
+ y1 = F.fold(y1, output_size=(h, w), kernel_size=1) # b,m,n_H,n_W
128
+ y2 = F.fold(y2, output_size=(h, w), kernel_size=1) # b,m,n_H,n_W
129
+ y1 = y1.clamp(-255, 255)
130
+ y2 = y2.clamp(-255, 255)
131
+ return (0.5 * torch.abs(y1) + 0.5 * torch.abs(y2)) / 255.0
132
+
133
+
134
+ def gram_matrix(input):
135
+ a, b, c, d = input.size() # a=batch size(=1)
136
+ # b=number of feature maps
137
+ # (c,d)=dimensions of a f. map (N=c*d)
138
+
139
+ features = input.reshape(a * b, c * d) # resize F_XL into \hat F_XL
140
+
141
+ G = torch.mm(features, features.t()) # compute the gram product
142
+
143
+ # we 'normalize' the values of the gram matrix
144
+ # by dividing by the number of element in each feature maps.
145
+ return G.div(a * b * c * d)
146
+
147
+
148
+ class StyleLoss(nn.Module):
149
+ def __init__(self):
150
+ super(StyleLoss, self).__init__()
151
+
152
+ def forward(self, input_fea, target_fea):
153
+ target = gram_matrix(target_fea).detach()
154
+ G = gram_matrix(input_fea)
155
+ loss = F.mse_loss(G, target)
156
+ return loss
157
+
158
+
159
+ def cos_loss(feat1, feat2):
160
+ # maximize average cosine similarity
161
+ return -F.cosine_similarity(feat1, feat2).mean()
162
+
163
+
164
+ def feat_scharr(x):
165
+ x = torch.mean(x, dim=1, keepdim=True)
166
+ x = (x - x.min()) / (x.max() - x.min())
167
+ x = x * 255
168
+ return scharr(x)
169
+
170
+
171
+ def feat_ssim(feat1, feat2, gt):
172
+ mask = scharr(torch.mean(gt, dim=1, keepdim=True))
173
+ # mask = torch.nn.MaxPool2d(5, 1, 2)(mask)
174
+ mask = F.interpolate(mask, size=(feat1.shape[2], feat1.shape[3]), mode="bicubic")
175
+ loss = torch.abs(feat1 - feat2) * mask
176
+ return torch.mean(loss), mask
177
+
178
+
179
+ def similarity_loss(f_s, f_t):
180
+ def at(f):
181
+ return F.normalize(f.pow(2).mean(1).view(f.size(0), -1))
182
+
183
+ return (at(f_s) - at(f_t)).pow(2).mean()
184
+
185
+
186
+ class RBF(nn.Module):
187
+
188
+ def __init__(self, n_kernels=5, mul_factor=2.0, bandwidth=None):
189
+ super().__init__()
190
+ self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2)
191
+ self.bandwidth = bandwidth
192
+
193
+ def get_bandwidth(self, L2_distances):
194
+ if self.bandwidth is None:
195
+ n_samples = L2_distances.shape[0]
196
+ return L2_distances.data.sum() / (n_samples ** 2 - n_samples)
197
+
198
+ return self.bandwidth
199
+
200
+ def forward(self, X):
201
+ L2_distances = torch.cdist(X, X) ** 2
202
+
203
+ return torch.exp(
204
+ -L2_distances[None, ...].cuda() / (self.get_bandwidth(L2_distances).cuda() * self.bandwidth_multipliers.cuda())[:, None,
205
+ None]).sum(dim=0)
206
+
207
+
208
+ class MMDLoss(nn.Module):
209
+
210
+ def __init__(self, kernel=RBF()):
211
+ super().__init__()
212
+ self.kernel = kernel.cuda()
213
+
214
+ def forward(self, X, Y):
215
+ K = self.kernel(torch.vstack([X, Y]))
216
+
217
+ X_size = X.shape[0]
218
+ XX = K[:X_size, :X_size].mean()
219
+ XY = K[:X_size, X_size:].mean()
220
+ YY = K[X_size:, X_size:].mean()
221
+ return XX - 2 * XY + YY
metrics.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+
6
+
7
+ def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
8
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
9
+
10
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
11
+
12
+ Args:
13
+ img (ndarray): Images with range [0, 255].
14
+ img2 (ndarray): Images with range [0, 255].
15
+ crop_border (int): Cropped pixels in each edge of an image. These
16
+ pixels are not involved in the PSNR calculation.
17
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
18
+ Default: 'HWC'.
19
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
20
+
21
+ Returns:
22
+ float: psnr result.
23
+ """
24
+
25
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
26
+ if input_order not in ['HWC', 'CHW']:
27
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
28
+ img = img.astype(np.float64)
29
+ img2 = img2.astype(np.float64)
30
+
31
+ if crop_border != 0:
32
+ img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
33
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
34
+
35
+ if test_y_channel:
36
+ img = to_y_channel(img)
37
+ img2 = to_y_channel(img2)
38
+
39
+ mse = np.mean((img - img2)**2)
40
+ if mse == 0:
41
+ return float('inf')
42
+ return 20. * np.log10(255. / np.sqrt(mse))
43
+
44
+
45
+ def _ssim(img, img2):
46
+ """Calculate SSIM (structural similarity) for one channel images.
47
+
48
+ It is called by func:`calculate_ssim`.
49
+
50
+ Args:
51
+ img (ndarray): Images with range [0, 255] with order 'HWC'.
52
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
53
+
54
+ Returns:
55
+ float: ssim result.
56
+ """
57
+
58
+ c1 = (0.01 * 255)**2
59
+ c2 = (0.03 * 255)**2
60
+
61
+ img = img.astype(np.float64)
62
+ img2 = img2.astype(np.float64)
63
+ kernel = cv2.getGaussianKernel(11, 1.5)
64
+ window = np.outer(kernel, kernel.transpose())
65
+
66
+ mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5]
67
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
68
+ mu1_sq = mu1**2
69
+ mu2_sq = mu2**2
70
+ mu1_mu2 = mu1 * mu2
71
+ sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
72
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
73
+ sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
74
+
75
+ ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
76
+ return ssim_map.mean()
77
+
78
+ def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
79
+ """Calculate SSIM (structural similarity).
80
+
81
+ Ref:
82
+ Image quality assessment: From error visibility to structural similarity
83
+
84
+ The results are the same as that of the official released MATLAB code in
85
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
86
+
87
+ For three-channel images, SSIM is calculated for each channel and then
88
+ averaged.
89
+
90
+ Args:
91
+ img (ndarray): Images with range [0, 255].
92
+ img2 (ndarray): Images with range [0, 255].
93
+ crop_border (int): Cropped pixels in each edge of an image. These
94
+ pixels are not involved in the SSIM calculation.
95
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
96
+ Default: 'HWC'.
97
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
98
+
99
+ Returns:
100
+ float: ssim result.
101
+ """
102
+
103
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
104
+ if input_order not in ['HWC', 'CHW']:
105
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
106
+ # img = reorder_image(img, input_order=input_order)
107
+ # img2 = reorder_image(img2, input_order=input_order)
108
+ img = img.astype(np.float64)
109
+ img2 = img2.astype(np.float64)
110
+
111
+ if crop_border != 0:
112
+ img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
113
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
114
+
115
+ if test_y_channel:
116
+ img = to_y_channel(img)
117
+ img2 = to_y_channel(img2)
118
+
119
+ ssims = []
120
+ for i in range(img.shape[2]):
121
+ ssims.append(_ssim(img[..., i], img2[..., i]))
122
+ return np.array(ssims).mean()
123
+
124
+ if __name__ == '__main__':
125
+
126
+ # test_transforms = transforms.Compose([transforms.Resize((512, 512)),transforms.ToTensor()])
127
+ # inp_img = Image.open("/mnt/disk1/yuwei/data/4Kdehaze/train/clear/0_000002.jpg").convert("RGB")
128
+ # img = test_transforms(inp_img)
129
+ img = cv2.imread("/mnt/disk1/yuwei/data/4Kdehaze/train/clear/0_000002.jpg")
130
+ psnr = calculate_psnr(img, img, 0)
131
+ ssim = calculate_ssim(img, img, 0)
132
+ print(psnr)
133
+ print(ssim)
model/LMAR_model.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import net
2
+ import torch.nn as nn
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import Resize
6
+
7
+ try:
8
+ from resize_right import resize
9
+ except:
10
+ from .resize_right import resize
11
+
12
+ try:
13
+ from .interp_methods import *
14
+ except:
15
+ from interp_methods import *
16
+
17
+ from torchvision.models import vgg19
18
+ from torchvision.models.feature_extraction import create_feature_extractor
19
+
20
+ import tinycudann as tcnn
21
+ from torchvision.utils import save_image
22
+ import torchvision.transforms as transforms
23
+ from torchviz import make_dot
24
+
25
+
26
+ def make_coord(shape, ranges=None, flatten=True):
27
+ """ Make coordinates at grid centers.
28
+ """
29
+ coord_seqs = []
30
+ for i, n in enumerate(shape):
31
+ if ranges is None:
32
+ v0, v1 = -1, 1
33
+ else:
34
+ v0, v1 = ranges[i]
35
+ r = (v1 - v0) / (2 * n)
36
+ seq = v0 + r + (2 * r) * torch.arange(n).float()
37
+ coord_seqs.append(seq)
38
+ ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
39
+ if flatten:
40
+ ret = ret.view(-1, ret.shape[-1])
41
+ return ret
42
+
43
+ def get_local_grid(img):
44
+ local_grid = make_coord(img.shape[-2:], flatten=False).cuda()
45
+ local_grid = local_grid.permute(2, 0, 1).unsqueeze(0)
46
+ local_grid = local_grid.expand(img.shape[0], 2, *img.shape[-2:])
47
+
48
+ return local_grid
49
+
50
+ def creat_coord(x):
51
+ b = x.shape[0]
52
+ coord = make_coord(x.shape[-2:], flatten=False)
53
+ coord = coord.permute(2, 0, 1).contiguous().unsqueeze(0)
54
+ coord = coord.expand(b, 2, *coord.shape[-2:])
55
+
56
+ coord_ = coord.clone()
57
+ coord_ = coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
58
+ coord_ = coord_.permute(0, 2, 3, 1).contiguous()
59
+ coord_ = coord_.view(b, -1, coord.size(1))
60
+ return coord.cuda(), coord_.cuda()
61
+
62
+
63
+ def get_cell(img, local_grid):
64
+ cell = torch.ones_like(local_grid)
65
+ cell[:, 0] *= 2 / img.size(2)
66
+ cell[:, 1] *= 2 / img.size(3)
67
+
68
+ return cell
69
+
70
+
71
+ class TcnnFCBlock(tcnn.Network):
72
+ def __init__(
73
+ self, in_features, out_features,
74
+ num_hidden_layers, hidden_features,
75
+ activation: str = 'LeakyRelu', last_activation: str = 'None',
76
+ seed=42):
77
+ assert hidden_features in [16, 32, 64, 128], "hidden_features can only be 16, 32, 64, or 128."
78
+ super().__init__(in_features, out_features, network_config={
79
+ "otype": "FullyFusedMLP", # Component type.
80
+ "activation": activation, # Activation of hidden layers.
81
+ "output_activation": last_activation, # Activation of the output layer.
82
+ "n_neurons": hidden_features, # Neurons in each hidden layer. # May only be 16, 32, 64, or 128.
83
+ "n_hidden_layers": num_hidden_layers, # Number of hidden layers.
84
+ }, seed=seed)
85
+
86
+ def forward(self, x: torch.Tensor):
87
+ prefix = x.shape[:-1]
88
+ return super().forward(x.flatten(0, -2)).unflatten(0, prefix)
89
+
90
+
91
+ class LMAR_model(nn.Module):
92
+ def __init__(self, args):
93
+ super().__init__()
94
+ self.resume_flag = args.resume["flag"]
95
+ self.load_path = args.resume["checkpoint"]
96
+
97
+ if self.resume_flag and self.load_path:
98
+ self.model = net(args)
99
+ checkpoint = torch.load(self.load_path)
100
+ self.model.load_state_dict(checkpoint["state_dict"])
101
+ for param in self.model.parameters():
102
+ param.requires_grad_(False)
103
+
104
+ self.in_channel = 3
105
+ self.out_channel = 3
106
+ self.kernel_size = 3
107
+ self.imnet = TcnnFCBlock(7, self.in_channel * self.out_channel * self.kernel_size * self.kernel_size, 5,
108
+ 128).cuda()
109
+ self.mid_nodes = {"hr_backbone.skip2": "bottom"}
110
+ self.extractor_mid = create_feature_extractor(self.model, self.mid_nodes)
111
+ self.modulation = nn.Conv2d(6, 3, 1, 1, 0)
112
+ # self.projection = nn.Conv2d()
113
+
114
+ def forward(self, x, down_size, up_size, test_flag=False):
115
+ if test_flag:
116
+ up_out, _ = self.inference(x, down_size, up_size)
117
+ return up_out, _
118
+ else:
119
+ down_x, hr_feature, new_lr_feature, ori_lr_feature, residual, res = self.train_model(x, down_size, up_size)
120
+ return down_x, hr_feature, new_lr_feature, ori_lr_feature, residual, res
121
+
122
+ def train_model(self, x, down_size, up_size):
123
+ # down_sizer = transforms.Resize(size=down_size,
124
+ # interpolation=transforms.InterpolationMode.BILINEAR)
125
+ # up_sizer = transforms.Resize(size=up_size,
126
+ # interpolation=transforms.InterpolationMode.BILINEAR)
127
+
128
+ b = x.shape[0]
129
+ # down_x = down_sizer(x)
130
+ down_x = resize(x, out_shape=down_size, antialiasing=False)
131
+ # down_x = resize(x, out_shape=down_size, antialiasing=True)
132
+
133
+ hr_feature = self.extractor_mid(x)["bottom"]
134
+ # feature_sizer = transforms.Resize(size=(hr_feature.shape[2], hr_feature.shape[3]),
135
+ # interpolation=transforms.InterpolationMode.BILINEAR)
136
+
137
+ hr_coord, hr_coord_ = self.creat_coord(x)
138
+ lr_coord, _ = self.creat_coord(down_x)
139
+ q_coord = F.grid_sample(lr_coord, hr_coord_.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)
140
+ q_coord = q_coord.view(b, -1, hr_coord.size(2) * hr_coord.size(3)).permute(0, 2, 1).contiguous()
141
+
142
+ # test_coord = F.grid_sample(lr_coord, hr_coord.permute(0, 2, 3, 1), mode='bilinear', align_corners=False)
143
+ # test_rel_coord = hr_coord - test_coord
144
+ # test_rel_coord = test_rel_coord.view(b, -1, 2)
145
+
146
+ # test_rel_coord[:, :, 0] *= down_x.shape[-2]
147
+ # test_rel_coord[:, :, 1] *= down_x.shape[-1]
148
+
149
+ rel_coord = hr_coord_ - q_coord
150
+ rel_coord[:, :, 0] *= down_x.shape[-2]
151
+ rel_coord[:, :, 1] *= down_x.shape[-1]
152
+
153
+ laplacian = x - resize(down_x, out_shape=up_size, antialiasing=False)
154
+ # laplacian = x - resize(down_x, out_shape=up_size, antialiasing=True)
155
+
156
+ laplacian = laplacian.reshape(b, laplacian.size(1), -1).permute(0, 2, 1).contiguous()
157
+
158
+ # cell
159
+ hr_grid = self.get_local_grid(x)
160
+ hr_cell = self.get_cell(x, hr_grid)
161
+ hr_cell_ = hr_cell.clone()
162
+ hr_cell_ = hr_cell_.permute(0, 2, 3, 1).contiguous()
163
+ rel_cell = hr_cell_.view(b, -1, hr_cell.size(1))
164
+ rel_cell[:, :, 0] *= down_x.shape[-2]
165
+ rel_cell[:, :, 1] *= down_x.shape[-1]
166
+
167
+ inp = torch.cat([rel_coord.cuda(), rel_cell.cuda(), laplacian], dim=-1)
168
+ local_weight = self.imnet(inp)
169
+ local_weight = local_weight.type(torch.float32)
170
+ local_weight = local_weight.view(b, -1, x.shape[1] * 9, 3).contiguous()
171
+
172
+ unfolded_x = F.unfold(x, 3, padding=1).view(b, -1, x.shape[2] * x.shape[3]).permute(0, 2, 1).contiguous()
173
+ cols = unfolded_x.unsqueeze(2)
174
+ out = torch.matmul(cols, local_weight).squeeze(2).permute(0, 2, 1).contiguous().view(b, -1, x.size(2),
175
+ x.size(3))
176
+ out = resize(out, out_shape=down_size, antialiasing=False)
177
+ # out = resize(out, out_shape=down_size, antialiasing=True)
178
+
179
+ # out = down_sizer(out)
180
+
181
+ # ori
182
+ ori_lr_feature = self.extractor_mid(down_x)["bottom"]
183
+ ori_lr_feature = resize(ori_lr_feature, out_shape=(hr_feature.shape[2], hr_feature.shape[3]),
184
+ antialiasing=False)
185
+ # ori_lr_feature = resize(ori_lr_feature, out_shape=(hr_feature.shape[2], hr_feature.shape[3]), antialiasing=True)
186
+ # ori_lr_feature = feature_sizer(ori_lr_feature)
187
+
188
+ # new
189
+ down_x = self.modulation(torch.cat([down_x, out], dim=1))
190
+ new_lr_feature = self.extractor_mid(down_x)["bottom"]
191
+
192
+ new_lr_feature = resize(new_lr_feature, out_shape=(hr_feature.shape[2], hr_feature.shape[3]),
193
+ antialiasing=False)
194
+ # new_lr_feature = resize(new_lr_feature, out_shape=(hr_feature.shape[2], hr_feature.shape[3]), antialiasing=True)
195
+
196
+ # new_lr_feature = feature_sizer(new_lr_feature)
197
+
198
+ # res = resize(self.model(self.modulation(torch.cat([down_x, out], dim=1))), out_shape=up_size,
199
+ # antialiasing=False)
200
+
201
+ # res = up_sizer(self.model(self.modulation(torch.cat([down_x, out], dim=1))))
202
+ res = 0
203
+
204
+ return down_x, hr_feature, \
205
+ new_lr_feature, ori_lr_feature, out, res
206
+
207
+ def inference(self, x, down_size, up_size):
208
+ b = x.shape[0]
209
+ down_x = resize(x, out_shape=down_size, antialiasing=False)
210
+ hr_coord, hr_coord_ = self.creat_coord(x)
211
+ lr_coord, _ = self.creat_coord(down_x)
212
+ q_coord = F.grid_sample(lr_coord, hr_coord_.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)
213
+ q_coord = q_coord.view(b, -1, hr_coord.size(2) * hr_coord.size(3)).permute(0, 2, 1).contiguous()
214
+
215
+ rel_coord = hr_coord_ - q_coord
216
+ rel_coord[:, :, 0] *= down_x.shape[-2]
217
+ rel_coord[:, :, 1] *= down_x.shape[-1]
218
+
219
+ hr_grid = self.get_local_grid(x)
220
+ hr_cell = self.get_cell(x, hr_grid)
221
+
222
+ hr_cell_ = hr_cell.clone()
223
+ hr_cell_ = hr_cell_.permute(0, 2, 3, 1).contiguous()
224
+
225
+ rel_cell = hr_cell_.view(b, -1, hr_cell.size(1))
226
+ rel_cell[:, :, 0] *= down_x.shape[-2]
227
+ rel_cell[:, :, 1] *= down_x.shape[-1]
228
+
229
+ laplacian = x - resize(down_x, out_shape=up_size, antialiasing=False)
230
+ # laplacian = x - resize(down_x, out_shape=up_size, antialiasing=True)
231
+
232
+ laplacian = laplacian.reshape(b, laplacian.size(1), -1).permute(0, 2, 1).contiguous()
233
+ # laplacian = F.unfold(laplacian, 3, padding=1).view(b, -1, laplacian.shape[2] * laplacian.shape[3]).permute(0, 2, 1).contiguous()
234
+
235
+ inp = torch.cat([rel_coord.cuda(), rel_cell.cuda(), laplacian], dim=-1)
236
+ local_weight = self.imnet(inp)
237
+ local_weight = local_weight.type(torch.float32)
238
+ local_weight = local_weight.view(b, -1, x.shape[1] * 9, 3)
239
+
240
+ unfolded_x = F.unfold(x, 3, padding=1).view(b, -1, x.shape[2] * x.shape[3]).permute(0, 2, 1).contiguous()
241
+
242
+ cols = unfolded_x.unsqueeze(2)
243
+
244
+ out = torch.matmul(cols, local_weight).squeeze(2).permute(0, 2, 1).contiguous().view(b, -1, x.size(2),
245
+ x.size(3))
246
+ out = resize(out, out_shape=down_size, antialiasing=False)
247
+ down_x = self.modulation(torch.cat([down_x, out], dim=1))
248
+
249
+ res = resize(self.model(down_x), out_shape=up_size, antialiasing=False)
250
+ return res, down_x
251
+
252
+ def creat_coord(self, x):
253
+ b = x.shape[0]
254
+ coord = make_coord(x.shape[-2:], flatten=False)
255
+ coord = coord.permute(2, 0, 1).contiguous().unsqueeze(0)
256
+ coord = coord.expand(b, 2, *coord.shape[-2:])
257
+
258
+ coord_ = coord.clone()
259
+ coord_ = coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
260
+ coord_ = coord_.permute(0, 2, 3, 1).contiguous()
261
+ coord_ = coord_.view(b, -1, coord.size(1))
262
+ return coord.cuda(), coord_.cuda()
263
+
264
+ def get_local_grid(self, img):
265
+ local_grid = make_coord(img.shape[-2:], flatten=False).cuda()
266
+ local_grid = local_grid.permute(2, 0, 1).unsqueeze(0)
267
+ local_grid = local_grid.expand(img.shape[0], 2, *img.shape[-2:])
268
+
269
+ return local_grid
270
+
271
+ def get_cell(self, img, local_grid):
272
+ cell = torch.ones_like(local_grid)
273
+ cell[:, 0] *= 2 / img.size(2)
274
+ cell[:, 1] *= 2 / img.size(3)
275
+
276
+ return cell
277
+
model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .model import net
2
+ from .resize_right import resize
3
+ from .interp_methods import *
4
+ from .module import Discriminator, Discriminator_new
5
+ from .LMAR_model import *
model/interp_methods.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+
3
+ try:
4
+ import torch
5
+ except ImportError:
6
+ torch = None
7
+
8
+ try:
9
+ import numpy
10
+ except ImportError:
11
+ numpy = None
12
+
13
+ if numpy is None and torch is None:
14
+ raise ImportError("Must have either Numpy or PyTorch but both not found")
15
+
16
+
17
+ def set_framework_dependencies(x):
18
+ if type(x) is numpy.ndarray:
19
+ to_dtype = lambda a: a
20
+ fw = numpy
21
+ else:
22
+ to_dtype = lambda a: a.to(x.dtype)
23
+ fw = torch
24
+ eps = fw.finfo(fw.float32).eps
25
+ return fw, to_dtype, eps
26
+
27
+
28
+ def support_sz(sz):
29
+ def wrapper(f):
30
+ f.support_sz = sz
31
+ return f
32
+ return wrapper
33
+
34
+
35
+ @support_sz(4)
36
+ def cubic(x):
37
+ fw, to_dtype, eps = set_framework_dependencies(x)
38
+ absx = fw.abs(x)
39
+ absx2 = absx ** 2
40
+ absx3 = absx ** 3
41
+ return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +
42
+ (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *
43
+ to_dtype((1. < absx) & (absx <= 2.)))
44
+
45
+
46
+ @support_sz(4)
47
+ def lanczos2(x):
48
+ fw, to_dtype, eps = set_framework_dependencies(x)
49
+ return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) /
50
+ ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2))
51
+
52
+
53
+ @support_sz(6)
54
+ def lanczos3(x):
55
+ fw, to_dtype, eps = set_framework_dependencies(x)
56
+ return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) /
57
+ ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3))
58
+
59
+
60
+ @support_sz(2)
61
+ def linear(x):
62
+ fw, to_dtype, eps = set_framework_dependencies(x)
63
+ return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) *
64
+ to_dtype((0 <= x) & (x <= 1)))
65
+
66
+ @support_sz(1)
67
+ def box(x):
68
+ fw, to_dtype, eps = set_framework_dependencies(x)
69
+ return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))
model/model.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from .module import *
3
+ except:
4
+ from module import *
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.nn.init as init
14
+
15
+ class SuperUnet_MS(nn.Module):
16
+ def __init__(self, channels, block="INV"):
17
+ super(SuperUnet_MS, self).__init__()
18
+ # ---------ENCODE
19
+ self.layer_dowm1 = basic_block(channels, channels, block)
20
+ self.dowm1 = nn.Sequential(nn.Conv2d(channels, channels * 2, 4, 2, 1, bias=True),
21
+ nn.InstanceNorm2d(channels * 2, affine=True), nn.LeakyReLU(0.2, inplace=True))
22
+ self.layer_dowm2 = basic_block(channels * 2, channels * 2, block)
23
+ self.dowm2 = nn.Sequential(nn.Conv2d(channels * 2, channels * 4, 4, 2, 1, bias=True),
24
+ nn.InstanceNorm2d(channels * 4, affine=True), nn.LeakyReLU(0.2, inplace=True))
25
+ # ---------DECODE
26
+ self.layer_bottom = basic_block(channels * 4, channels * 4, block)
27
+ self.up2 = nn.Sequential(nn.ConvTranspose2d(channels * 4, channels * 2, 4, 2, 1, bias=True),
28
+ nn.InstanceNorm2d(channels * 2, affine=True), nn.LeakyReLU(0.2, inplace=True))
29
+ self.layer_up2 = basic_block(channels * 2, channels * 2, block)
30
+ self.up1 = nn.Sequential(nn.ConvTranspose2d(channels * 2, channels, 4, 2, 1, bias=True),
31
+ nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
32
+ self.layer_up1 = basic_block(channels, channels, block)
33
+ # ---------SKIP
34
+ self.fus2 = skip(channels * 4, channels * 2, "HIN")
35
+ self.fus1 = skip(channels * 2, channels, "HIN")
36
+ # ---------SKIP
37
+ self.skip_down1 = nn.Sequential(nn.Conv2d(channels, channels, 4, 2, 1, bias=True),
38
+ nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
39
+ self.skip1 = skip(channels * 3, channels * 2, "CONV")
40
+ self.skip_down2 = nn.Sequential(nn.Conv2d(channels * 2, channels, 4, 2, 1, bias=True),
41
+ nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
42
+ self.skip2 = skip(channels * 5, channels * 4, "CONV")
43
+ # self.skip3 = skip(channels*2, channels, "CONV")
44
+ self.skip_up4 = nn.Sequential(nn.ConvTranspose2d(channels * 4, channels, 4, 2, 1, bias=True),
45
+ nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
46
+ self.skip4 = skip(channels * 3, channels * 2, "CONV")
47
+ # self.skip5 = skip(channels*2, channels, "CONV")
48
+ self.skip_up6 = nn.Sequential(nn.ConvTranspose2d(channels * 2, channels, 4, 2, 1, bias=True),
49
+ nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
50
+ self.skip6 = skip(channels * 2, channels, "CONV")
51
+
52
+ def forward(self, x):
53
+ # ---------ENCODE
54
+ x_11 = self.layer_dowm1(x)
55
+ x_down1 = self.dowm1(x_11)
56
+ # x =self.skip_down1(x)
57
+ # print(x.shape, x_down1.shape)
58
+
59
+ x_down1 = self.skip1(torch.cat([self.skip_down1(x), x_down1], 1), x_down1)
60
+
61
+ x_12 = self.layer_dowm2(x_down1)
62
+ x_down2 = self.dowm2(x_12)
63
+ x_down2 = self.skip2(torch.cat([self.skip_down2(x_down1), x_down2], 1), x_down2)
64
+
65
+ x_bottom = self.layer_bottom(x_down2)
66
+
67
+ # ---------DECODE
68
+ x_up2 = self.up2(x_bottom)
69
+ x_22 = self.layer_up2(x_up2)
70
+ x_22 = self.skip4(torch.cat([self.skip_up4(x_bottom), x_22], 1), x_22)
71
+ x_22 = self.fus2(torch.cat([x_12, x_22], 1), x_22)
72
+
73
+ x_up1 = self.up1(x_22)
74
+ x_21 = self.layer_up1(x_up1)
75
+ x_21 = self.skip6(torch.cat([self.skip_up6(x_22), x_21], 1), x_21)
76
+ x_21 = self.fus1(torch.cat([x_11, x_21], 1), x_21)
77
+ return x_21, x_down2
78
+
79
+
80
+ class skip(nn.Module):
81
+ def __init__(self, channels_in, channels_out, block):
82
+ super(skip, self).__init__()
83
+ if block == "CONV":
84
+ self.body = nn.Sequential(nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=True),
85
+ nn.InstanceNorm2d(channels_out, affine=True), nn.ReLU(inplace=True), )
86
+ if block == "ID":
87
+ self.body = nn.Identity()
88
+ if block == "INV":
89
+ self.body = nn.Sequential(InvBlock(channels_in, channels_in // 2),
90
+ nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=True), )
91
+ if block == "HIN":
92
+ self.body = nn.Sequential(HinBlock(channels_in, channels_out))
93
+ # --------------------------------------
94
+ self.alpha1 = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
95
+ self.alpha1.data.fill_(1.0)
96
+ self.alpha2 = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
97
+ self.alpha2.data.fill_(0.5)
98
+
99
+ def forward(self, x, y):
100
+ out = self.alpha1 * self.body(x) + self.alpha2 * y
101
+ return out
102
+
103
+
104
+ def subnet(net_structure, init='xavier'):
105
+ def constructor(channel_in, channel_out):
106
+ if net_structure == 'HIN':
107
+ return HinBlock(channel_in, channel_out)
108
+
109
+ return constructor
110
+
111
+
112
+ class InvBlock(nn.Module):
113
+ def __init__(self, channel_num, channel_split_num, subnet_constructor=subnet('HIN'),
114
+ clamp=0.8): ################ split_channel一般设为channel_num的一半
115
+ super(InvBlock, self).__init__()
116
+ # channel_num: 3
117
+ # channel_split_num: 1
118
+
119
+ self.split_len1 = channel_split_num # 1
120
+ self.split_len2 = channel_num - channel_split_num # 2
121
+
122
+ self.clamp = clamp
123
+
124
+ self.F = subnet_constructor(self.split_len2, self.split_len1)
125
+ self.G = subnet_constructor(self.split_len1, self.split_len2)
126
+ self.H = subnet_constructor(self.split_len1, self.split_len2)
127
+
128
+ def forward(self, x):
129
+ x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2))
130
+
131
+ y1 = x1 + self.F(x2) # 1 channel
132
+ self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1)
133
+ y2 = x2.mul(torch.exp(self.s)) + self.G(y1) # 2 channel
134
+ out = torch.cat((y1, y2), 1)
135
+
136
+ return out + x
137
+
138
+
139
+ class sample_block(nn.Module):
140
+ def __init__(self, channels_in, channels_out, size, dil):
141
+ super(sample_block, self).__init__()
142
+ # ------------------------------------------
143
+ if size == "DOWN":
144
+ self.conv = nn.Sequential(
145
+ nn.Conv2d(channels_in, channels_out, 3, 1, dil, dilation=dil),
146
+ nn.InstanceNorm2d(channels_out, affine=True),
147
+ nn.ReLU(inplace=True),
148
+ )
149
+ if size == "UP":
150
+ self.conv = nn.Sequential(
151
+ nn.ConvTranspose2d(channels_in, channels_out, 3, 1, dil, dilation=dil),
152
+ nn.InstanceNorm2d(channels_out, affine=True),
153
+ nn.ReLU(inplace=True),
154
+ )
155
+
156
+ def forward(self, x):
157
+ return self.conv(x)
158
+
159
+
160
+ class HinBlock(nn.Module):
161
+ def __init__(self, in_size, out_size):
162
+ super(HinBlock, self).__init__()
163
+ self.identity = nn.Conv2d(in_size, out_size, 1, 1, 0)
164
+ self.norm = nn.InstanceNorm2d(out_size // 2, affine=True)
165
+
166
+ self.conv_1 = nn.Conv2d(in_size, out_size, kernel_size=3, stride=1, padding=1, bias=True)
167
+ self.relu_1 = nn.Sequential(nn.LeakyReLU(0.2, inplace=False), )
168
+ self.conv_2 = nn.Sequential(nn.Conv2d(out_size, out_size, kernel_size=3, stride=1, padding=1, bias=True),
169
+ nn.LeakyReLU(0.2, inplace=False), )
170
+
171
+ def forward(self, x):
172
+ out = self.conv_1(x)
173
+ out_1, out_2 = torch.chunk(out, 2, dim=1)
174
+ out = torch.cat([self.norm(out_1), out_2], dim=1)
175
+ out = self.relu_1(out)
176
+ out = self.conv_2(out)
177
+ out += self.identity(x)
178
+ return out
179
+
180
+
181
+ class net(nn.Module):
182
+ def __init__(self, args):
183
+ super().__init__()
184
+ self.args = args.model
185
+ self.hr_inc = DoubleConv(self.args["in_channel"], self.args["model_channel"] * 2)
186
+ self.hr_backbone = SuperUnet_MS(self.args["model_channel"] * 2)
187
+ self.final_out = nn.Conv2d(self.args["model_channel"] * 2, 3, kernel_size=1, bias=False)
188
+
189
+ def forward(self, x):
190
+ x = self.hr_inc(x)
191
+ x, mid_feat = self.hr_backbone(x)
192
+ out = self.final_out(x)
193
+ return out
194
+
model/module.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from torchvision.transforms.functional import rgb_to_grayscale
6
+ import numpy as np
7
+
8
+ class DoubleConv(nn.Module):
9
+ """(convolution => [BN] => ReLU) * 2"""
10
+
11
+ def __init__(self, in_channels, out_channels, mid_channels=None):
12
+ super().__init__()
13
+ if not mid_channels:
14
+ mid_channels = out_channels
15
+ self.double_conv = nn.Sequential(
16
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
17
+ nn.ReLU(inplace=True),
18
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
19
+ nn.ReLU(inplace=True)
20
+ )
21
+ self.apply(self._init_weights)
22
+
23
+ def forward(self, x):
24
+ return self.double_conv(x)
25
+
26
+ def _init_weights(self, m):
27
+ if isinstance(m, nn.Conv2d):
28
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
29
+ m.weight.data.normal_(0, math.sqrt(2. / n))
30
+ elif isinstance(m, nn.BatchNorm2d):
31
+ m.weight.data.fill_(1)
32
+ m.bias.data.zero_()
33
+
34
+
35
+ class Down(nn.Module):
36
+ """Downscaling with maxpool then double conv"""
37
+
38
+ def __init__(self, in_channels, out_channels):
39
+ super().__init__()
40
+ self.maxpool_conv = nn.Sequential(
41
+ nn.MaxPool2d(2),
42
+ DoubleConv(in_channels, out_channels)
43
+ )
44
+
45
+ def forward(self, x):
46
+ return self.maxpool_conv(x)
47
+
48
+
49
+ class Up(nn.Module):
50
+ """Upscaling then double conv"""
51
+
52
+ def __init__(self, in_channels, out_channels, bilinear=True):
53
+ super().__init__()
54
+
55
+ # if bilinear, use the normal convolutions to reduce the number of channels
56
+ if bilinear:
57
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
58
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
59
+ else:
60
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
61
+ self.conv = DoubleConv(in_channels, out_channels)
62
+
63
+ def forward(self, x1, x2):
64
+ x1 = self.up(x1)
65
+ # input is CHW
66
+ diffY = x2.size()[2] - x1.size()[2]
67
+ diffX = x2.size()[3] - x1.size()[3]
68
+
69
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
70
+ diffY // 2, diffY - diffY // 2])
71
+ # if you have padding issues, see
72
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
73
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
74
+ x = torch.cat([x2, x1], dim=1)
75
+ return self.conv(x)
76
+
77
+
78
+ # spatial attention
79
+ class SpatialGate(nn.Module):
80
+ def __init__(self, in_channels):
81
+ super(SpatialGate, self).__init__()
82
+ self.spatial = nn.Conv2d(in_channels, 1, kernel_size=1)
83
+ self.sigmoid = nn.Sigmoid()
84
+
85
+ def forward(self, x):
86
+ x_out = self.spatial(x)
87
+ scale = self.sigmoid(x_out)
88
+ return scale * x
89
+
90
+
91
+ # sobel
92
+ class SobelOperator(nn.Module):
93
+ def __init__(self):
94
+ super(SobelOperator, self).__init__()
95
+ self.conv_x = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
96
+ self.conv_y = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
97
+ self.conv_x.weight[0].data[:, :, :] = torch.FloatTensor([[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]])
98
+ self.conv_y.weight[0].data[:, :, :] = torch.FloatTensor([[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]])
99
+
100
+ def forward(self, x):
101
+ G_x = self.conv_x(x)
102
+ G_y = self.conv_y(x)
103
+ grad_mag = torch.sqrt(torch.pow(G_x, 2) + torch.pow(G_y, 2))
104
+ return grad_mag
105
+
106
+
107
+ class offset_estimator(nn.Sequential):
108
+ def __init__(self, kernel_size, fwhm, in_channel, mid_channel, out_channel) -> None:
109
+ super().__init__()
110
+ model = []
111
+ assert len(kernel_size) == len(fwhm), "length error"
112
+ for i in range(len(kernel_size)):
113
+ if i == 0:
114
+ gaussian_weight = torch.FloatTensor(gaussian_2d(kernel_size[i], fwhm=fwhm[i]))
115
+ gauss_filter = nn.Conv2d(in_channel, mid_channel, kernel_size[i], padding=(kernel_size[i] - 1) // 2,
116
+ bias=False)
117
+ gauss_filter.weight[0].data[:, :, :] = gaussian_weight
118
+ model += [gauss_filter, nn.ReLU(inplace=True)]
119
+ elif i == len(kernel_size) - 1:
120
+ gaussian_weight = torch.FloatTensor(gaussian_2d(kernel_size[i], fwhm=fwhm[i]))
121
+ gauss_filter = nn.Conv2d(mid_channel, out_channel, kernel_size[i], padding=(kernel_size[i] - 1) // 2,
122
+ bias=False)
123
+ gauss_filter.weight[0].data[:, :, :] = gaussian_weight
124
+ model += [gauss_filter, nn.ReLU(inplace=True)]
125
+ else:
126
+ gaussian_weight = torch.FloatTensor(gaussian_2d(kernel_size[i], fwhm=fwhm[i]))
127
+ gauss_filter = nn.Conv2d(mid_channel, mid_channel, kernel_size[i], padding=(kernel_size[i] - 1) // 2,
128
+ bias=False)
129
+ gauss_filter.weight[0].data[:, :, :] = gaussian_weight
130
+ model += [gauss_filter, nn.ReLU(inplace=True)]
131
+ self.model = nn.Sequential(*model)
132
+
133
+ def forward(self, x):
134
+ return self.model(x)
135
+
136
+
137
+ # Channel attention
138
+ def logsumexp_2d(tensor):
139
+ tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
140
+ s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
141
+ outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
142
+ return outputs
143
+
144
+
145
+ class Flatten(nn.Module):
146
+ def forward(self, x):
147
+ return x.view(x.size(0), -1)
148
+
149
+
150
+ class ChannelGate(nn.Module):
151
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
152
+ super(ChannelGate, self).__init__()
153
+ self.gate_channels = gate_channels
154
+ self.mlp = nn.Sequential(
155
+ Flatten(),
156
+ nn.Linear(gate_channels, gate_channels // reduction_ratio),
157
+ nn.ReLU(),
158
+ nn.Linear(gate_channels // reduction_ratio, gate_channels)
159
+ )
160
+ self.pool_types = pool_types
161
+
162
+ def forward(self, x):
163
+ channel_att_sum = None
164
+ for pool_type in self.pool_types:
165
+ if pool_type == 'avg':
166
+ avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
167
+ channel_att_raw = self.mlp(avg_pool)
168
+ elif pool_type == 'max':
169
+ max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
170
+ channel_att_raw = self.mlp(max_pool)
171
+ elif pool_type == 'lp':
172
+ lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
173
+ channel_att_raw = self.mlp(lp_pool)
174
+ elif pool_type == 'lse':
175
+ # LSE pool only
176
+ lse_pool = logsumexp_2d(x)
177
+ channel_att_raw = self.mlp(lse_pool)
178
+
179
+ if channel_att_sum is None:
180
+ channel_att_sum = channel_att_raw
181
+ else:
182
+ channel_att_sum = channel_att_sum + channel_att_raw
183
+
184
+ scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
185
+ return x * scale
186
+
187
+
188
+ # LBP
189
+ def LBP(image): # b, 3, h, w tensor
190
+ radius = 2
191
+ n_points = 8 * radius
192
+ method = 'uniform'
193
+ gray_img = rgb_to_grayscale(image) # b, 1, h, w
194
+ gray_img = gray_img.squeeze(1)
195
+ lbf_feature = np.zeros((gray_img.shape[0], gray_img.shape[1], gray_img.shape[2]))
196
+ for i in range(gray_img.shape[0]):
197
+ lbf_feature[i] = feature.local_binary_pattern(gray_img[i], n_points, radius, method)
198
+ return torch.FloatTensor(lbf_feature).unsqueeze(1)
199
+
200
+
201
+ class Discriminator(nn.Module):
202
+ def __init__(self, in_channel):
203
+ super().__init__()
204
+ self.in_channel = in_channel
205
+
206
+ def discriminator_block(in_filters, out_filters):
207
+ layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=False)]
208
+ return layers
209
+
210
+ self.model = nn.Sequential(
211
+ *discriminator_block(self.in_channel, 4),
212
+ *discriminator_block(4, 4),
213
+ *discriminator_block(4, 4),
214
+ *discriminator_block(4, 4),
215
+ nn.ZeroPad2d((1, 0, 1, 0)),
216
+ nn.Conv2d(4, 1, 4, padding=1, bias=False)
217
+ )
218
+
219
+ def forward(self, x):
220
+ return self.model(x)
221
+
222
+
223
+ class Discriminator_new(nn.Module):
224
+ def __init__(self):
225
+ super().__init__()
226
+
227
+ def discriminator_block(in_filters, out_filters, first_block=False):
228
+ layers = []
229
+ layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
230
+ layers.append(nn.LeakyReLU(0.2, inplace=True))
231
+ layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
232
+ layers.append(nn.LeakyReLU(0.2, inplace=True))
233
+ return layers
234
+
235
+ layers = []
236
+ in_filters = 3
237
+ for i, out_filters in enumerate([4, 6, 8, 10]):
238
+ layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
239
+ in_filters = out_filters
240
+
241
+ layers.append(nn.ZeroPad2d((1, 0, 1, 0)))
242
+ layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))
243
+
244
+ self.model = nn.Sequential(*layers)
245
+
246
+ def forward(self, img):
247
+ return self.model(img)
248
+
model/resize_right.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import warnings
3
+ from math import ceil
4
+
5
+ try:
6
+ from .interp_methods import *
7
+ except:
8
+ from interp_methods import *
9
+ from fractions import Fraction
10
+
11
+
12
+ class NoneClass:
13
+ pass
14
+
15
+
16
+ try:
17
+ import torch
18
+ from torch import nn
19
+
20
+ nnModuleWrapped = nn.Module
21
+ except ImportError:
22
+ warnings.warn('No PyTorch found, will work only with Numpy')
23
+ torch = None
24
+ nnModuleWrapped = NoneClass
25
+
26
+ try:
27
+ import numpy
28
+ except ImportError:
29
+ warnings.warn('No Numpy found, will work only with PyTorch')
30
+ numpy = None
31
+
32
+ if numpy is None and torch is None:
33
+ raise ImportError("Must have either Numpy or PyTorch but both not found")
34
+
35
+
36
+ def resize(input, scale_factors=None, out_shape=None,
37
+ interp_method=lanczos3, support_sz=None,
38
+ antialiasing=True, by_convs=False, scale_tolerance=None,
39
+ max_numerator=10, pad_mode='constant', adv_weights=None):
40
+ # get properties of the input tensor
41
+ in_shape, n_dims = input.shape, input.ndim
42
+
43
+ # fw stands for framework that can be either numpy or torch,
44
+ # determined by the input type
45
+ fw = numpy if type(input) is numpy.ndarray else torch
46
+ eps = fw.finfo(fw.float32).eps
47
+ device = input.device if fw is torch else None
48
+ weights_container = []
49
+
50
+ # set missing scale factors or output shapem one according to another,
51
+ # scream if both missing. this is also where all the defults policies
52
+ # take place. also handling the by_convs attribute carefully.
53
+ scale_factors, out_shape, by_convs = set_scale_and_out_sz(in_shape,
54
+ out_shape,
55
+ scale_factors,
56
+ by_convs,
57
+ scale_tolerance,
58
+ max_numerator,
59
+ eps, fw)
60
+
61
+ # sort indices of dimensions according to scale of each dimension.
62
+ # since we are going dim by dim this is efficient
63
+ sorted_filtered_dims_and_scales = [(dim, scale_factors[dim], by_convs[dim],
64
+ in_shape[dim], out_shape[dim])
65
+ for dim in sorted(range(n_dims),
66
+ key=lambda ind: scale_factors[ind])
67
+ if scale_factors[dim] != 1.]
68
+
69
+ # unless support size is specified by the user, it is an attribute
70
+ # of the interpolation method
71
+ if support_sz is None:
72
+ support_sz = interp_method.support_sz
73
+
74
+ # output begins identical to input and changes with each iteration
75
+ output = input
76
+
77
+ # iterate over dims
78
+ for i, (dim, scale_factor, dim_by_convs, in_sz, out_sz
79
+ ) in enumerate(sorted_filtered_dims_and_scales):
80
+ # STEP 1- PROJECTED GRID: The non-integer locations of the projection
81
+ # of output pixel locations to the input tensor
82
+ projected_grid = get_projected_grid(in_sz, out_sz,
83
+ scale_factor, fw, dim_by_convs,
84
+ device)
85
+
86
+ # STEP 1.5: ANTIALIASING- If antialiasing is taking place, we modify
87
+ # the window size and the interpolation method (see inside function)
88
+ cur_interp_method, cur_support_sz = apply_antialiasing_if_needed(
89
+ interp_method,
90
+ support_sz,
91
+ scale_factor,
92
+ antialiasing)
93
+
94
+ # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels
95
+ # that influence it. Also calculate needed padding and update grid
96
+ # accoedingly
97
+ field_of_view = get_field_of_view(projected_grid, cur_support_sz, fw,
98
+ eps, device)
99
+
100
+ # STEP 2.5- CALCULATE PAD AND UPDATE: according to the field of view,
101
+ # the input should be padded to handle the boundaries, coordinates
102
+ # should be updated. actual padding only occurs when weights are
103
+ # aplied (step 4). if using by_convs for this dim, then we need to
104
+ # calc right and left boundaries for each filter instead.
105
+ pad_sz, projected_grid, field_of_view = calc_pad_sz(in_sz, out_sz,
106
+ field_of_view,
107
+ projected_grid,
108
+ scale_factor,
109
+ dim_by_convs, fw,
110
+ device)
111
+
112
+ # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in
113
+ # the field of view for each output pixel
114
+ if adv_weights != None:
115
+ weights = adv_weights[i]
116
+ else:
117
+ weights = get_weights(cur_interp_method, projected_grid, field_of_view)
118
+ weights_container.append(weights)
119
+
120
+ # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying
121
+ # its set of weights with the pixel values in its field of view.
122
+ # We now multiply the fields of view with their matching weights.
123
+ # We do this by tensor multiplication and broadcasting.
124
+ # if by_convs is true for this dim, then we do this action by
125
+ # convolutions. this is equivalent but faster.
126
+ if not dim_by_convs:
127
+ output = apply_weights(output, field_of_view, weights, dim, n_dims,
128
+ pad_sz, pad_mode, fw)
129
+ else:
130
+ output = apply_convs(output, scale_factor, in_sz, out_sz, weights,
131
+ dim, pad_sz, pad_mode, fw)
132
+ return output
133
+
134
+
135
+ def get_projected_grid(in_sz, out_sz, scale_factor, fw, by_convs, device=None):
136
+ # we start by having the ouput coordinates which are just integer locations
137
+ # in the special case when usin by_convs, we only need two cycles of grid
138
+ # points. the first and last.
139
+ grid_sz = out_sz if not by_convs else scale_factor.numerator
140
+ out_coordinates = fw_arange(grid_sz, fw, device)
141
+
142
+ # This is projecting the ouput pixel locations in 1d to the input tensor,
143
+ # as non-integer locations.
144
+ # the following fomrula is derived in the paper
145
+ # "From Discrete to Continuous Convolutions" by Shocher et al.
146
+ return (out_coordinates / float(scale_factor) +
147
+ (in_sz - 1) / 2 - (out_sz - 1) / (2 * float(scale_factor)))
148
+
149
+
150
+ def get_field_of_view(projected_grid, cur_support_sz, fw, eps, device):
151
+ # for each output pixel, map which input pixels influence it, in 1d.
152
+ # we start by calculating the leftmost neighbor, using half of the window
153
+ # size (eps is for when boundary is exact int)
154
+ left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw)
155
+
156
+ # then we simply take all the pixel centers in the field by counting
157
+ # window size pixels from the left boundary
158
+ ordinal_numbers = fw_arange(ceil(cur_support_sz - eps), fw, device)
159
+ return left_boundaries[:, None] + ordinal_numbers
160
+
161
+
162
+ def calc_pad_sz(in_sz, out_sz, field_of_view, projected_grid, scale_factor,
163
+ dim_by_convs, fw, device):
164
+ if not dim_by_convs:
165
+ # determine padding according to neighbor coords out of bound.
166
+ # this is a generalized notion of padding, when pad<0 it means crop
167
+ pad_sz = [-field_of_view[0, 0].item(),
168
+ field_of_view[-1, -1].item() - in_sz + 1]
169
+
170
+ # since input image will be changed by padding, coordinates of both
171
+ # field_of_view and projected_grid need to be updated
172
+ field_of_view += pad_sz[0]
173
+ projected_grid += pad_sz[0]
174
+
175
+ else:
176
+ # only used for by_convs, to calc the boundaries of each filter the
177
+ # number of distinct convolutions is the numerator of the scale factor
178
+ num_convs, stride = scale_factor.numerator, scale_factor.denominator
179
+
180
+ # calculate left and right boundaries for each conv. left can also be
181
+ # negative right can be bigger than in_sz. such cases imply padding if
182
+ # needed. however if# both are in-bounds, it means we need to crop,
183
+ # practically apply the conv only on part of the image.
184
+ left_pads = -field_of_view[:, 0]
185
+
186
+ # next calc is tricky, explanation by rows:
187
+ # 1) counting output pixels between the first position of each filter
188
+ # to the right boundary of the input
189
+ # 2) dividing it by number of filters to count how many 'jumps'
190
+ # each filter does
191
+ # 3) multiplying by the stride gives us the distance over the input
192
+ # coords done by all these jumps for each filter
193
+ # 4) to this distance we add the right boundary of the filter when
194
+ # placed in its leftmost position. so now we get the right boundary
195
+ # of that filter in input coord.
196
+ # 5) the padding size needed is obtained by subtracting the rightmost
197
+ # input coordinate. if the result is positive padding is needed. if
198
+ # negative then negative padding means shaving off pixel columns.
199
+ right_pads = (((out_sz - fw_arange(num_convs, fw, device) - 1) # (1)
200
+ // num_convs) # (2)
201
+ * stride # (3)
202
+ + field_of_view[:, -1] # (4)
203
+ - in_sz + 1) # (5)
204
+
205
+ # in the by_convs case pad_sz is a list of left-right pairs. one per
206
+ # each filter
207
+
208
+ pad_sz = list(zip(left_pads, right_pads))
209
+
210
+ return pad_sz, projected_grid, field_of_view
211
+
212
+
213
+ def get_weights(interp_method, projected_grid, field_of_view):
214
+ # the set of weights per each output pixels is the result of the chosen
215
+ # interpolation method applied to the distances between projected grid
216
+ # locations and the pixel-centers in the field of view (distances are
217
+ # directed, can be positive or negative)
218
+ weights = interp_method(projected_grid[:, None] - field_of_view)
219
+
220
+ # we now carefully normalize the weights to sum to 1 per each output pixel
221
+ sum_weights = weights.sum(1, keepdims=True)
222
+ sum_weights[sum_weights == 0] = 1
223
+ return weights / sum_weights
224
+
225
+
226
+ def apply_weights(input, field_of_view, weights, dim, n_dims, pad_sz, pad_mode,
227
+ fw):
228
+ # for this operation we assume the resized dim is the first one.
229
+ # so we transpose and will transpose back after multiplying
230
+ tmp_input = fw_swapaxes(input, dim, 0, fw)
231
+
232
+ # apply padding
233
+ tmp_input = fw_pad(tmp_input, fw, pad_sz, pad_mode)
234
+
235
+ # field_of_view is a tensor of order 2: for each output (1d location
236
+ # along cur dim)- a list of 1d neighbors locations.
237
+ # note that this whole operations is applied to each dim separately,
238
+ # this is why it is all in 1d.
239
+ # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1:
240
+ # for each output pixel (this time indicated in all dims), these are the
241
+ # values of the neighbors in the 1d field of view. note that we only
242
+ # consider neighbors along the current dim, but such set exists for every
243
+ # multi-dim location, hence the final tensor order is image_dims+1.
244
+ neighbors = tmp_input[field_of_view]
245
+
246
+ # weights is an order 2 tensor: for each output location along 1d- a list
247
+ # of weights matching the field of view. we augment it with ones, for
248
+ # broadcasting, so that when multiplies some tensor the weights affect
249
+ # only its first dim.
250
+ tmp_weights = fw.reshape(weights, (*weights.shape, *[1] * (n_dims - 1)))
251
+
252
+ # now we simply multiply the weights with the neighbors, and then sum
253
+ # along the field of view, to get a single value per out pixel
254
+ tmp_output = (neighbors * tmp_weights).sum(1)
255
+
256
+ # we transpose back the resized dim to its original position
257
+ return fw_swapaxes(tmp_output, 0, dim, fw)
258
+
259
+
260
+ def apply_convs(input, scale_factor, in_sz, out_sz, weights, dim, pad_sz,
261
+ pad_mode, fw):
262
+ # for this operations we assume the resized dim is the last one.
263
+ # so we transpose and will transpose back after multiplying
264
+ input = fw_swapaxes(input, dim, -1, fw)
265
+
266
+ # the stride for all convs is the denominator of the scale factor
267
+ stride, num_convs = scale_factor.denominator, scale_factor.numerator
268
+
269
+ # prepare an empty tensor for the output
270
+ tmp_out_shape = list(input.shape)
271
+ tmp_out_shape[-1] = out_sz
272
+ tmp_output = fw_empty(tuple(tmp_out_shape), fw, input.device)
273
+
274
+ # iterate over the conv operations. we have as many as the numerator
275
+ # of the scale-factor. for each we need boundaries and a filter.
276
+ for conv_ind, (pad_sz, filt) in enumerate(zip(pad_sz, weights)):
277
+ # apply padding (we pad last dim, padding can be negative)
278
+ pad_dim = input.ndim - 1
279
+ tmp_input = fw_pad(input, fw, pad_sz, pad_mode, dim=pad_dim)
280
+
281
+ # apply convolution over last dim. store in the output tensor with
282
+ # positional strides so that when the loop is comlete conv results are
283
+ # interwind
284
+ tmp_output[..., conv_ind::num_convs] = fw_conv(tmp_input, filt, stride)
285
+
286
+ return fw_swapaxes(tmp_output, -1, dim, fw)
287
+
288
+
289
+ def set_scale_and_out_sz(in_shape, out_shape, scale_factors, by_convs,
290
+ scale_tolerance, max_numerator, eps, fw):
291
+ # eventually we must have both scale-factors and out-sizes for all in/out
292
+ # dims. however, we support many possible partial arguments
293
+ if scale_factors is None and out_shape is None:
294
+ raise ValueError("either scale_factors or out_shape should be "
295
+ "provided")
296
+ if out_shape is not None:
297
+ # if out_shape has less dims than in_shape, we defaultly resize the
298
+ # first dims for numpy and last dims for torch
299
+ out_shape = (list(out_shape) + list(in_shape[len(out_shape):])
300
+ if fw is numpy
301
+ else list(in_shape[:-len(out_shape)]) + list(out_shape))
302
+ if scale_factors is None:
303
+ # if no scale given, we calculate it as the out to in ratio
304
+ # (not recomended)
305
+ scale_factors = [out_sz / in_sz for out_sz, in_sz
306
+ in zip(out_shape, in_shape)]
307
+ if scale_factors is not None:
308
+ # by default, if a single number is given as scale, we assume resizing
309
+ # two dims (most common are images with 2 spatial dims)
310
+ scale_factors = (scale_factors
311
+ if isinstance(scale_factors, (list, tuple))
312
+ else [scale_factors, scale_factors])
313
+ # if less scale_factors than in_shape dims, we defaultly resize the
314
+ # first dims for numpy and last dims for torch
315
+ scale_factors = (list(scale_factors) + [1] *
316
+ (len(in_shape) - len(scale_factors)) if fw is numpy
317
+ else [1] * (len(in_shape) - len(scale_factors)) +
318
+ list(scale_factors))
319
+ if out_shape is None:
320
+ # when no out_shape given, it is calculated by multiplying the
321
+ # scale by the in_shape (not recomended)
322
+ out_shape = [ceil(scale_factor * in_sz)
323
+ for scale_factor, in_sz in
324
+ zip(scale_factors, in_shape)]
325
+ # next part intentionally after out_shape determined for stability
326
+ # we fix by_convs to be a list of truth values in case it is not
327
+ if not isinstance(by_convs, (list, tuple)):
328
+ by_convs = [by_convs] * len(out_shape)
329
+
330
+ # next loop fixes the scale for each dim to be either frac or float.
331
+ # this is determined by by_convs and by tolerance for scale accuracy.
332
+ for ind, (sf, dim_by_convs) in enumerate(zip(scale_factors, by_convs)):
333
+ # first we fractionaize
334
+ if dim_by_convs:
335
+ frac = Fraction(1 / sf).limit_denominator(max_numerator)
336
+ frac = Fraction(numerator=frac.denominator, denominator=frac.numerator)
337
+
338
+ # if accuracy is within tolerance scale will be frac. if not, then
339
+ # it will be float and the by_convs attr will be set false for
340
+ # this dim
341
+ if scale_tolerance is None:
342
+ scale_tolerance = eps
343
+ if dim_by_convs and abs(frac - sf) < scale_tolerance:
344
+ scale_factors[ind] = frac
345
+ else:
346
+ scale_factors[ind] = float(sf)
347
+ by_convs[ind] = False
348
+
349
+ return scale_factors, out_shape, by_convs
350
+
351
+
352
+ def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,
353
+ antialiasing):
354
+ # antialiasing is "stretching" the field of view according to the scale
355
+ # factor (only for downscaling). this is low-pass filtering. this
356
+ # requires modifying both the interpolation (stretching the 1d
357
+ # function and multiplying by the scale-factor) and the window size.
358
+ scale_factor = float(scale_factor)
359
+ if scale_factor >= 1.0 or not antialiasing:
360
+ return interp_method, support_sz
361
+ cur_interp_method = (lambda arg: scale_factor *
362
+ interp_method(scale_factor * arg))
363
+ cur_support_sz = support_sz / scale_factor
364
+ return cur_interp_method, cur_support_sz
365
+
366
+
367
+ def fw_ceil(x, fw):
368
+ if fw is numpy:
369
+ return fw.int_(fw.ceil(x))
370
+ else:
371
+ return x.ceil().long()
372
+
373
+
374
+ def fw_floor(x, fw):
375
+ if fw is numpy:
376
+ return fw.int_(fw.floor(x))
377
+ else:
378
+ return x.floor().long()
379
+
380
+
381
+ def fw_cat(x, fw):
382
+ if fw is numpy:
383
+ return fw.concatenate(x)
384
+ else:
385
+ return fw.cat(x)
386
+
387
+
388
+ def fw_swapaxes(x, ax_1, ax_2, fw):
389
+ if fw is numpy:
390
+ return fw.swapaxes(x, ax_1, ax_2)
391
+ else:
392
+ return x.transpose(ax_1, ax_2)
393
+
394
+
395
+ def fw_pad(x, fw, pad_sz, pad_mode, dim=0):
396
+ if pad_sz == (0, 0):
397
+ return x
398
+ if fw is numpy:
399
+ pad_vec = [(0, 0)] * x.ndim
400
+ pad_vec[dim] = pad_sz
401
+ return fw.pad(x, pad_width=pad_vec, mode=pad_mode)
402
+ else:
403
+ if x.ndim < 3:
404
+ x = x[None, None, ...]
405
+
406
+ pad_vec = [0] * ((x.ndim - 2) * 2)
407
+ pad_vec[0:2] = pad_sz
408
+ return fw.nn.functional.pad(x.transpose(dim, -1), pad=pad_vec,
409
+ mode=pad_mode).transpose(dim, -1)
410
+
411
+
412
+ def fw_conv(input, filter, stride):
413
+ # we want to apply 1d conv to any nd array. the way to do it is to reshape
414
+ # the input to a 4D tensor. first two dims are singeletons, 3rd dim stores
415
+ # all the spatial dims that we are not convolving along now. then we can
416
+ # apply conv2d with a 1xK filter. This convolves the same way all the other
417
+ # dims stored in the 3d dim. like depthwise conv over these.
418
+ # TODO: numpy support
419
+ reshaped_input = input.reshape(1, 1, -1, input.shape[-1])
420
+ reshaped_output = torch.nn.functional.conv2d(reshaped_input,
421
+ filter.view(1, 1, 1, -1),
422
+ stride=(1, stride))
423
+ return reshaped_output.reshape(*input.shape[:-1], -1)
424
+
425
+
426
+ def fw_arange(upper_bound, fw, device):
427
+ if fw is numpy:
428
+ return fw.arange(upper_bound)
429
+ else:
430
+ return fw.arange(upper_bound, device=device)
431
+
432
+
433
+ def fw_empty(shape, fw, device):
434
+ if fw is numpy:
435
+ return fw.empty(shape)
436
+ else:
437
+ return fw.empty(size=(*shape,), device=device)
pretrained_models/LMAR_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27f1ada04c3297053af030ec2547a06f54d5de5e1ec20f3b430a9dd2f2f666ff
3
+ size 1475245
pretrained_models/base_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45d49de91c08e7a6080d60f7059482fcd443377982e1908045625759e5931772
3
+ size 3417093
utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.utils as vutils
2
+ import argparse
3
+ import yaml
4
+ import torch
5
+ import torchvision
6
+ from metrics import calculate_psnr, calculate_ssim
7
+ import torchvision.transforms as transforms
8
+ import numpy as np
9
+ from torch.optim.lr_scheduler import _LRScheduler
10
+ import math
11
+
12
+
13
+ class AverageMeter(object):
14
+ """Computes and stores the average and current value"""
15
+
16
+ def __init__(self):
17
+ self.reset()
18
+
19
+ def reset(self):
20
+ self.val = 0
21
+ self.avg = 0
22
+ self.sum = 0
23
+ self.count = 0
24
+
25
+ def update(self, val, n=1):
26
+ self.val = val
27
+ self.sum += val * n
28
+ self.count += n
29
+ self.avg = self.sum / self.count
30
+
31
+
32
+ def calculate_metrics(imgs_1, imgs_2):
33
+ psnrs = []
34
+ ssims = []
35
+ assert imgs_1.shape[0] == imgs_2.shape[0]
36
+ batch_size = imgs_1.shape[0]
37
+ for i in range(batch_size):
38
+ img1 = imgs_1[i]
39
+ img2 = imgs_2[i]
40
+ img1 = np.asarray(transforms.ToPILImage()(img1))
41
+ img2 = np.asarray(transforms.ToPILImage()(img2))
42
+ psnr = calculate_psnr(img1, img2, 0)
43
+ ssim = calculate_ssim(img1, img2, 0)
44
+ psnrs.append(psnr)
45
+ ssims.append(ssim)
46
+ return np.asarray(psnrs).mean(), np.asarray(ssims).mean()
47
+
48
+
49
+ def read_args(config_file):
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument("--config", default=config_file)
52
+ file = open(config_file)
53
+ config = yaml.safe_load(file)
54
+ for k, v in config.items():
55
+ parser.add_argument(f"--{k}", default=v)
56
+ return parser
57
+
58
+
59
+ def save_checkpoint(state, filename):
60
+ torch.save(state, filename)
61
+
62
+
63
+ class CosineAnnealingWarmRestarts(_LRScheduler):
64
+ r"""Set the learning rate of each parameter group using a cosine annealing
65
+ schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
66
+ is the number of epochs since the last restart and :math:`T_{i}` is the number
67
+ of epochs between two warm restarts in SGDR:
68
+ .. math::
69
+ \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
70
+ \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
71
+ When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
72
+ When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
73
+ It has been proposed in
74
+ `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
75
+ Args:
76
+ optimizer (Optimizer): Wrapped optimizer.
77
+ T_0 (int): Number of iterations for the first restart.
78
+ T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
79
+ eta_min (float, optional): Minimum learning rate. Default: 0.
80
+ last_epoch (int, optional): The index of last epoch. Default: -1.
81
+ verbose (bool): If ``True``, prints a message to stdout for
82
+ each update. Default: ``False``.
83
+ .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
84
+ https://arxiv.org/abs/1608.03983
85
+ """
86
+
87
+ def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
88
+ if T_0 <= 0 or not isinstance(T_0, int):
89
+ raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
90
+ if T_mult < 1 or not isinstance(T_mult, int):
91
+ raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
92
+ self.T_0 = T_0
93
+ self.T_i = T_0
94
+ self.T_mult = T_mult
95
+ self.eta_min = eta_min
96
+
97
+ self.T_cur = 0 if last_epoch < 0 else last_epoch
98
+ super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch, verbose)
99
+
100
+ def get_lr(self):
101
+ if not self._get_lr_called_within_step:
102
+ warnings.warn("To get the last learning rate computed by the scheduler, "
103
+ "please use `get_last_lr()`.", UserWarning)
104
+ return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
105
+ for base_lr in self.base_lrs]
106
+
107
+ def step(self, epoch=None):
108
+ """Step could be called after every batch update
109
+ Example:
110
+ >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
111
+ >>> iters = len(dataloader)
112
+ >>> for epoch in range(20):
113
+ >>> for i, sample in enumerate(dataloader):
114
+ >>> inputs, labels = sample['inputs'], sample['labels']
115
+ >>> optimizer.zero_grad()
116
+ >>> outputs = net(inputs)
117
+ >>> loss = criterion(outputs, labels)
118
+ >>> loss.backward()
119
+ >>> optimizer.step()
120
+ >>> scheduler.step(epoch + i / iters)
121
+ This function can be called in an interleaved way.
122
+ Example:
123
+ >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
124
+ >>> for epoch in range(20):
125
+ >>> scheduler.step()
126
+ >>> scheduler.step(26)
127
+ >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
128
+ """
129
+ if epoch is None and self.last_epoch < 0:
130
+ epoch = 0
131
+ if epoch is None:
132
+ epoch = self.last_epoch + 1
133
+ self.T_cur = self.T_cur + 1
134
+ if self.T_cur >= self.T_i:
135
+ self.T_cur = self.T_cur - self.T_i
136
+ self.T_i = self.T_i * self.T_mult
137
+ else:
138
+ if epoch < 0:
139
+ raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
140
+ if epoch >= self.T_0:
141
+ if self.T_mult == 1:
142
+ self.T_cur = epoch % self.T_0
143
+ else:
144
+ n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
145
+ self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
146
+ self.T_i = self.T_0 * self.T_mult ** (n)
147
+ else:
148
+ self.T_i = self.T_0
149
+ self.T_cur = epoch
150
+ self.last_epoch = math.floor(epoch)
151
+
152
+ class _enable_get_lr_call:
153
+ def __init__(self, o):
154
+ self.o = o
155
+
156
+ def __enter__(self):
157
+ self.o._get_lr_called_within_step = True
158
+ return self
159
+
160
+ def __exit__(self, type, value, traceback):
161
+ self.o._get_lr_called_within_step = False
162
+ return self
163
+
164
+ with _enable_get_lr_call(self):
165
+ for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
166
+ param_group, lr = data
167
+ param_group['lr'] = lr
168
+ self.print_lr(self.verbose, i, lr, epoch)
169
+ self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
170
+
171
+
172
+ def set_seed(seed):
173
+ random.seed(seed)
174
+ np.random.seed(seed)
175
+ torch.manual_seed(seed)
176
+ if torch.cuda.is_available():
177
+ torch.cuda.manual_seed_all(seed)