Upload 20 files
Browse files- LMAR_GAN_train.py +331 -0
- LMAR_VGG_train.py +300 -0
- LMAR_test.py +98 -0
- base_test.py +106 -0
- base_train.py +261 -0
- config/LMAR_config.yaml +48 -0
- config/base_config.yaml +42 -0
- data/__init__.py +1 -0
- data/loader.py +109 -0
- loss.py +221 -0
- metrics.py +133 -0
- model/LMAR_model.py +277 -0
- model/__init__.py +5 -0
- model/interp_methods.py +69 -0
- model/model.py +194 -0
- model/module.py +248 -0
- model/resize_right.py +437 -0
- pretrained_models/LMAR_model.bin +3 -0
- pretrained_models/base_model.bin +3 -0
- utils.py +177 -0
LMAR_GAN_train.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import yaml
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
from utils import read_args, save_checkpoint, AverageMeter, CosineAnnealingWarmRestarts
|
| 5 |
+
import time
|
| 6 |
+
from tqdm import trange, tqdm
|
| 7 |
+
from torchvision.utils import save_image
|
| 8 |
+
# from tensorboardX import SummaryWriter
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
import time
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
|
| 15 |
+
import torch
|
| 16 |
+
from torch import optim
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torchvision.utils as vutils
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
|
| 21 |
+
from data import *
|
| 22 |
+
from model import *
|
| 23 |
+
from loss import *
|
| 24 |
+
import pyiqa
|
| 25 |
+
from torch.autograd import Variable
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
global_step = 0
|
| 29 |
+
psnr_calculator = pyiqa.create_metric('psnr').cuda()
|
| 30 |
+
ssim_calculator = pyiqa.create_metric('ssimc', downsample=True).cuda()
|
| 31 |
+
|
| 32 |
+
criterion_GAN = nn.MSELoss()
|
| 33 |
+
Tensor = torch.cuda.FloatTensor
|
| 34 |
+
|
| 35 |
+
mmdLoss = MMDLoss().cuda()
|
| 36 |
+
|
| 37 |
+
# cos_loss = cos_loss
|
| 38 |
+
# feature_extractor.eval()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def train(model, data_loader, criterion, optimizer_G, optimizer_D, epoch, args, discriminator):
|
| 42 |
+
global global_step
|
| 43 |
+
iter_bar = tqdm(data_loader, desc='Iter (loss=X.XXX)')
|
| 44 |
+
nbatches = len(data_loader)
|
| 45 |
+
|
| 46 |
+
total_losses = AverageMeter()
|
| 47 |
+
|
| 48 |
+
pixel_losses = AverageMeter()
|
| 49 |
+
resize_losses = AverageMeter()
|
| 50 |
+
pseudo_losses = AverageMeter()
|
| 51 |
+
up_losses = AverageMeter()
|
| 52 |
+
dis_losses = AverageMeter()
|
| 53 |
+
|
| 54 |
+
psnrs = AverageMeter()
|
| 55 |
+
ssims = AverageMeter()
|
| 56 |
+
|
| 57 |
+
optimizer_G.zero_grad()
|
| 58 |
+
optimizer_D.zero_grad()
|
| 59 |
+
|
| 60 |
+
start_time = time.time()
|
| 61 |
+
|
| 62 |
+
if not os.path.exists(args.output_dir + '/image_train'):
|
| 63 |
+
os.mkdir(args.output_dir + '/image_train')
|
| 64 |
+
|
| 65 |
+
if not os.path.exists(args.output_dir + "/models"):
|
| 66 |
+
os.mkdir(args.output_dir + "/models")
|
| 67 |
+
|
| 68 |
+
for i, batch in enumerate(iter_bar):
|
| 69 |
+
optimizer_G.zero_grad()
|
| 70 |
+
optimizer_D.zero_grad()
|
| 71 |
+
|
| 72 |
+
inp_img, gt_img, down_h, down_w, inp_img_path = batch
|
| 73 |
+
batch_size = inp_img.size(0)
|
| 74 |
+
inp_img = inp_img.cuda()
|
| 75 |
+
gt_img = gt_img.cuda()
|
| 76 |
+
|
| 77 |
+
down_size = (down_h.item(), down_w.item())
|
| 78 |
+
up_size = eval(args.train_loader["img_size"])
|
| 79 |
+
|
| 80 |
+
down_x, hr_feature, new_lr_feature, ori_lr_feature, residual, res = model(inp_img, down_size, up_size)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
dis_patch_lr = (1, down_size[0] // 2 ** 4, down_size[1] // 2 ** 4)
|
| 84 |
+
valid_lr = Variable(Tensor(np.ones((batch_size, *dis_patch_lr))), requires_grad=False)
|
| 85 |
+
fake_lr = Variable(Tensor(np.zeros((batch_size, *dis_patch_lr))), requires_grad=False)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
pixel_loss = criterion_GAN(discriminator(down_x), valid_lr)
|
| 89 |
+
pixel_losses.update(pixel_loss.item(), batch_size)
|
| 90 |
+
|
| 91 |
+
resize_loss = criterion(hr_feature, new_lr_feature)
|
| 92 |
+
resize_losses.update(resize_loss.item(), batch_size)
|
| 93 |
+
|
| 94 |
+
pseudo_loss = similarity_loss(new_lr_feature, hr_feature) * 5000
|
| 95 |
+
pseudo_losses.update(pseudo_loss.item(), batch_size)
|
| 96 |
+
|
| 97 |
+
up_loss, gradient = feat_ssim(new_lr_feature, hr_feature, inp_img)
|
| 98 |
+
up_losses.update(up_loss.item(), batch_size)
|
| 99 |
+
|
| 100 |
+
total_loss = pixel_loss + resize_loss + pseudo_loss + up_loss
|
| 101 |
+
total_losses.update(total_loss.item(), batch_size)
|
| 102 |
+
|
| 103 |
+
total_loss.backward()
|
| 104 |
+
optimizer_G.step()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
loss_real_lr = criterion_GAN(discriminator(resize(inp_img, out_shape=down_size, antialiasing=False)), valid_lr)
|
| 109 |
+
|
| 110 |
+
loss_fake_lr = criterion_GAN(discriminator(down_x.detach()), fake_lr)
|
| 111 |
+
|
| 112 |
+
loss_D = (loss_fake_lr + loss_real_lr) * 0.5
|
| 113 |
+
dis_losses.update(loss_D.item(), batch_size)
|
| 114 |
+
|
| 115 |
+
loss_D.backward()
|
| 116 |
+
optimizer_D.step()
|
| 117 |
+
|
| 118 |
+
iter_bar.set_description('Iter (loss=%5.6f)' % (total_losses.avg + dis_losses.avg))
|
| 119 |
+
|
| 120 |
+
if i % 200 == 0:
|
| 121 |
+
error = torch.abs(resize(inp_img, out_shape=down_size, antialiasing=False) - down_x)
|
| 122 |
+
saved_image = torch.cat(
|
| 123 |
+
[resize(inp_img, out_shape=down_size, antialiasing=False)[0:2], down_x[0:2], error[0:2]],
|
| 124 |
+
dim=0)
|
| 125 |
+
save_image(saved_image, args.output_dir + '/image_train/epoch_{}_iter_down_{}.png'.format(epoch, i))
|
| 126 |
+
|
| 127 |
+
saved_image = torch.cat(
|
| 128 |
+
[torch.mean(hr_feature, dim=1, keepdim=True)[0:2], torch.mean(new_lr_feature, dim=1, keepdim=True)[0:2],
|
| 129 |
+
torch.mean(ori_lr_feature, dim=1, keepdim=True)[0:2], torch.mean(torch.abs(new_lr_feature-ori_lr_feature), dim=1, keepdim=True)[0:2]],
|
| 130 |
+
dim=0)
|
| 131 |
+
save_image(saved_image, args.output_dir + '/image_train/epoch_{}_iter_feat_{}.png'.format(epoch, i))
|
| 132 |
+
residual = residual * 10
|
| 133 |
+
save_image(residual[0], args.output_dir + '/image_train/epoch_{}_iter_out_{}.png'.format(epoch, i))
|
| 134 |
+
|
| 135 |
+
if i % max(1, nbatches // 10) == 0:
|
| 136 |
+
psnr_val, ssim_val = 0.0, 0.0
|
| 137 |
+
psnrs.update(psnr_val, batch_size)
|
| 138 |
+
ssims.update(ssim_val, batch_size)
|
| 139 |
+
|
| 140 |
+
logging.info(
|
| 141 |
+
"Epoch {}, learning rates {:}, Iter {}, total_loss {:.4f}, pixel_loss {:.4f}, resize_loss {:.4f}, pseudo_loss {:.4f}, up_loss {:.4f}, dis_loss: {:.4f}, PSNR {:.4f}, SSIM {:.4f}, Elapse time {:.2f}\n".format(
|
| 142 |
+
epoch, optimizer_G.param_groups[0]["lr"], i, total_losses.avg, pixel_losses.avg, resize_losses.avg,
|
| 143 |
+
pseudo_losses.avg, up_losses.avg, dis_losses.avg,
|
| 144 |
+
psnrs.avg, ssims.avg,
|
| 145 |
+
time.time() - start_time))
|
| 146 |
+
|
| 147 |
+
if epoch % 1 == 0:
|
| 148 |
+
logging.info("** ** * Saving model and optimizer ** ** * ")
|
| 149 |
+
|
| 150 |
+
output_model_file = os.path.join(args.output_dir + "/models", "model.%d.bin" % (epoch))
|
| 151 |
+
state = {"epoch": epoch, "state_dict": model.state_dict(), "step": global_step}
|
| 152 |
+
save_checkpoint(state, output_model_file)
|
| 153 |
+
|
| 154 |
+
output_model_file = os.path.join(args.output_dir + "/models", "discriminator.%d.bin" % (epoch))
|
| 155 |
+
state = {"epoch": epoch, "state_dict": discriminator.state_dict(), "step": global_step}
|
| 156 |
+
save_checkpoint(state, output_model_file)
|
| 157 |
+
logging.info("Save model to %s", output_model_file)
|
| 158 |
+
|
| 159 |
+
logging.info(
|
| 160 |
+
"Finish training epoch %d, avg total_loss: %.4f, avg pixel_loss: %.4f, avg resize_loss: %.4f, avg pseudo_loss: %.4f, avg up_loss: %.4f, "
|
| 161 |
+
"avg dis_loss: %.4f, avg PSNR: %.2f, avg SSIM: %.2F, and takes %.2f seconds" % (
|
| 162 |
+
epoch, total_losses.avg, pixel_losses.avg, resize_losses.avg, pseudo_losses.avg, up_losses.avg, dis_losses.avg, psnrs.avg,
|
| 163 |
+
ssims.avg,
|
| 164 |
+
time.time() - start_time))
|
| 165 |
+
|
| 166 |
+
logging.info("***** CUDA.empty_cache() *****\n")
|
| 167 |
+
torch.cuda.empty_cache()
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def evaluate(model, load_path, data_loader, epoch):
|
| 171 |
+
checkpoint = torch.load(load_path)
|
| 172 |
+
model.load_state_dict(checkpoint["state_dict"])
|
| 173 |
+
model.cuda()
|
| 174 |
+
model.eval()
|
| 175 |
+
|
| 176 |
+
psnrs = AverageMeter()
|
| 177 |
+
ssims = AverageMeter()
|
| 178 |
+
random_index = torch.randint(low=0, high=5, size=(1,))
|
| 179 |
+
down_size = eval(args.test_loader["img_size"])
|
| 180 |
+
down_size = down_size[random_index]
|
| 181 |
+
logging.info("Inference at down size: {}".format(down_size))
|
| 182 |
+
up_size = eval(args.test_loader["gt_size"])
|
| 183 |
+
|
| 184 |
+
start_time = time.time()
|
| 185 |
+
with torch.no_grad():
|
| 186 |
+
for i, batch in enumerate(tqdm(data_loader)):
|
| 187 |
+
inp_img, gt_img, inp_img_path = batch
|
| 188 |
+
inp_img = inp_img.cuda()
|
| 189 |
+
batch_size = inp_img.size(0)
|
| 190 |
+
up_out, _ = model(inp_img, down_size, up_size, test_flag=True)
|
| 191 |
+
|
| 192 |
+
# metrics
|
| 193 |
+
clamped_out = torch.clamp(up_out, 0, 1)
|
| 194 |
+
psnr_val, ssim_val = psnr_calculator(clamped_out, gt_img), ssim_calculator(clamped_out, gt_img)
|
| 195 |
+
psnrs.update(torch.mean(psnr_val).item(), batch_size)
|
| 196 |
+
ssims.update(torch.mean(ssim_val).item(), batch_size)
|
| 197 |
+
torch.cuda.empty_cache()
|
| 198 |
+
|
| 199 |
+
if i % 100 == 0:
|
| 200 |
+
logging.info(
|
| 201 |
+
"PSNR {:.4f}, SSIM {:.4f}, Elapse time {:.2f}\n".format(psnrs.avg, ssims.avg,
|
| 202 |
+
time.time() - start_time))
|
| 203 |
+
|
| 204 |
+
logging.info("avg PSNR: %.4f, avg SSIM: %.4F, and takes %.2f seconds" % (
|
| 205 |
+
psnrs.avg, ssims.avg, time.time() - start_time))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def main(args):
|
| 209 |
+
global global_step
|
| 210 |
+
|
| 211 |
+
start_epoch = 1
|
| 212 |
+
global_step = 0
|
| 213 |
+
|
| 214 |
+
if not os.path.exists(args.output_dir):
|
| 215 |
+
os.mkdir(args.output_dir)
|
| 216 |
+
|
| 217 |
+
with open(os.path.join(args.output_dir, "args.json"), "w") as f:
|
| 218 |
+
json.dump(args.__dict__, f, sort_keys=True, indent=2)
|
| 219 |
+
|
| 220 |
+
log_format = "%(asctime)s %(levelname)-8s %(message)s"
|
| 221 |
+
log_file = os.path.join(args.output_dir, "train_log")
|
| 222 |
+
logging.basicConfig(filename=log_file, level=logging.INFO, format=log_format)
|
| 223 |
+
logging.getLogger().addHandler(logging.StreamHandler())
|
| 224 |
+
|
| 225 |
+
# device setting
|
| 226 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 227 |
+
args.device = device
|
| 228 |
+
|
| 229 |
+
logging.info(args.__dict__)
|
| 230 |
+
|
| 231 |
+
model = codebook_model(args)
|
| 232 |
+
|
| 233 |
+
discriminator = Discriminator(3).cuda()
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
optimizer_G = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.optimizer["lr"],
|
| 237 |
+
betas=(0.9, 0.999))
|
| 238 |
+
|
| 239 |
+
optimizer_D = optim.Adam(list(discriminator.parameters()),
|
| 240 |
+
lr=args.optimizer["lr"],
|
| 241 |
+
betas=(0.9, 0.999))
|
| 242 |
+
|
| 243 |
+
logging.info("Building data loader")
|
| 244 |
+
|
| 245 |
+
if args.train_loader["loader"] == "resize":
|
| 246 |
+
train_transforms = transforms.Compose([transforms.Resize(eval(args.train_loader["img_size"])),
|
| 247 |
+
transforms.ToTensor()])
|
| 248 |
+
train_loader = get_loader(args.data["train_dir"],
|
| 249 |
+
eval(args.train_loader["img_size"]), train_transforms, False,
|
| 250 |
+
int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
|
| 251 |
+
args.train_loader["shuffle"], random_flag=False)
|
| 252 |
+
|
| 253 |
+
elif args.train_loader["loader"] == "crop":
|
| 254 |
+
train_loader = get_loader(args.data["train_dir"],
|
| 255 |
+
eval(args.train_loader["img_size"]), False, True,
|
| 256 |
+
int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
|
| 257 |
+
args.train_loader["shuffle"], random_flag=args.train_loader["random_flag"])
|
| 258 |
+
|
| 259 |
+
elif args.train_loader["loader"] == "default":
|
| 260 |
+
train_transforms = transforms.Compose([transforms.ToTensor()])
|
| 261 |
+
train_loader = get_loader(args.data["train_dir"],
|
| 262 |
+
eval(args.train_loader["img_size"]), train_transforms, False,
|
| 263 |
+
int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
|
| 264 |
+
args.train_loader["shuffle"], random_flag=args.train_loader["random_flag"])
|
| 265 |
+
else:
|
| 266 |
+
raise NotImplementedError
|
| 267 |
+
|
| 268 |
+
if args.test_loader["loader"] == "default":
|
| 269 |
+
|
| 270 |
+
test_transforms = transforms.Compose([transforms.ToTensor()])
|
| 271 |
+
test_loader = get_loader(args.data["test_dir"],
|
| 272 |
+
None, test_transforms, False,
|
| 273 |
+
int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
|
| 274 |
+
args.test_loader["shuffle"], random_flag=False)
|
| 275 |
+
|
| 276 |
+
elif args.test_loader["loader"] == "resize":
|
| 277 |
+
|
| 278 |
+
test_transforms = transforms.Compose([transforms.Resize(eval(args.test_loader["img_size"])),
|
| 279 |
+
transforms.ToTensor()])
|
| 280 |
+
test_loader = get_loader(args.data["test_dir"],
|
| 281 |
+
eval(args.test_loader["img_size"]), test_transforms, False,
|
| 282 |
+
int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
|
| 283 |
+
args.test_loader["shuffle"], random_flag=False)
|
| 284 |
+
else:
|
| 285 |
+
raise NotImplementedError
|
| 286 |
+
|
| 287 |
+
# criterion = similarity_loss
|
| 288 |
+
criterion = nn.SmoothL1Loss()
|
| 289 |
+
# criterion = nn.L1Loss()
|
| 290 |
+
|
| 291 |
+
# vgg_loss = VGGLoss()
|
| 292 |
+
|
| 293 |
+
if args.optimizer["type"] == "cos":
|
| 294 |
+
lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=args.optimizer["T_0"],
|
| 295 |
+
T_mult=args.optimizer["T_MULT"],
|
| 296 |
+
eta_min=args.optimizer["ETA_MIN"],
|
| 297 |
+
last_epoch=-1)
|
| 298 |
+
elif args.optimizer["type"] == "step":
|
| 299 |
+
lr_scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=args.optimizer["step"],
|
| 300 |
+
gamma=args.optimizer["gamma"])
|
| 301 |
+
lr_scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=args.optimizer["step"],
|
| 302 |
+
gamma=args.optimizer["gamma"])
|
| 303 |
+
|
| 304 |
+
t_total = int(len(train_loader) * args.optimizer["total_epoch"])
|
| 305 |
+
logging.info("***** CUDA.empty_cache() *****")
|
| 306 |
+
torch.cuda.empty_cache()
|
| 307 |
+
|
| 308 |
+
logging.info("***** Running training *****")
|
| 309 |
+
logging.info(" Batch size = %d", args.train_loader["batch_size"])
|
| 310 |
+
logging.info(" Num steps = %d", t_total)
|
| 311 |
+
logging.info(" Loader length = %d", len(train_loader))
|
| 312 |
+
|
| 313 |
+
model.train()
|
| 314 |
+
model.cuda()
|
| 315 |
+
|
| 316 |
+
logging.info("Begin training from epoch = %d\n", start_epoch)
|
| 317 |
+
for epoch in trange(start_epoch, args.optimizer["total_epoch"] + 1, desc="Epoch"):
|
| 318 |
+
train(model, train_loader, criterion, optimizer_G, optimizer_D, epoch, args, discriminator)
|
| 319 |
+
lr_scheduler_G.step()
|
| 320 |
+
lr_scheduler_D.step()
|
| 321 |
+
if epoch % args.evaluate_intervel == 0:
|
| 322 |
+
logging.info("***** Running testing *****")
|
| 323 |
+
load_path = os.path.join(args.output_dir + "/models", "model.%d.bin" % (epoch))
|
| 324 |
+
evaluate(model, load_path, test_loader, epoch)
|
| 325 |
+
logging.info("***** End testing *****")
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
if __name__ == '__main__':
|
| 329 |
+
parser = read_args("/home/yuwei/code/cvpr/config/LMAR_config.yaml")
|
| 330 |
+
args = parser.parse_args()
|
| 331 |
+
main(args)
|
LMAR_VGG_train.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
|
| 4 |
+
import argparse
|
| 5 |
+
import yaml
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
from utils import read_args, save_checkpoint, AverageMeter, CosineAnnealingWarmRestarts
|
| 8 |
+
import time
|
| 9 |
+
from tqdm import trange, tqdm
|
| 10 |
+
from torchvision.utils import save_image
|
| 11 |
+
import json
|
| 12 |
+
import time
|
| 13 |
+
import logging
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn, optim
|
| 16 |
+
import torchvision.utils as vutils
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from data import *
|
| 21 |
+
from model import *
|
| 22 |
+
from loss import *
|
| 23 |
+
import pyiqa
|
| 24 |
+
|
| 25 |
+
from torch.autograd import Variable
|
| 26 |
+
|
| 27 |
+
global_step = 0
|
| 28 |
+
psnr_calculator = pyiqa.create_metric('psnr').cuda()
|
| 29 |
+
ssim_calculator = pyiqa.create_metric('ssimc', downsample=True).cuda()
|
| 30 |
+
|
| 31 |
+
feature_extractor = VGGPerceptualLoss(resize=False).cuda()
|
| 32 |
+
feature_extractor.eval()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def weight_annealing(epoch):
|
| 36 |
+
initial_weight = 1
|
| 37 |
+
if epoch < 2:
|
| 38 |
+
return initial_weight # 初始阶段保持权重不变
|
| 39 |
+
else:
|
| 40 |
+
return initial_weight * 0.001 # 后续阶段权重继续减小
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def train(model, data_loader, criterion, optimizer, epoch, args):
|
| 44 |
+
global global_step
|
| 45 |
+
iter_bar = tqdm(data_loader, desc='Iter (loss=X.XXX)')
|
| 46 |
+
nbatches = len(data_loader)
|
| 47 |
+
|
| 48 |
+
total_losses = AverageMeter()
|
| 49 |
+
|
| 50 |
+
pixel_losses = AverageMeter()
|
| 51 |
+
resize_losses = AverageMeter()
|
| 52 |
+
pseudo_losses = AverageMeter()
|
| 53 |
+
|
| 54 |
+
psnrs = AverageMeter()
|
| 55 |
+
ssims = AverageMeter()
|
| 56 |
+
|
| 57 |
+
optimizer.zero_grad()
|
| 58 |
+
|
| 59 |
+
start_time = time.time()
|
| 60 |
+
|
| 61 |
+
if not os.path.exists(args.output_dir + '/image_train'):
|
| 62 |
+
os.mkdir(args.output_dir + '/image_train')
|
| 63 |
+
|
| 64 |
+
if not os.path.exists(args.output_dir + "/models"):
|
| 65 |
+
os.mkdir(args.output_dir + "/models")
|
| 66 |
+
|
| 67 |
+
for i, batch in enumerate(iter_bar):
|
| 68 |
+
optimizer.zero_grad()
|
| 69 |
+
inp_img, gt_img, down_h, down_w, inp_img_path = batch
|
| 70 |
+
batch_size = inp_img.size(0)
|
| 71 |
+
inp_img = inp_img.cuda()
|
| 72 |
+
gt_img = gt_img.cuda()
|
| 73 |
+
|
| 74 |
+
down_size = (down_h.item(), down_w.item())
|
| 75 |
+
up_size = eval(args.train_loader["img_size"])
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
down_x, hr_feature, new_lr_feature, ori_lr_feature, residual, res = model(inp_img, down_size, up_size)
|
| 79 |
+
|
| 80 |
+
pixel_loss = criterion(new_lr_feature, hr_feature)
|
| 81 |
+
|
| 82 |
+
pixel_losses.update(pixel_loss.item(), batch_size)
|
| 83 |
+
|
| 84 |
+
resize_loss = feature_extractor(down_x, resize(inp_img, out_shape=down_size, antialiasing=False),
|
| 85 |
+
feature_layers=[3])
|
| 86 |
+
resize_loss = resize_loss * weight_annealing(epoch)
|
| 87 |
+
resize_losses.update(resize_loss.item(), batch_size)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
pseudo_loss, _ = feat_ssim(new_lr_feature, hr_feature, inp_img)
|
| 91 |
+
pseudo_losses.update(pseudo_loss.item(), batch_size)
|
| 92 |
+
|
| 93 |
+
total_loss = pixel_loss + resize_loss + pseudo_loss
|
| 94 |
+
total_losses.update(total_loss.item(), batch_size)
|
| 95 |
+
|
| 96 |
+
total_loss.backward()
|
| 97 |
+
|
| 98 |
+
optimizer.step()
|
| 99 |
+
|
| 100 |
+
iter_bar.set_description('Iter (loss=%5.6f)' % total_losses.avg)
|
| 101 |
+
|
| 102 |
+
if i % 200 == 0:
|
| 103 |
+
# print(residual.max())
|
| 104 |
+
error = torch.abs(resize(inp_img, out_shape=down_size, antialiasing=False) - down_x)
|
| 105 |
+
# error = (error - error.min()) / (error.max()-error.min())
|
| 106 |
+
saved_image = torch.cat(
|
| 107 |
+
[resize(inp_img, out_shape=down_size, antialiasing=False)[0:2], down_x[0:2], error[0:2]],
|
| 108 |
+
dim=0)
|
| 109 |
+
save_image(saved_image, args.output_dir + '/image_train/epoch_{}_iter_down_{}.png'.format(epoch, i))
|
| 110 |
+
|
| 111 |
+
saved_image = torch.cat(
|
| 112 |
+
[torch.mean(hr_feature, dim=1, keepdim=True)[0:2], torch.mean(new_lr_feature, dim=1, keepdim=True)[0:2],
|
| 113 |
+
torch.mean(ori_lr_feature, dim=1, keepdim=True)[0:2],
|
| 114 |
+
torch.mean(torch.abs(new_lr_feature - ori_lr_feature), dim=1, keepdim=True)[0:2]],
|
| 115 |
+
dim=0)
|
| 116 |
+
save_image(saved_image, args.output_dir + '/image_train/epoch_{}_iter_feat_{}.png'.format(epoch, i))
|
| 117 |
+
# residual = (residual - residual.min()) / (residual.max()-residual.min())
|
| 118 |
+
residual = residual * 10
|
| 119 |
+
save_image(residual[0], args.output_dir + '/image_train/epoch_{}_iter_out_{}.png'.format(epoch, i))
|
| 120 |
+
|
| 121 |
+
if i % max(1, nbatches // 10) == 0:
|
| 122 |
+
psnr_val, ssim_val = 0.0, 0.0
|
| 123 |
+
psnrs.update(psnr_val, batch_size)
|
| 124 |
+
ssims.update(ssim_val, batch_size)
|
| 125 |
+
|
| 126 |
+
logging.info(
|
| 127 |
+
"Epoch {}, learning rates {:}, Iter {}, total_loss {:.4f}, pixel_loss {:.4f}, resize_loss {:.4f}, pseudo_loss {:.4f}, PSNR {:.4f}, SSIM {:.4f}, Elapse time {:.2f}\n".format(
|
| 128 |
+
epoch, optimizer.param_groups[0]["lr"], i, total_losses.avg, pixel_losses.avg, resize_losses.avg,
|
| 129 |
+
pseudo_losses.avg,
|
| 130 |
+
psnrs.avg, ssims.avg,
|
| 131 |
+
time.time() - start_time))
|
| 132 |
+
|
| 133 |
+
if epoch % 1 == 0:
|
| 134 |
+
logging.info("** ** * Saving model and optimizer ** ** * ")
|
| 135 |
+
|
| 136 |
+
output_model_file = os.path.join(args.output_dir + "/models", "model.%d.bin" % (epoch))
|
| 137 |
+
state = {"epoch": epoch, "state_dict": model.state_dict(),
|
| 138 |
+
"optimizer": optimizer.state_dict(), "step": global_step}
|
| 139 |
+
|
| 140 |
+
save_checkpoint(state, output_model_file)
|
| 141 |
+
logging.info("Save model to %s", output_model_file)
|
| 142 |
+
|
| 143 |
+
logging.info(
|
| 144 |
+
"Finish training epoch %d, avg total_loss: %.4f, avg pixel_loss: %.4f, avg resize_loss: %.4f, avg pseudo_loss: %.4f, avg PSNR: %.2f, avg SSIM: %.2F, and takes %.2f seconds" % (
|
| 145 |
+
epoch, total_losses.avg, pixel_losses.avg, resize_losses.avg, pseudo_losses.avg, psnrs.avg, ssims.avg,
|
| 146 |
+
time.time() - start_time))
|
| 147 |
+
|
| 148 |
+
logging.info("***** CUDA.empty_cache() *****\n")
|
| 149 |
+
torch.cuda.empty_cache()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def evaluate(model, load_path, data_loader, epoch):
|
| 153 |
+
checkpoint = torch.load(load_path)
|
| 154 |
+
model.load_state_dict(checkpoint["state_dict"])
|
| 155 |
+
model.cuda()
|
| 156 |
+
model.eval()
|
| 157 |
+
|
| 158 |
+
psnrs = AverageMeter()
|
| 159 |
+
ssims = AverageMeter()
|
| 160 |
+
random_index = torch.randint(low=0, high=5, size=(1,))
|
| 161 |
+
down_size = eval(args.test_loader["img_size"])
|
| 162 |
+
down_size = down_size[random_index]
|
| 163 |
+
logging.info("Inference at down size: {}".format(down_size))
|
| 164 |
+
up_size = eval(args.test_loader["gt_size"])
|
| 165 |
+
|
| 166 |
+
start_time = time.time()
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
for i, batch in enumerate(tqdm(data_loader)):
|
| 169 |
+
inp_img, gt_img, inp_img_path = batch
|
| 170 |
+
inp_img = inp_img.cuda()
|
| 171 |
+
batch_size = inp_img.size(0)
|
| 172 |
+
up_out = model(inp_img, down_size, up_size, test_flag=True)
|
| 173 |
+
|
| 174 |
+
# metrics
|
| 175 |
+
clamped_out = torch.clamp(up_out, 0, 1)
|
| 176 |
+
psnr_val, ssim_val = psnr_calculator(clamped_out, gt_img), ssim_calculator(clamped_out, gt_img)
|
| 177 |
+
psnrs.update(torch.mean(psnr_val).item(), batch_size)
|
| 178 |
+
ssims.update(torch.mean(ssim_val).item(), batch_size)
|
| 179 |
+
torch.cuda.empty_cache()
|
| 180 |
+
|
| 181 |
+
if i % 100 == 0:
|
| 182 |
+
logging.info(
|
| 183 |
+
"PSNR {:.4f}, SSIM {:.4f}, Elapse time {:.2f}\n".format(psnrs.avg, ssims.avg,
|
| 184 |
+
time.time() - start_time))
|
| 185 |
+
|
| 186 |
+
logging.info(f"Finish test at epoch {epoch}: avg PSNR: %.4f, avg SSIM: %.4F, and takes %.2f seconds" % (
|
| 187 |
+
psnrs.avg, ssims.avg, time.time() - start_time))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def main(args):
|
| 191 |
+
global global_step
|
| 192 |
+
|
| 193 |
+
start_epoch = 1
|
| 194 |
+
global_step = 0
|
| 195 |
+
|
| 196 |
+
if not os.path.exists(args.output_dir):
|
| 197 |
+
os.mkdir(args.output_dir)
|
| 198 |
+
|
| 199 |
+
with open(os.path.join(args.output_dir, "args.json"), "w") as f:
|
| 200 |
+
json.dump(args.__dict__, f, sort_keys=True, indent=2)
|
| 201 |
+
|
| 202 |
+
log_format = "%(asctime)s %(levelname)-8s %(message)s"
|
| 203 |
+
log_file = os.path.join(args.output_dir, "train_log")
|
| 204 |
+
logging.basicConfig(filename=log_file, level=logging.INFO, format=log_format)
|
| 205 |
+
logging.getLogger().addHandler(logging.StreamHandler())
|
| 206 |
+
|
| 207 |
+
# device setting
|
| 208 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 209 |
+
args.device = device
|
| 210 |
+
|
| 211 |
+
logging.info(args.__dict__)
|
| 212 |
+
|
| 213 |
+
model = codebook_model(args)
|
| 214 |
+
|
| 215 |
+
optimizer = optim.Adam(model.parameters(), lr=args.optimizer["lr"],
|
| 216 |
+
betas=(0.9, 0.999))
|
| 217 |
+
|
| 218 |
+
logging.info("Building data loader")
|
| 219 |
+
|
| 220 |
+
if args.train_loader["loader"] == "resize":
|
| 221 |
+
train_transforms = transforms.Compose([transforms.Resize(eval(args.train_loader["img_size"])),
|
| 222 |
+
transforms.ToTensor()])
|
| 223 |
+
train_loader = get_loader(args.data["train_dir"],
|
| 224 |
+
eval(args.train_loader["img_size"]), train_transforms, False,
|
| 225 |
+
int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
|
| 226 |
+
args.train_loader["shuffle"], random_flag=False)
|
| 227 |
+
|
| 228 |
+
elif args.train_loader["loader"] == "crop":
|
| 229 |
+
train_loader = get_loader(args.data["train_dir"],
|
| 230 |
+
eval(args.train_loader["img_size"]), False, True,
|
| 231 |
+
int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
|
| 232 |
+
args.train_loader["shuffle"], random_flag=args.train_loader["random_flag"])
|
| 233 |
+
|
| 234 |
+
elif args.train_loader["loader"] == "default":
|
| 235 |
+
train_transforms = transforms.Compose([transforms.ToTensor()])
|
| 236 |
+
train_loader = get_loader(args.data["train_dir"],
|
| 237 |
+
eval(args.train_loader["img_size"]), train_transforms, False,
|
| 238 |
+
int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
|
| 239 |
+
args.train_loader["shuffle"], random_flag=args.train_loader["random_flag"])
|
| 240 |
+
else:
|
| 241 |
+
raise NotImplementedError
|
| 242 |
+
|
| 243 |
+
if args.test_loader["loader"] == "default":
|
| 244 |
+
|
| 245 |
+
test_transforms = transforms.Compose([transforms.ToTensor()])
|
| 246 |
+
test_loader = get_loader(args.data["test_dir"],
|
| 247 |
+
eval(args.test_loader["img_size"]), test_transforms, False,
|
| 248 |
+
int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
|
| 249 |
+
args.test_loader["shuffle"], random_flag=False)
|
| 250 |
+
|
| 251 |
+
elif args.test_loader["loader"] == "resize":
|
| 252 |
+
test_transforms = transforms.Compose([transforms.Resize(eval(args.test_loader["img_size"])),
|
| 253 |
+
transforms.ToTensor()])
|
| 254 |
+
test_loader = get_loader(args.data["test_dir"],
|
| 255 |
+
eval(args.test_loader["img_size"]), test_transforms, False,
|
| 256 |
+
int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
|
| 257 |
+
args.test_loader["shuffle"], random_flag=False)
|
| 258 |
+
else:
|
| 259 |
+
raise NotImplementedError
|
| 260 |
+
|
| 261 |
+
criterion = nn.SmoothL1Loss()
|
| 262 |
+
# vgg_loss = VGGLoss()
|
| 263 |
+
|
| 264 |
+
if args.optimizer["type"] == "cos":
|
| 265 |
+
lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=args.optimizer["T_0"],
|
| 266 |
+
T_mult=args.optimizer["T_MULT"],
|
| 267 |
+
eta_min=args.optimizer["ETA_MIN"],
|
| 268 |
+
last_epoch=-1)
|
| 269 |
+
elif args.optimizer["type"] == "step":
|
| 270 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.optimizer["step"],
|
| 271 |
+
gamma=args.optimizer["gamma"])
|
| 272 |
+
|
| 273 |
+
t_total = int(len(train_loader) * args.optimizer["total_epoch"])
|
| 274 |
+
logging.info("***** CUDA.empty_cache() *****")
|
| 275 |
+
torch.cuda.empty_cache()
|
| 276 |
+
|
| 277 |
+
logging.info("***** Running training *****")
|
| 278 |
+
logging.info(" Batch size = %d", args.train_loader["batch_size"])
|
| 279 |
+
logging.info(" Num steps = %d", t_total)
|
| 280 |
+
logging.info(" Loader length = %d", len(train_loader))
|
| 281 |
+
|
| 282 |
+
model.train()
|
| 283 |
+
model.cuda()
|
| 284 |
+
|
| 285 |
+
logging.info("Begin training from epoch = %d\n", start_epoch)
|
| 286 |
+
# evaluate(model, "/home/yuwei/experiment/cvpr/prompt_final_vgg_gradient/models/model.1.bin", test_loader, 1)
|
| 287 |
+
for epoch in trange(start_epoch, args.optimizer["total_epoch"] + 1, desc="Epoch"):
|
| 288 |
+
train(model, train_loader, criterion, optimizer, epoch, args)
|
| 289 |
+
lr_scheduler.step()
|
| 290 |
+
if epoch % args.evaluate_intervel == 0:
|
| 291 |
+
logging.info("***** Running testing *****")
|
| 292 |
+
load_path = os.path.join(args.output_dir + "/models", "model.%d.bin" % (epoch))
|
| 293 |
+
evaluate(model, load_path, test_loader, epoch)
|
| 294 |
+
logging.info("***** End testing *****")
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
if __name__ == '__main__':
|
| 298 |
+
parser = read_args("/home/yuwei/code/cvpr/config/LMAR_config.yaml")
|
| 299 |
+
args = parser.parse_args()
|
| 300 |
+
main(args)
|
LMAR_test.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import yaml
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
from utils import read_args, save_checkpoint, AverageMeter, CosineAnnealingWarmRestarts
|
| 5 |
+
import time
|
| 6 |
+
from tqdm import trange, tqdm
|
| 7 |
+
from torchvision.utils import save_image
|
| 8 |
+
import os
|
| 9 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
| 10 |
+
import json
|
| 11 |
+
import time
|
| 12 |
+
import logging
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn, optim
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
import copy
|
| 19 |
+
from model import *
|
| 20 |
+
from data import *
|
| 21 |
+
from PIL import Image
|
| 22 |
+
from torch.optim import LBFGS
|
| 23 |
+
import pyiqa
|
| 24 |
+
from thop import profile
|
| 25 |
+
from thop import clever_format
|
| 26 |
+
|
| 27 |
+
from torchvision.models.feature_extraction import create_feature_extractor
|
| 28 |
+
|
| 29 |
+
psnr_calculator = pyiqa.create_metric('psnr').cuda()
|
| 30 |
+
ssim_calculator = pyiqa.create_metric('ssimc', downsample=True).cuda()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def test(load_path, data_loader, args):
|
| 34 |
+
model = codebook_model(args)
|
| 35 |
+
checkpoint = torch.load(load_path)
|
| 36 |
+
model.load_state_dict(checkpoint["state_dict"])
|
| 37 |
+
model.cuda()
|
| 38 |
+
model.eval()
|
| 39 |
+
|
| 40 |
+
psnrs = AverageMeter()
|
| 41 |
+
ssims = AverageMeter()
|
| 42 |
+
lpipss = AverageMeter()
|
| 43 |
+
niqes = AverageMeter()
|
| 44 |
+
|
| 45 |
+
down_size = (1440, 2560)
|
| 46 |
+
logging.info("Inference at down size: {}".format(down_size))
|
| 47 |
+
up_size = eval(args.test_loader["gt_size"])
|
| 48 |
+
|
| 49 |
+
start_time = time.time()
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
for i, batch in enumerate(tqdm(data_loader)):
|
| 52 |
+
inp_img, gt_img, inp_img_path = batch
|
| 53 |
+
inp_img = inp_img.cuda()
|
| 54 |
+
batch_size = inp_img.size(0)
|
| 55 |
+
gt_img = gt_img.cuda()
|
| 56 |
+
up_out = model(inp_img, down_size, up_size, test_flag=True)
|
| 57 |
+
name = inp_img_path[0].split("/")[-1]
|
| 58 |
+
# save_image(up_out[0], os.path.join(save_path, name))
|
| 59 |
+
|
| 60 |
+
# metrics
|
| 61 |
+
clamped_out = torch.clamp(up_out, 0, 1)
|
| 62 |
+
|
| 63 |
+
psnr_val, ssim_val = psnr_calculator(clamped_out, gt_img), ssim_calculator(clamped_out, gt_img)
|
| 64 |
+
psnrs.update(psnr_val.item(), batch_size)
|
| 65 |
+
ssims.update(ssim_val.item(), batch_size)
|
| 66 |
+
|
| 67 |
+
if i % 700 == 0:
|
| 68 |
+
logging.info(
|
| 69 |
+
"PSNR {:.4f}, SSIM {:.4f}, LPIPS {:.4F}, NIQE {:.4F}, Elapse time {:.2f}\n".format(psnrs.avg, ssims.avg, lpipss.avg, niqes.avg,
|
| 70 |
+
time.time() - start_time))
|
| 71 |
+
|
| 72 |
+
logging.info("Finish test: avg PSNR: %.4f, avg SSIM: %.4F, avg LPIPS: %.4F, avg NIQE: %.4F, and takes %.2f seconds" % (
|
| 73 |
+
psnrs.avg, ssims.avg, lpipss.avg, niqes.avg, time.time() - start_time))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def main(args, load_path):
|
| 77 |
+
if not os.path.exists(args.output_dir):
|
| 78 |
+
os.mkdir(args.output_dir)
|
| 79 |
+
test_transforms = transforms.Compose([transforms.ToTensor()])
|
| 80 |
+
|
| 81 |
+
log_format = "%(asctime)s %(levelname)-8s %(message)s"
|
| 82 |
+
log_file = os.path.join(args.output_dir, "test_log")
|
| 83 |
+
logging.basicConfig(filename=log_file, level=logging.INFO, format=log_format)
|
| 84 |
+
logging.getLogger().addHandler(logging.StreamHandler())
|
| 85 |
+
|
| 86 |
+
logging.info("Building data loader")
|
| 87 |
+
|
| 88 |
+
test_loader = get_loader(args.data["test_dir"],
|
| 89 |
+
eval(args.test_loader["img_size"]), test_transforms, False,
|
| 90 |
+
int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
|
| 91 |
+
args.test_loader["shuffle"], random_flag=False)
|
| 92 |
+
test_time(load_path, test_loader, args)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == '__main__':
|
| 96 |
+
parser = read_args("/home/yuwei/code/cvpr/config/LMAR_config.yaml")
|
| 97 |
+
args = parser.parse_args()
|
| 98 |
+
main(args, "./pretrained_models\LMAR_model.bin")
|
base_test.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import yaml
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
from utils import read_args, save_checkpoint, AverageMeter, calculate_metrics, CosineAnnealingWarmRestarts
|
| 5 |
+
# import torchvision.transforms.InterpolationMode
|
| 6 |
+
import time
|
| 7 |
+
from tqdm import trange, tqdm
|
| 8 |
+
from torchvision.utils import save_image
|
| 9 |
+
import os
|
| 10 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
| 11 |
+
import json
|
| 12 |
+
import time
|
| 13 |
+
import logging
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn, optim
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
from model import *
|
| 20 |
+
from data import *
|
| 21 |
+
from PIL import Image
|
| 22 |
+
from torchvision.transforms import Resize
|
| 23 |
+
import pyiqa
|
| 24 |
+
from thop import profile
|
| 25 |
+
from thop import clever_format
|
| 26 |
+
|
| 27 |
+
psnr_calculator = pyiqa.create_metric('psnr').cuda()
|
| 28 |
+
ssim_calculator = pyiqa.create_metric('ssimc', downsample=True).cuda()
|
| 29 |
+
lpips_calculator = pyiqa.create_metric('lpips').cuda()
|
| 30 |
+
niqe_calculator = pyiqa.create_metric('niqe').cuda()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def test(load_path, data_loader, args):
|
| 34 |
+
# if not os.path.exists(args.output_dir + '/out_my'):
|
| 35 |
+
# os.mkdir(args.output_dir + '/out_my')
|
| 36 |
+
|
| 37 |
+
# save_path = args.output_dir + "/out_my"
|
| 38 |
+
model = net(args)
|
| 39 |
+
checkpoint = torch.load(load_path)
|
| 40 |
+
model.load_state_dict(checkpoint["state_dict"])
|
| 41 |
+
model.cuda()
|
| 42 |
+
model.eval()
|
| 43 |
+
|
| 44 |
+
psnrs = AverageMeter()
|
| 45 |
+
ssims = AverageMeter()
|
| 46 |
+
lpipss = AverageMeter()
|
| 47 |
+
niqes = AverageMeter()
|
| 48 |
+
|
| 49 |
+
start_time = time.time()
|
| 50 |
+
down_size = (1440, 2560)
|
| 51 |
+
logging.info("Inference at down size: {}".format(down_size))
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
for i, batch in enumerate(tqdm(data_loader)):
|
| 54 |
+
input_img, gt_img, inp_img_path = batch
|
| 55 |
+
|
| 56 |
+
name = inp_img_path[0].split("/")[-1]
|
| 57 |
+
input_img = input_img.cuda()
|
| 58 |
+
batch_size = input_img.size(0)
|
| 59 |
+
start_time = time.time()
|
| 60 |
+
input_img = resize(input_img, out_shape=down_size, antialiasing=False)
|
| 61 |
+
out_img = model(input_img)
|
| 62 |
+
out_img = resize(out_img, out_shape=eval(args.test_loader["gt_size"]), antialiasing=False)
|
| 63 |
+
|
| 64 |
+
# metrics
|
| 65 |
+
clamped_out = torch.clamp(out_img, 0, 1)
|
| 66 |
+
psnr_val, ssim_val = psnr_calculator(clamped_out, gt_img), ssim_calculator(clamped_out, gt_img)
|
| 67 |
+
psnrs.update(torch.mean(psnr_val).item(), batch_size)
|
| 68 |
+
ssims.update(torch.mean(ssim_val).item(), batch_size)
|
| 69 |
+
|
| 70 |
+
# lpips = lpips_calculator(clamped_out, gt_img)
|
| 71 |
+
# lpipss.update(torch.mean(lpips).item(), batch_size)
|
| 72 |
+
# niqe = niqe_calculator(clamped_out)
|
| 73 |
+
# niqes.update(torch.mean(niqe).item(), batch_size)
|
| 74 |
+
torch.cuda.empty_cache()
|
| 75 |
+
|
| 76 |
+
if i % 700 == 0:
|
| 77 |
+
logging.info(
|
| 78 |
+
"PSNR {:.4f}, SSIM {:.4f}, LPIPS {:.4F}, NIQE {:.4F}, Elapse time {:.2f}\n".format(psnrs.avg, ssims.avg, lpipss.avg, niqes.avg,
|
| 79 |
+
time.time() - start_time))
|
| 80 |
+
|
| 81 |
+
logging.info("Finish test: avg PSNR: %.4f, avg SSIM: %.4F, avg LPIPS: %.4F, avg NIQE: %.4F, and takes %.2f seconds" % (
|
| 82 |
+
psnrs.avg, ssims.avg, lpipss.avg, niqes.avg, time.time() - start_time))
|
| 83 |
+
|
| 84 |
+
def main(args, load_path):
|
| 85 |
+
if not os.path.exists(args.output_dir):
|
| 86 |
+
os.mkdir(args.output_dir)
|
| 87 |
+
test_transforms = transforms.Compose([transforms.ToTensor()])
|
| 88 |
+
|
| 89 |
+
log_format = "%(asctime)s %(levelname)-8s %(message)s"
|
| 90 |
+
log_file = os.path.join(args.output_dir, "baseline_log")
|
| 91 |
+
logging.basicConfig(filename=log_file, level=logging.INFO, format=log_format)
|
| 92 |
+
logging.getLogger().addHandler(logging.StreamHandler())
|
| 93 |
+
|
| 94 |
+
logging.info("Building data loader")
|
| 95 |
+
|
| 96 |
+
test_loader = get_loader(args.data["test_dir"],
|
| 97 |
+
eval(args.test_loader["img_size"]), test_transforms, False,
|
| 98 |
+
int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
|
| 99 |
+
args.test_loader["shuffle"], random_flag=False)
|
| 100 |
+
test(load_path, test_loader, args)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == '__main__':
|
| 104 |
+
parser = read_args("/home/yuwei/code/cvpr/config/base_config.yaml")
|
| 105 |
+
args = parser.parse_args()
|
| 106 |
+
main(args, "./pretrained_models/base_model.bin")
|
base_train.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import yaml
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
from utils import read_args, save_checkpoint, AverageMeter, calculate_metrics, CosineAnnealingWarmRestarts
|
| 5 |
+
import time
|
| 6 |
+
from tqdm import trange, tqdm
|
| 7 |
+
from torchvision.utils import save_image
|
| 8 |
+
# from tensorboardX import SummaryWriter
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
import time
|
| 12 |
+
import logging
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn, optim
|
| 15 |
+
import torchvision.utils as vutils
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
from data import *
|
| 19 |
+
from model import *
|
| 20 |
+
from loss import *
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
|
| 24 |
+
|
| 25 |
+
global_step = 0
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def train(model, data_loader, criterion, optimizer, epoch, args):
|
| 29 |
+
global global_step
|
| 30 |
+
iter_bar = tqdm(data_loader, desc='Iter (loss=X.XXX)')
|
| 31 |
+
nbatches = len(data_loader)
|
| 32 |
+
|
| 33 |
+
total_losses = AverageMeter()
|
| 34 |
+
pixel_losses = AverageMeter()
|
| 35 |
+
gradient_losses = AverageMeter()
|
| 36 |
+
psnrs = AverageMeter()
|
| 37 |
+
ssims = AverageMeter()
|
| 38 |
+
|
| 39 |
+
optimizer.zero_grad()
|
| 40 |
+
|
| 41 |
+
start_time = time.time()
|
| 42 |
+
|
| 43 |
+
if not os.path.exists(args.output_dir + '/image_train'):
|
| 44 |
+
os.mkdir(args.output_dir + '/image_train')
|
| 45 |
+
|
| 46 |
+
if not os.path.exists(args.output_dir + "/models"):
|
| 47 |
+
os.mkdir(args.output_dir + "/models")
|
| 48 |
+
|
| 49 |
+
for i, batch in enumerate(iter_bar):
|
| 50 |
+
optimizer.zero_grad()
|
| 51 |
+
|
| 52 |
+
input_img, gt_img, image_path = batch
|
| 53 |
+
input_img = input_img.cuda()
|
| 54 |
+
gt_img = gt_img.cuda()
|
| 55 |
+
batch_size = input_img.size(0)
|
| 56 |
+
|
| 57 |
+
out_img = model(input_img)
|
| 58 |
+
|
| 59 |
+
pixel_loss = criterion(out_img, gt_img)
|
| 60 |
+
pixel_losses.update(pixel_loss.item(), batch_size)
|
| 61 |
+
|
| 62 |
+
# gradient_loss = vggloss(out_img, gt_img).cuda()
|
| 63 |
+
# gradient_loss = args.hyper_params["x_lambda"] * gradient_loss
|
| 64 |
+
# gradient_losses.update(gradient_loss.item(), batch_size)
|
| 65 |
+
|
| 66 |
+
total_loss = pixel_loss
|
| 67 |
+
total_losses.update(total_loss.item(), batch_size)
|
| 68 |
+
|
| 69 |
+
total_loss.backward()
|
| 70 |
+
optimizer.step()
|
| 71 |
+
|
| 72 |
+
iter_bar.set_description('Iter (loss=%5.6f)' % total_losses.avg)
|
| 73 |
+
|
| 74 |
+
if i % 200 == 0:
|
| 75 |
+
saved_image = torch.cat([input_img[0:2], out_img[0:2], gt_img[0:2]], dim=0)
|
| 76 |
+
save_image(saved_image, args.output_dir + '/image_train/epoch_{}_iter_{}.jpg'.format(epoch, i))
|
| 77 |
+
|
| 78 |
+
# metrics
|
| 79 |
+
norm_out = torch.clamp(out_img, 0, 1)
|
| 80 |
+
#psnr_val, ssim_val = calculate_metrics(norm_out, gt_img)
|
| 81 |
+
#psnrs.update(psnr_val.item(), batch_size)
|
| 82 |
+
#ssims.update(ssim_val.item(), batch_size)
|
| 83 |
+
|
| 84 |
+
if i % max(1, nbatches // 10) == 0:
|
| 85 |
+
logging.info(
|
| 86 |
+
"Epoch {}, learning rates {:}, Iter {}, total_loss {:.4f}, pixel_loss {:.4f}, PSNR {:.4f}, SSIM {:.4f}, Elapse time {:.2f}\n".format(
|
| 87 |
+
epoch, optimizer.param_groups[0]["lr"], i, total_losses.avg, pixel_losses.avg,
|
| 88 |
+
psnrs.avg, ssims.avg,
|
| 89 |
+
time.time() - start_time))
|
| 90 |
+
|
| 91 |
+
if epoch % 1 == 0:
|
| 92 |
+
logging.info("** ** * Saving model and optimizer ** ** * ")
|
| 93 |
+
|
| 94 |
+
output_model_file = os.path.join(args.output_dir + "/models", "model.%d.bin" % (epoch))
|
| 95 |
+
state = {"epoch": epoch, "state_dict": model.state_dict(),
|
| 96 |
+
"optimizer": optimizer.state_dict(), "step": global_step}
|
| 97 |
+
|
| 98 |
+
save_checkpoint(state, output_model_file)
|
| 99 |
+
logging.info("Save model to %s", output_model_file)
|
| 100 |
+
|
| 101 |
+
logging.info(
|
| 102 |
+
"Finish training epoch %d, avg total_loss: %.4f, avg pixel_loss: %.4f, avg PSNR: %.2f, avg SSIM: %.2F, and takes %.2f seconds" % (
|
| 103 |
+
epoch, total_losses.avg, pixel_losses.avg, psnrs.avg, ssims.avg,
|
| 104 |
+
time.time() - start_time))
|
| 105 |
+
|
| 106 |
+
logging.info("***** CUDA.empty_cache() *****\n")
|
| 107 |
+
torch.cuda.empty_cache()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def evaluate(model, load_path, data_loader, epoch):
|
| 111 |
+
|
| 112 |
+
checkpoint = torch.load(load_path)
|
| 113 |
+
model.load_state_dict(checkpoint["state_dict"])
|
| 114 |
+
model.cuda()
|
| 115 |
+
model.eval()
|
| 116 |
+
|
| 117 |
+
psnrs = AverageMeter()
|
| 118 |
+
ssims = AverageMeter()
|
| 119 |
+
|
| 120 |
+
start_time = time.time()
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
for i, batch in enumerate(tqdm(data_loader)):
|
| 123 |
+
input_img, gt_img, inp_img_path = batch
|
| 124 |
+
input_img = input_img.cuda()
|
| 125 |
+
batch_size = input_img.size(0)
|
| 126 |
+
out_img = model(input_img)
|
| 127 |
+
|
| 128 |
+
# metrics
|
| 129 |
+
norm_out = torch.clamp(out_img, 0, 1)
|
| 130 |
+
psnr_val, ssim_val = calculate_metrics(norm_out, gt_img)
|
| 131 |
+
psnrs.update(psnr_val.item(), batch_size)
|
| 132 |
+
ssims.update(ssim_val.item(), batch_size)
|
| 133 |
+
torch.cuda.empty_cache()
|
| 134 |
+
|
| 135 |
+
if i % 100 == 0:
|
| 136 |
+
logging.info(
|
| 137 |
+
"PSNR {:.4f}, SSIM {:.4f}, Elapse time {:.2f}\n".format(psnrs.avg, ssims.avg,
|
| 138 |
+
time.time() - start_time))
|
| 139 |
+
|
| 140 |
+
logging.info(f"Finish test at epoch {epoch}: avg PSNR: %.4f, avg SSIM: %.4F, and takes %.2f seconds" % (
|
| 141 |
+
psnrs.avg, ssims.avg, time.time() - start_time))
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def main(args):
|
| 145 |
+
global global_step
|
| 146 |
+
|
| 147 |
+
start_epoch = 1
|
| 148 |
+
global_step = 0
|
| 149 |
+
|
| 150 |
+
if not os.path.exists(args.output_dir):
|
| 151 |
+
os.mkdir(args.output_dir)
|
| 152 |
+
|
| 153 |
+
with open(os.path.join(args.output_dir, "args.json"), "w") as f:
|
| 154 |
+
json.dump(args.__dict__, f, sort_keys=True, indent=2)
|
| 155 |
+
|
| 156 |
+
log_format = "%(asctime)s %(levelname)-8s %(message)s"
|
| 157 |
+
log_file = os.path.join(args.output_dir, "train_log")
|
| 158 |
+
logging.basicConfig(filename=log_file, level=logging.INFO, format=log_format)
|
| 159 |
+
logging.getLogger().addHandler(logging.StreamHandler())
|
| 160 |
+
|
| 161 |
+
# device setting
|
| 162 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 163 |
+
args.device = device
|
| 164 |
+
|
| 165 |
+
logging.info(args.__dict__)
|
| 166 |
+
|
| 167 |
+
if args.resume["flag"]:
|
| 168 |
+
model = net(args)
|
| 169 |
+
model.to(args.device)
|
| 170 |
+
check_point = torch.load(args.resume["checkpoint"])
|
| 171 |
+
model.load_state_dict(check_point["state_dict"])
|
| 172 |
+
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.optimizer["lr"],
|
| 173 |
+
betas=(0.9, 0.999))
|
| 174 |
+
optimizer.load_state_dict(check_point["optimizer"])
|
| 175 |
+
start_epoch = check_point["epoch"] + 1
|
| 176 |
+
# start_epoch = check_point["epoch"]
|
| 177 |
+
|
| 178 |
+
else:
|
| 179 |
+
model = net(args)
|
| 180 |
+
model.to(args.device)
|
| 181 |
+
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.optimizer["lr"],
|
| 182 |
+
betas=(0.9, 0.999))
|
| 183 |
+
|
| 184 |
+
logging.info("Building data loader")
|
| 185 |
+
|
| 186 |
+
if args.train_loader["loader"] == "resize":
|
| 187 |
+
train_transforms = transforms.Compose([transforms.Resize(eval(args.train_loader["img_size"])),
|
| 188 |
+
transforms.ToTensor()])
|
| 189 |
+
train_loader = get_loader(args.data["train_dir"],
|
| 190 |
+
eval(args.train_loader["img_size"]), train_transforms, False,
|
| 191 |
+
int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
|
| 192 |
+
args.train_loader["shuffle"], inference_flag=False)
|
| 193 |
+
|
| 194 |
+
elif args.train_loader["loader"] == "crop":
|
| 195 |
+
train_loader = get_loader(args.data["train_dir"],
|
| 196 |
+
eval(args.train_loader["img_size"]), False, True,
|
| 197 |
+
int(args.train_loader["batch_size"]), args.train_loader["num_workers"],
|
| 198 |
+
args.train_loader["shuffle"], inference_flag=False)
|
| 199 |
+
else:
|
| 200 |
+
raise NotImplementedError
|
| 201 |
+
|
| 202 |
+
if args.test_loader["loader"] == "default":
|
| 203 |
+
|
| 204 |
+
test_transforms = transforms.Compose([transforms.ToTensor()])
|
| 205 |
+
test_loader = get_loader(args.data["test_dir"],
|
| 206 |
+
eval(args.test_loader["img_size"]), test_transforms, False,
|
| 207 |
+
int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
|
| 208 |
+
args.test_loader["shuffle"], inference_flag=False)
|
| 209 |
+
|
| 210 |
+
elif args.test_loader["loader"] == "resize":
|
| 211 |
+
|
| 212 |
+
test_transforms = transforms.Compose([transforms.Resize(eval(args.test_loader["img_size"])),
|
| 213 |
+
transforms.ToTensor()])
|
| 214 |
+
test_loader = get_loader(args.data["test_dir"],
|
| 215 |
+
eval(args.test_loader["img_size"]), test_transforms, False,
|
| 216 |
+
int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
|
| 217 |
+
args.test_loader["shuffle"], inference_flag=False)
|
| 218 |
+
|
| 219 |
+
criterion = nn.L1Loss()
|
| 220 |
+
# vgg_loss = VGGLoss()
|
| 221 |
+
|
| 222 |
+
if args.optimizer["type"] == "cos":
|
| 223 |
+
lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=args.optimizer["T_0"],
|
| 224 |
+
T_mult=args.optimizer["T_MULT"],
|
| 225 |
+
eta_min=args.optimizer["ETA_MIN"],
|
| 226 |
+
last_epoch=-1)
|
| 227 |
+
elif args.optimizer["type"] == "step":
|
| 228 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.optimizer["step"],
|
| 229 |
+
gamma=args.optimizer["gamma"])
|
| 230 |
+
|
| 231 |
+
if args.resume["flag"]:
|
| 232 |
+
for i in range(start_epoch):
|
| 233 |
+
lr_scheduler.step()
|
| 234 |
+
|
| 235 |
+
t_total = int(len(train_loader) * args.optimizer["total_epoch"])
|
| 236 |
+
logging.info("***** CUDA.empty_cache() *****")
|
| 237 |
+
torch.cuda.empty_cache()
|
| 238 |
+
|
| 239 |
+
logging.info("***** Running training *****")
|
| 240 |
+
logging.info(" Batch size = %d", args.train_loader["batch_size"])
|
| 241 |
+
logging.info(" Num steps = %d", t_total)
|
| 242 |
+
logging.info(" Loader length = %d", len(train_loader))
|
| 243 |
+
|
| 244 |
+
model.train()
|
| 245 |
+
model.cuda()
|
| 246 |
+
|
| 247 |
+
logging.info("Begin training from epoch = %d\n", start_epoch)
|
| 248 |
+
for epoch in trange(start_epoch, args.optimizer["total_epoch"] + 1, desc="Epoch"):
|
| 249 |
+
train(model, train_loader, criterion, optimizer, epoch, args)
|
| 250 |
+
lr_scheduler.step()
|
| 251 |
+
if epoch % args.evaluate_intervel == 0:
|
| 252 |
+
logging.info("***** Running testing *****")
|
| 253 |
+
load_path = os.path.join(args.output_dir + "/models", "model.%d.bin" % (epoch))
|
| 254 |
+
evaluate(model, load_path, test_loader, epoch)
|
| 255 |
+
logging.info("***** End testing *****")
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if __name__ == '__main__':
|
| 259 |
+
parser = read_args("/home/yuwei/code/cvpr/config/base_config.yaml")
|
| 260 |
+
args = parser.parse_args()
|
| 261 |
+
main(args)
|
config/LMAR_config.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
output_dir: '/home/yuwei/experiment/cvpr/LMAR_cubic'
|
| 2 |
+
data:
|
| 3 |
+
train_dir: /home/data/yuwei/data/uhd4k_ll/train
|
| 4 |
+
test_dir: /home/data/yuwei/data/uhd4k_ll/test
|
| 5 |
+
|
| 6 |
+
model:
|
| 7 |
+
in_channel: 3
|
| 8 |
+
model_channel: 8
|
| 9 |
+
sparsity_threshold: 0.01
|
| 10 |
+
num_blocks: 8
|
| 11 |
+
threslhold_frac: 0.6
|
| 12 |
+
hidden_channel: 48
|
| 13 |
+
|
| 14 |
+
train_loader:
|
| 15 |
+
num_workers: 8
|
| 16 |
+
batch_size: 1
|
| 17 |
+
loader: crop
|
| 18 |
+
img_size: (1024, 1024)
|
| 19 |
+
shuffle: True
|
| 20 |
+
gt_size: (2160, 3840)
|
| 21 |
+
random_flag: True
|
| 22 |
+
|
| 23 |
+
test_loader:
|
| 24 |
+
num_workers: 8
|
| 25 |
+
batch_size: 1
|
| 26 |
+
loader: default
|
| 27 |
+
img_size: ((1440, 2560), (1080, 1920), (1200, 1600), (720, 1280), (540, 960))
|
| 28 |
+
shuffle: False
|
| 29 |
+
gt_size: (2160, 3840)
|
| 30 |
+
|
| 31 |
+
optimizer:
|
| 32 |
+
type: step
|
| 33 |
+
total_epoch: 12
|
| 34 |
+
lr: 0.0004
|
| 35 |
+
T_0: 0.00001
|
| 36 |
+
T_MULT: 1
|
| 37 |
+
ETA_MIN: 0.000001
|
| 38 |
+
step: 4
|
| 39 |
+
gamma: 0.75
|
| 40 |
+
|
| 41 |
+
hyper_params:
|
| 42 |
+
lambda: 0.5
|
| 43 |
+
|
| 44 |
+
resume:
|
| 45 |
+
flag: True
|
| 46 |
+
checkpoint: ./pretrained_models/base_model.bin
|
| 47 |
+
|
| 48 |
+
evaluate_intervel: 1
|
config/base_config.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
output_dir: '/home/yuwei/experiment/cvpr/uhd4k_ll_pretrain'
|
| 2 |
+
data:
|
| 3 |
+
train_dir: /home/data/yuwei/data/uhd4k_ll/train
|
| 4 |
+
test_dir: /home/data/yuwei/data/uhd4k_ll/test
|
| 5 |
+
|
| 6 |
+
model:
|
| 7 |
+
in_channel: 3
|
| 8 |
+
model_channel: 8
|
| 9 |
+
|
| 10 |
+
train_loader:
|
| 11 |
+
num_workers: 8
|
| 12 |
+
batch_size: 2
|
| 13 |
+
loader: resize
|
| 14 |
+
img_size: (1024, 1024)
|
| 15 |
+
shuffle: True
|
| 16 |
+
|
| 17 |
+
test_loader:
|
| 18 |
+
num_workers: 8
|
| 19 |
+
batch_size: 1
|
| 20 |
+
loader: default
|
| 21 |
+
img_size: (1200, 1600)
|
| 22 |
+
shuffle: False
|
| 23 |
+
gt_size: (2160, 3840)
|
| 24 |
+
|
| 25 |
+
optimizer:
|
| 26 |
+
type: step
|
| 27 |
+
total_epoch: 100
|
| 28 |
+
lr: 0.001
|
| 29 |
+
T_0: 100
|
| 30 |
+
T_MULT: 1
|
| 31 |
+
ETA_MIN: 0.000001
|
| 32 |
+
step: 20
|
| 33 |
+
gamma: 0.75
|
| 34 |
+
|
| 35 |
+
hyper_params:
|
| 36 |
+
x_lambda: 0.03
|
| 37 |
+
|
| 38 |
+
resume:
|
| 39 |
+
flag: False
|
| 40 |
+
checkpoint: Null
|
| 41 |
+
|
| 42 |
+
evaluate_intervel: 5
|
data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .loader import *
|
data/loader.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from torch.utils.data.dataset import Dataset
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torchvision.transforms.functional as TF
|
| 8 |
+
import torchvision.transforms as tf
|
| 9 |
+
from PIL import Image, ImageFile
|
| 10 |
+
import random
|
| 11 |
+
import math
|
| 12 |
+
from model import *
|
| 13 |
+
import torch
|
| 14 |
+
# import cv2
|
| 15 |
+
# cv2.setNumThreads(0)
|
| 16 |
+
|
| 17 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class base_dataset(Dataset):
|
| 21 |
+
def __init__(self, data_dir, img_size, transforms=False, crop=False):
|
| 22 |
+
imgs = sorted(os.listdir(data_dir + "/input"))
|
| 23 |
+
self.input_imgs = [os.path.join(data_dir + "/input", name) for name in imgs]
|
| 24 |
+
self.gt_imgs = [os.path.join(data_dir + "/gt", name) for name in imgs]
|
| 25 |
+
self.transforms = transforms
|
| 26 |
+
self.crop = crop
|
| 27 |
+
self.img_size = img_size
|
| 28 |
+
|
| 29 |
+
def __getitem__(self, index):
|
| 30 |
+
inp_img_path = self.input_imgs[index]
|
| 31 |
+
gt_img_path = self.gt_imgs[index]
|
| 32 |
+
inp_img = Image.open(inp_img_path).convert("RGB")
|
| 33 |
+
gt_img = Image.open(gt_img_path).convert("RGB")
|
| 34 |
+
if self.transforms:
|
| 35 |
+
inp_img = self.transforms(inp_img)
|
| 36 |
+
gt_img = self.transforms(gt_img)
|
| 37 |
+
|
| 38 |
+
if self.crop:
|
| 39 |
+
inp_img, gt_img = self.crop_image(inp_img, gt_img)
|
| 40 |
+
|
| 41 |
+
return inp_img, gt_img, inp_img_path
|
| 42 |
+
|
| 43 |
+
def __len__(self):
|
| 44 |
+
return len(self.gt_imgs)
|
| 45 |
+
|
| 46 |
+
def crop_image(self, inp_img, gt_img):
|
| 47 |
+
crop_h, crop_w = self.img_size
|
| 48 |
+
i, j, h, w = tf.RandomCrop.get_params(
|
| 49 |
+
inp_img, output_size=((crop_h, crop_w)))
|
| 50 |
+
inp_img = TF.crop(inp_img, i, j, h, w)
|
| 51 |
+
gt_img = TF.crop(gt_img, i, j, h, w)
|
| 52 |
+
inp_img = TF.to_tensor(inp_img)
|
| 53 |
+
gt_img = TF.to_tensor(gt_img)
|
| 54 |
+
|
| 55 |
+
return inp_img, gt_img
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class random_scale_dataset(Dataset):
|
| 59 |
+
def __init__(self, data_dir, img_size, transforms=False, crop=False):
|
| 60 |
+
imgs = sorted(os.listdir(data_dir + "/input"))
|
| 61 |
+
self.input_imgs = [os.path.join(data_dir + "/input", name) for name in imgs]
|
| 62 |
+
self.gt_imgs = [os.path.join(data_dir + "/gt", name) for name in imgs]
|
| 63 |
+
self.transforms = transforms
|
| 64 |
+
self.crop = crop
|
| 65 |
+
self.img_size = img_size
|
| 66 |
+
|
| 67 |
+
def __getitem__(self, index):
|
| 68 |
+
inp_img_path = self.input_imgs[index]
|
| 69 |
+
gt_img_path = self.gt_imgs[index]
|
| 70 |
+
inp_img = Image.open(inp_img_path).convert("RGB")
|
| 71 |
+
gt_img = Image.open(gt_img_path).convert("RGB")
|
| 72 |
+
|
| 73 |
+
random_scale_factor = random.randrange(self.img_size[0] * 0.25, self.img_size[0], 8)
|
| 74 |
+
down_h = down_w = random_scale_factor
|
| 75 |
+
|
| 76 |
+
if self.transforms:
|
| 77 |
+
inp_img = self.transforms(inp_img)
|
| 78 |
+
gt_img = self.transforms(gt_img)
|
| 79 |
+
return inp_img, gt_img, down_h, down_w, inp_img_path
|
| 80 |
+
|
| 81 |
+
if self.crop:
|
| 82 |
+
inp_img, gt_img = self.crop_image(inp_img, gt_img)
|
| 83 |
+
return inp_img, gt_img, down_h, down_w, inp_img_path
|
| 84 |
+
|
| 85 |
+
def __len__(self):
|
| 86 |
+
return len(self.gt_imgs)
|
| 87 |
+
|
| 88 |
+
def crop_image(self, inp_img, gt_img):
|
| 89 |
+
crop_h, crop_w = self.img_size
|
| 90 |
+
i, j, h, w = tf.RandomCrop.get_params(
|
| 91 |
+
inp_img, output_size=((crop_h, crop_w)))
|
| 92 |
+
inp_img = TF.crop(inp_img, i, j, h, w)
|
| 93 |
+
gt_img = TF.crop(gt_img, i, j, h, w)
|
| 94 |
+
inp_img = TF.to_tensor(inp_img)
|
| 95 |
+
gt_img = TF.to_tensor(gt_img)
|
| 96 |
+
|
| 97 |
+
return inp_img, gt_img
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_loader(data_dir, img_size, transforms, crop_flag, batch_size, num_workers, shuffle, random_flag=False, inference_flag=False):
|
| 101 |
+
if random_flag:
|
| 102 |
+
dataset = random_scale_dataset(data_dir, img_size, transforms, crop_flag)
|
| 103 |
+
dataloader = DataLoader(dataset, batch_size=batch_size,
|
| 104 |
+
shuffle=shuffle, num_workers=num_workers, pin_memory=True)
|
| 105 |
+
else:
|
| 106 |
+
dataset = base_dataset(data_dir, img_size, transforms, crop_flag)
|
| 107 |
+
dataloader = DataLoader(dataset, batch_size=batch_size,
|
| 108 |
+
shuffle=shuffle, num_workers=num_workers, pin_memory=True)
|
| 109 |
+
return dataloader
|
loss.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class VGG19(torch.nn.Module):
|
| 8 |
+
def __init__(self, requires_grad=False):
|
| 9 |
+
super().__init__()
|
| 10 |
+
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
|
| 11 |
+
self.slice1 = torch.nn.Sequential()
|
| 12 |
+
self.slice2 = torch.nn.Sequential()
|
| 13 |
+
self.slice3 = torch.nn.Sequential()
|
| 14 |
+
self.slice4 = torch.nn.Sequential()
|
| 15 |
+
self.slice5 = torch.nn.Sequential()
|
| 16 |
+
for x in range(2):
|
| 17 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
| 18 |
+
for x in range(2, 7):
|
| 19 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
| 20 |
+
for x in range(7, 12):
|
| 21 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
| 22 |
+
for x in range(12, 21):
|
| 23 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
| 24 |
+
for x in range(21, 30):
|
| 25 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
| 26 |
+
if not requires_grad:
|
| 27 |
+
for param in self.parameters():
|
| 28 |
+
param.requires_grad = False
|
| 29 |
+
|
| 30 |
+
def forward(self, X):
|
| 31 |
+
h_relu1 = self.slice1(X)
|
| 32 |
+
h_relu2 = self.slice2(h_relu1)
|
| 33 |
+
h_relu3 = self.slice3(h_relu2)
|
| 34 |
+
h_relu4 = self.slice4(h_relu3)
|
| 35 |
+
h_relu5 = self.slice5(h_relu4)
|
| 36 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
| 37 |
+
return out
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class VGGLoss(nn.Module):
|
| 41 |
+
def __init__(self):
|
| 42 |
+
super(VGGLoss, self).__init__()
|
| 43 |
+
self.vgg = VGG19().cuda()
|
| 44 |
+
self.criterion = nn.L1Loss()
|
| 45 |
+
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
|
| 46 |
+
|
| 47 |
+
def forward(self, x, y):
|
| 48 |
+
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
| 49 |
+
loss = 0
|
| 50 |
+
for i in range(len(x_vgg)):
|
| 51 |
+
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
| 52 |
+
return loss
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class VGGPerceptualLoss(torch.nn.Module):
|
| 56 |
+
def __init__(self, lam=1, lam_p=1):
|
| 57 |
+
super(VGGPerceptualLoss, self).__init__()
|
| 58 |
+
self.loss_fn = VGGPerceptualLoss()
|
| 59 |
+
|
| 60 |
+
def forward(self, out, gt):
|
| 61 |
+
loss = self.loss_fn(out, gt, feature_layers=[2])
|
| 62 |
+
|
| 63 |
+
return loss
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class VGGPerceptualLoss(torch.nn.Module):
|
| 67 |
+
def __init__(self, resize=True):
|
| 68 |
+
super(VGGPerceptualLoss, self).__init__()
|
| 69 |
+
blocks = []
|
| 70 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
|
| 71 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
|
| 72 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
|
| 73 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
|
| 74 |
+
for bl in blocks:
|
| 75 |
+
for p in bl:
|
| 76 |
+
p.requires_grad = False
|
| 77 |
+
self.blocks = torch.nn.ModuleList(blocks).cuda()
|
| 78 |
+
self.transform = torch.nn.functional.interpolate
|
| 79 |
+
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)).cuda()
|
| 80 |
+
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)).cuda()
|
| 81 |
+
self.resize = resize
|
| 82 |
+
|
| 83 |
+
def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
|
| 84 |
+
if input.shape[1] != 3:
|
| 85 |
+
input = input.repeat(1, 3, 1, 1)
|
| 86 |
+
target = target.repeat(1, 3, 1, 1)
|
| 87 |
+
input = (input - self.mean) / self.std
|
| 88 |
+
target = (target - self.mean) / self.std
|
| 89 |
+
if self.resize:
|
| 90 |
+
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
| 91 |
+
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
|
| 92 |
+
loss = 0.0
|
| 93 |
+
x = input
|
| 94 |
+
y = target
|
| 95 |
+
for i, block in enumerate(self.blocks):
|
| 96 |
+
x = block(x)
|
| 97 |
+
y = block(y)
|
| 98 |
+
if i in feature_layers:
|
| 99 |
+
loss += torch.nn.functional.l1_loss(x, y)
|
| 100 |
+
if i in style_layers:
|
| 101 |
+
act_x = x.reshape(x.shape[0], x.shape[1], -1)
|
| 102 |
+
act_y = y.reshape(y.shape[0], y.shape[1], -1)
|
| 103 |
+
gram_x = act_x @ act_x.permute(0, 2, 1)
|
| 104 |
+
gram_y = act_y @ act_y.permute(0, 2, 1)
|
| 105 |
+
loss += torch.nn.functional.l1_loss(gram_x, gram_y)
|
| 106 |
+
return loss
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def scharr(x): # 输入前对RGB通道求均值在灰度图上算
|
| 110 |
+
b, c, h, w = x.shape
|
| 111 |
+
pad = nn.ReplicationPad2d(padding=(1, 1, 1, 1))
|
| 112 |
+
x = pad(x)
|
| 113 |
+
kx = F.unfold(x, kernel_size=3, stride=1, padding=0) # b,n*k*k,n_H*n_W
|
| 114 |
+
kx = kx.permute([0, 2, 1]) # b,n_H*n_W,n*k*k
|
| 115 |
+
# kx=kx.view(1, b*h*w, 9) #1,b*n_H*n_W,n*k*k
|
| 116 |
+
|
| 117 |
+
w1 = torch.tensor([-3, 0, 3, -10, 0, 10, -3, 0, 3]).float().cuda()
|
| 118 |
+
w2 = torch.tensor([-3, -10, -3, 0, 0, 0, 3, 10, 3]).float().cuda()
|
| 119 |
+
|
| 120 |
+
y1 = torch.matmul((kx * 255.0), w1) # 1,b*n_H*n_W,1
|
| 121 |
+
y2 = torch.matmul((kx * 255.0), w2) # 1,b*n_H*n_W,1
|
| 122 |
+
# y1=y1.view(b,h*w,1) #b,n_H*n_W,1
|
| 123 |
+
y1 = y1.unsqueeze(-1).permute([0, 2, 1]) # b,1,n_H*n_W
|
| 124 |
+
# y2=y2.view(b,h*w,1) #b,n_H*n_W,1
|
| 125 |
+
y2 = y2.unsqueeze(-1).permute([0, 2, 1]) # b,1,n_H*n_W
|
| 126 |
+
|
| 127 |
+
y1 = F.fold(y1, output_size=(h, w), kernel_size=1) # b,m,n_H,n_W
|
| 128 |
+
y2 = F.fold(y2, output_size=(h, w), kernel_size=1) # b,m,n_H,n_W
|
| 129 |
+
y1 = y1.clamp(-255, 255)
|
| 130 |
+
y2 = y2.clamp(-255, 255)
|
| 131 |
+
return (0.5 * torch.abs(y1) + 0.5 * torch.abs(y2)) / 255.0
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def gram_matrix(input):
|
| 135 |
+
a, b, c, d = input.size() # a=batch size(=1)
|
| 136 |
+
# b=number of feature maps
|
| 137 |
+
# (c,d)=dimensions of a f. map (N=c*d)
|
| 138 |
+
|
| 139 |
+
features = input.reshape(a * b, c * d) # resize F_XL into \hat F_XL
|
| 140 |
+
|
| 141 |
+
G = torch.mm(features, features.t()) # compute the gram product
|
| 142 |
+
|
| 143 |
+
# we 'normalize' the values of the gram matrix
|
| 144 |
+
# by dividing by the number of element in each feature maps.
|
| 145 |
+
return G.div(a * b * c * d)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class StyleLoss(nn.Module):
|
| 149 |
+
def __init__(self):
|
| 150 |
+
super(StyleLoss, self).__init__()
|
| 151 |
+
|
| 152 |
+
def forward(self, input_fea, target_fea):
|
| 153 |
+
target = gram_matrix(target_fea).detach()
|
| 154 |
+
G = gram_matrix(input_fea)
|
| 155 |
+
loss = F.mse_loss(G, target)
|
| 156 |
+
return loss
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def cos_loss(feat1, feat2):
|
| 160 |
+
# maximize average cosine similarity
|
| 161 |
+
return -F.cosine_similarity(feat1, feat2).mean()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def feat_scharr(x):
|
| 165 |
+
x = torch.mean(x, dim=1, keepdim=True)
|
| 166 |
+
x = (x - x.min()) / (x.max() - x.min())
|
| 167 |
+
x = x * 255
|
| 168 |
+
return scharr(x)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def feat_ssim(feat1, feat2, gt):
|
| 172 |
+
mask = scharr(torch.mean(gt, dim=1, keepdim=True))
|
| 173 |
+
# mask = torch.nn.MaxPool2d(5, 1, 2)(mask)
|
| 174 |
+
mask = F.interpolate(mask, size=(feat1.shape[2], feat1.shape[3]), mode="bicubic")
|
| 175 |
+
loss = torch.abs(feat1 - feat2) * mask
|
| 176 |
+
return torch.mean(loss), mask
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def similarity_loss(f_s, f_t):
|
| 180 |
+
def at(f):
|
| 181 |
+
return F.normalize(f.pow(2).mean(1).view(f.size(0), -1))
|
| 182 |
+
|
| 183 |
+
return (at(f_s) - at(f_t)).pow(2).mean()
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class RBF(nn.Module):
|
| 187 |
+
|
| 188 |
+
def __init__(self, n_kernels=5, mul_factor=2.0, bandwidth=None):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2)
|
| 191 |
+
self.bandwidth = bandwidth
|
| 192 |
+
|
| 193 |
+
def get_bandwidth(self, L2_distances):
|
| 194 |
+
if self.bandwidth is None:
|
| 195 |
+
n_samples = L2_distances.shape[0]
|
| 196 |
+
return L2_distances.data.sum() / (n_samples ** 2 - n_samples)
|
| 197 |
+
|
| 198 |
+
return self.bandwidth
|
| 199 |
+
|
| 200 |
+
def forward(self, X):
|
| 201 |
+
L2_distances = torch.cdist(X, X) ** 2
|
| 202 |
+
|
| 203 |
+
return torch.exp(
|
| 204 |
+
-L2_distances[None, ...].cuda() / (self.get_bandwidth(L2_distances).cuda() * self.bandwidth_multipliers.cuda())[:, None,
|
| 205 |
+
None]).sum(dim=0)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class MMDLoss(nn.Module):
|
| 209 |
+
|
| 210 |
+
def __init__(self, kernel=RBF()):
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.kernel = kernel.cuda()
|
| 213 |
+
|
| 214 |
+
def forward(self, X, Y):
|
| 215 |
+
K = self.kernel(torch.vstack([X, Y]))
|
| 216 |
+
|
| 217 |
+
X_size = X.shape[0]
|
| 218 |
+
XX = K[:X_size, :X_size].mean()
|
| 219 |
+
XY = K[:X_size, X_size:].mean()
|
| 220 |
+
YY = K[X_size:, X_size:].mean()
|
| 221 |
+
return XX - 2 * XY + YY
|
metrics.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
|
| 8 |
+
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
| 9 |
+
|
| 10 |
+
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
img (ndarray): Images with range [0, 255].
|
| 14 |
+
img2 (ndarray): Images with range [0, 255].
|
| 15 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
| 16 |
+
pixels are not involved in the PSNR calculation.
|
| 17 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 18 |
+
Default: 'HWC'.
|
| 19 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
float: psnr result.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
|
| 26 |
+
if input_order not in ['HWC', 'CHW']:
|
| 27 |
+
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
|
| 28 |
+
img = img.astype(np.float64)
|
| 29 |
+
img2 = img2.astype(np.float64)
|
| 30 |
+
|
| 31 |
+
if crop_border != 0:
|
| 32 |
+
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 33 |
+
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 34 |
+
|
| 35 |
+
if test_y_channel:
|
| 36 |
+
img = to_y_channel(img)
|
| 37 |
+
img2 = to_y_channel(img2)
|
| 38 |
+
|
| 39 |
+
mse = np.mean((img - img2)**2)
|
| 40 |
+
if mse == 0:
|
| 41 |
+
return float('inf')
|
| 42 |
+
return 20. * np.log10(255. / np.sqrt(mse))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _ssim(img, img2):
|
| 46 |
+
"""Calculate SSIM (structural similarity) for one channel images.
|
| 47 |
+
|
| 48 |
+
It is called by func:`calculate_ssim`.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
img (ndarray): Images with range [0, 255] with order 'HWC'.
|
| 52 |
+
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
float: ssim result.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
c1 = (0.01 * 255)**2
|
| 59 |
+
c2 = (0.03 * 255)**2
|
| 60 |
+
|
| 61 |
+
img = img.astype(np.float64)
|
| 62 |
+
img2 = img2.astype(np.float64)
|
| 63 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
| 64 |
+
window = np.outer(kernel, kernel.transpose())
|
| 65 |
+
|
| 66 |
+
mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5]
|
| 67 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
| 68 |
+
mu1_sq = mu1**2
|
| 69 |
+
mu2_sq = mu2**2
|
| 70 |
+
mu1_mu2 = mu1 * mu2
|
| 71 |
+
sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
| 72 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
| 73 |
+
sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
| 74 |
+
|
| 75 |
+
ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
|
| 76 |
+
return ssim_map.mean()
|
| 77 |
+
|
| 78 |
+
def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
|
| 79 |
+
"""Calculate SSIM (structural similarity).
|
| 80 |
+
|
| 81 |
+
Ref:
|
| 82 |
+
Image quality assessment: From error visibility to structural similarity
|
| 83 |
+
|
| 84 |
+
The results are the same as that of the official released MATLAB code in
|
| 85 |
+
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
|
| 86 |
+
|
| 87 |
+
For three-channel images, SSIM is calculated for each channel and then
|
| 88 |
+
averaged.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
img (ndarray): Images with range [0, 255].
|
| 92 |
+
img2 (ndarray): Images with range [0, 255].
|
| 93 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
| 94 |
+
pixels are not involved in the SSIM calculation.
|
| 95 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 96 |
+
Default: 'HWC'.
|
| 97 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
float: ssim result.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
|
| 104 |
+
if input_order not in ['HWC', 'CHW']:
|
| 105 |
+
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
|
| 106 |
+
# img = reorder_image(img, input_order=input_order)
|
| 107 |
+
# img2 = reorder_image(img2, input_order=input_order)
|
| 108 |
+
img = img.astype(np.float64)
|
| 109 |
+
img2 = img2.astype(np.float64)
|
| 110 |
+
|
| 111 |
+
if crop_border != 0:
|
| 112 |
+
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 113 |
+
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 114 |
+
|
| 115 |
+
if test_y_channel:
|
| 116 |
+
img = to_y_channel(img)
|
| 117 |
+
img2 = to_y_channel(img2)
|
| 118 |
+
|
| 119 |
+
ssims = []
|
| 120 |
+
for i in range(img.shape[2]):
|
| 121 |
+
ssims.append(_ssim(img[..., i], img2[..., i]))
|
| 122 |
+
return np.array(ssims).mean()
|
| 123 |
+
|
| 124 |
+
if __name__ == '__main__':
|
| 125 |
+
|
| 126 |
+
# test_transforms = transforms.Compose([transforms.Resize((512, 512)),transforms.ToTensor()])
|
| 127 |
+
# inp_img = Image.open("/mnt/disk1/yuwei/data/4Kdehaze/train/clear/0_000002.jpg").convert("RGB")
|
| 128 |
+
# img = test_transforms(inp_img)
|
| 129 |
+
img = cv2.imread("/mnt/disk1/yuwei/data/4Kdehaze/train/clear/0_000002.jpg")
|
| 130 |
+
psnr = calculate_psnr(img, img, 0)
|
| 131 |
+
ssim = calculate_ssim(img, img, 0)
|
| 132 |
+
print(psnr)
|
| 133 |
+
print(ssim)
|
model/LMAR_model.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from model import net
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torchvision.transforms import Resize
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from resize_right import resize
|
| 9 |
+
except:
|
| 10 |
+
from .resize_right import resize
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from .interp_methods import *
|
| 14 |
+
except:
|
| 15 |
+
from interp_methods import *
|
| 16 |
+
|
| 17 |
+
from torchvision.models import vgg19
|
| 18 |
+
from torchvision.models.feature_extraction import create_feature_extractor
|
| 19 |
+
|
| 20 |
+
import tinycudann as tcnn
|
| 21 |
+
from torchvision.utils import save_image
|
| 22 |
+
import torchvision.transforms as transforms
|
| 23 |
+
from torchviz import make_dot
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def make_coord(shape, ranges=None, flatten=True):
|
| 27 |
+
""" Make coordinates at grid centers.
|
| 28 |
+
"""
|
| 29 |
+
coord_seqs = []
|
| 30 |
+
for i, n in enumerate(shape):
|
| 31 |
+
if ranges is None:
|
| 32 |
+
v0, v1 = -1, 1
|
| 33 |
+
else:
|
| 34 |
+
v0, v1 = ranges[i]
|
| 35 |
+
r = (v1 - v0) / (2 * n)
|
| 36 |
+
seq = v0 + r + (2 * r) * torch.arange(n).float()
|
| 37 |
+
coord_seqs.append(seq)
|
| 38 |
+
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
|
| 39 |
+
if flatten:
|
| 40 |
+
ret = ret.view(-1, ret.shape[-1])
|
| 41 |
+
return ret
|
| 42 |
+
|
| 43 |
+
def get_local_grid(img):
|
| 44 |
+
local_grid = make_coord(img.shape[-2:], flatten=False).cuda()
|
| 45 |
+
local_grid = local_grid.permute(2, 0, 1).unsqueeze(0)
|
| 46 |
+
local_grid = local_grid.expand(img.shape[0], 2, *img.shape[-2:])
|
| 47 |
+
|
| 48 |
+
return local_grid
|
| 49 |
+
|
| 50 |
+
def creat_coord(x):
|
| 51 |
+
b = x.shape[0]
|
| 52 |
+
coord = make_coord(x.shape[-2:], flatten=False)
|
| 53 |
+
coord = coord.permute(2, 0, 1).contiguous().unsqueeze(0)
|
| 54 |
+
coord = coord.expand(b, 2, *coord.shape[-2:])
|
| 55 |
+
|
| 56 |
+
coord_ = coord.clone()
|
| 57 |
+
coord_ = coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
|
| 58 |
+
coord_ = coord_.permute(0, 2, 3, 1).contiguous()
|
| 59 |
+
coord_ = coord_.view(b, -1, coord.size(1))
|
| 60 |
+
return coord.cuda(), coord_.cuda()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_cell(img, local_grid):
|
| 64 |
+
cell = torch.ones_like(local_grid)
|
| 65 |
+
cell[:, 0] *= 2 / img.size(2)
|
| 66 |
+
cell[:, 1] *= 2 / img.size(3)
|
| 67 |
+
|
| 68 |
+
return cell
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TcnnFCBlock(tcnn.Network):
|
| 72 |
+
def __init__(
|
| 73 |
+
self, in_features, out_features,
|
| 74 |
+
num_hidden_layers, hidden_features,
|
| 75 |
+
activation: str = 'LeakyRelu', last_activation: str = 'None',
|
| 76 |
+
seed=42):
|
| 77 |
+
assert hidden_features in [16, 32, 64, 128], "hidden_features can only be 16, 32, 64, or 128."
|
| 78 |
+
super().__init__(in_features, out_features, network_config={
|
| 79 |
+
"otype": "FullyFusedMLP", # Component type.
|
| 80 |
+
"activation": activation, # Activation of hidden layers.
|
| 81 |
+
"output_activation": last_activation, # Activation of the output layer.
|
| 82 |
+
"n_neurons": hidden_features, # Neurons in each hidden layer. # May only be 16, 32, 64, or 128.
|
| 83 |
+
"n_hidden_layers": num_hidden_layers, # Number of hidden layers.
|
| 84 |
+
}, seed=seed)
|
| 85 |
+
|
| 86 |
+
def forward(self, x: torch.Tensor):
|
| 87 |
+
prefix = x.shape[:-1]
|
| 88 |
+
return super().forward(x.flatten(0, -2)).unflatten(0, prefix)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class LMAR_model(nn.Module):
|
| 92 |
+
def __init__(self, args):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.resume_flag = args.resume["flag"]
|
| 95 |
+
self.load_path = args.resume["checkpoint"]
|
| 96 |
+
|
| 97 |
+
if self.resume_flag and self.load_path:
|
| 98 |
+
self.model = net(args)
|
| 99 |
+
checkpoint = torch.load(self.load_path)
|
| 100 |
+
self.model.load_state_dict(checkpoint["state_dict"])
|
| 101 |
+
for param in self.model.parameters():
|
| 102 |
+
param.requires_grad_(False)
|
| 103 |
+
|
| 104 |
+
self.in_channel = 3
|
| 105 |
+
self.out_channel = 3
|
| 106 |
+
self.kernel_size = 3
|
| 107 |
+
self.imnet = TcnnFCBlock(7, self.in_channel * self.out_channel * self.kernel_size * self.kernel_size, 5,
|
| 108 |
+
128).cuda()
|
| 109 |
+
self.mid_nodes = {"hr_backbone.skip2": "bottom"}
|
| 110 |
+
self.extractor_mid = create_feature_extractor(self.model, self.mid_nodes)
|
| 111 |
+
self.modulation = nn.Conv2d(6, 3, 1, 1, 0)
|
| 112 |
+
# self.projection = nn.Conv2d()
|
| 113 |
+
|
| 114 |
+
def forward(self, x, down_size, up_size, test_flag=False):
|
| 115 |
+
if test_flag:
|
| 116 |
+
up_out, _ = self.inference(x, down_size, up_size)
|
| 117 |
+
return up_out, _
|
| 118 |
+
else:
|
| 119 |
+
down_x, hr_feature, new_lr_feature, ori_lr_feature, residual, res = self.train_model(x, down_size, up_size)
|
| 120 |
+
return down_x, hr_feature, new_lr_feature, ori_lr_feature, residual, res
|
| 121 |
+
|
| 122 |
+
def train_model(self, x, down_size, up_size):
|
| 123 |
+
# down_sizer = transforms.Resize(size=down_size,
|
| 124 |
+
# interpolation=transforms.InterpolationMode.BILINEAR)
|
| 125 |
+
# up_sizer = transforms.Resize(size=up_size,
|
| 126 |
+
# interpolation=transforms.InterpolationMode.BILINEAR)
|
| 127 |
+
|
| 128 |
+
b = x.shape[0]
|
| 129 |
+
# down_x = down_sizer(x)
|
| 130 |
+
down_x = resize(x, out_shape=down_size, antialiasing=False)
|
| 131 |
+
# down_x = resize(x, out_shape=down_size, antialiasing=True)
|
| 132 |
+
|
| 133 |
+
hr_feature = self.extractor_mid(x)["bottom"]
|
| 134 |
+
# feature_sizer = transforms.Resize(size=(hr_feature.shape[2], hr_feature.shape[3]),
|
| 135 |
+
# interpolation=transforms.InterpolationMode.BILINEAR)
|
| 136 |
+
|
| 137 |
+
hr_coord, hr_coord_ = self.creat_coord(x)
|
| 138 |
+
lr_coord, _ = self.creat_coord(down_x)
|
| 139 |
+
q_coord = F.grid_sample(lr_coord, hr_coord_.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)
|
| 140 |
+
q_coord = q_coord.view(b, -1, hr_coord.size(2) * hr_coord.size(3)).permute(0, 2, 1).contiguous()
|
| 141 |
+
|
| 142 |
+
# test_coord = F.grid_sample(lr_coord, hr_coord.permute(0, 2, 3, 1), mode='bilinear', align_corners=False)
|
| 143 |
+
# test_rel_coord = hr_coord - test_coord
|
| 144 |
+
# test_rel_coord = test_rel_coord.view(b, -1, 2)
|
| 145 |
+
|
| 146 |
+
# test_rel_coord[:, :, 0] *= down_x.shape[-2]
|
| 147 |
+
# test_rel_coord[:, :, 1] *= down_x.shape[-1]
|
| 148 |
+
|
| 149 |
+
rel_coord = hr_coord_ - q_coord
|
| 150 |
+
rel_coord[:, :, 0] *= down_x.shape[-2]
|
| 151 |
+
rel_coord[:, :, 1] *= down_x.shape[-1]
|
| 152 |
+
|
| 153 |
+
laplacian = x - resize(down_x, out_shape=up_size, antialiasing=False)
|
| 154 |
+
# laplacian = x - resize(down_x, out_shape=up_size, antialiasing=True)
|
| 155 |
+
|
| 156 |
+
laplacian = laplacian.reshape(b, laplacian.size(1), -1).permute(0, 2, 1).contiguous()
|
| 157 |
+
|
| 158 |
+
# cell
|
| 159 |
+
hr_grid = self.get_local_grid(x)
|
| 160 |
+
hr_cell = self.get_cell(x, hr_grid)
|
| 161 |
+
hr_cell_ = hr_cell.clone()
|
| 162 |
+
hr_cell_ = hr_cell_.permute(0, 2, 3, 1).contiguous()
|
| 163 |
+
rel_cell = hr_cell_.view(b, -1, hr_cell.size(1))
|
| 164 |
+
rel_cell[:, :, 0] *= down_x.shape[-2]
|
| 165 |
+
rel_cell[:, :, 1] *= down_x.shape[-1]
|
| 166 |
+
|
| 167 |
+
inp = torch.cat([rel_coord.cuda(), rel_cell.cuda(), laplacian], dim=-1)
|
| 168 |
+
local_weight = self.imnet(inp)
|
| 169 |
+
local_weight = local_weight.type(torch.float32)
|
| 170 |
+
local_weight = local_weight.view(b, -1, x.shape[1] * 9, 3).contiguous()
|
| 171 |
+
|
| 172 |
+
unfolded_x = F.unfold(x, 3, padding=1).view(b, -1, x.shape[2] * x.shape[3]).permute(0, 2, 1).contiguous()
|
| 173 |
+
cols = unfolded_x.unsqueeze(2)
|
| 174 |
+
out = torch.matmul(cols, local_weight).squeeze(2).permute(0, 2, 1).contiguous().view(b, -1, x.size(2),
|
| 175 |
+
x.size(3))
|
| 176 |
+
out = resize(out, out_shape=down_size, antialiasing=False)
|
| 177 |
+
# out = resize(out, out_shape=down_size, antialiasing=True)
|
| 178 |
+
|
| 179 |
+
# out = down_sizer(out)
|
| 180 |
+
|
| 181 |
+
# ori
|
| 182 |
+
ori_lr_feature = self.extractor_mid(down_x)["bottom"]
|
| 183 |
+
ori_lr_feature = resize(ori_lr_feature, out_shape=(hr_feature.shape[2], hr_feature.shape[3]),
|
| 184 |
+
antialiasing=False)
|
| 185 |
+
# ori_lr_feature = resize(ori_lr_feature, out_shape=(hr_feature.shape[2], hr_feature.shape[3]), antialiasing=True)
|
| 186 |
+
# ori_lr_feature = feature_sizer(ori_lr_feature)
|
| 187 |
+
|
| 188 |
+
# new
|
| 189 |
+
down_x = self.modulation(torch.cat([down_x, out], dim=1))
|
| 190 |
+
new_lr_feature = self.extractor_mid(down_x)["bottom"]
|
| 191 |
+
|
| 192 |
+
new_lr_feature = resize(new_lr_feature, out_shape=(hr_feature.shape[2], hr_feature.shape[3]),
|
| 193 |
+
antialiasing=False)
|
| 194 |
+
# new_lr_feature = resize(new_lr_feature, out_shape=(hr_feature.shape[2], hr_feature.shape[3]), antialiasing=True)
|
| 195 |
+
|
| 196 |
+
# new_lr_feature = feature_sizer(new_lr_feature)
|
| 197 |
+
|
| 198 |
+
# res = resize(self.model(self.modulation(torch.cat([down_x, out], dim=1))), out_shape=up_size,
|
| 199 |
+
# antialiasing=False)
|
| 200 |
+
|
| 201 |
+
# res = up_sizer(self.model(self.modulation(torch.cat([down_x, out], dim=1))))
|
| 202 |
+
res = 0
|
| 203 |
+
|
| 204 |
+
return down_x, hr_feature, \
|
| 205 |
+
new_lr_feature, ori_lr_feature, out, res
|
| 206 |
+
|
| 207 |
+
def inference(self, x, down_size, up_size):
|
| 208 |
+
b = x.shape[0]
|
| 209 |
+
down_x = resize(x, out_shape=down_size, antialiasing=False)
|
| 210 |
+
hr_coord, hr_coord_ = self.creat_coord(x)
|
| 211 |
+
lr_coord, _ = self.creat_coord(down_x)
|
| 212 |
+
q_coord = F.grid_sample(lr_coord, hr_coord_.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)
|
| 213 |
+
q_coord = q_coord.view(b, -1, hr_coord.size(2) * hr_coord.size(3)).permute(0, 2, 1).contiguous()
|
| 214 |
+
|
| 215 |
+
rel_coord = hr_coord_ - q_coord
|
| 216 |
+
rel_coord[:, :, 0] *= down_x.shape[-2]
|
| 217 |
+
rel_coord[:, :, 1] *= down_x.shape[-1]
|
| 218 |
+
|
| 219 |
+
hr_grid = self.get_local_grid(x)
|
| 220 |
+
hr_cell = self.get_cell(x, hr_grid)
|
| 221 |
+
|
| 222 |
+
hr_cell_ = hr_cell.clone()
|
| 223 |
+
hr_cell_ = hr_cell_.permute(0, 2, 3, 1).contiguous()
|
| 224 |
+
|
| 225 |
+
rel_cell = hr_cell_.view(b, -1, hr_cell.size(1))
|
| 226 |
+
rel_cell[:, :, 0] *= down_x.shape[-2]
|
| 227 |
+
rel_cell[:, :, 1] *= down_x.shape[-1]
|
| 228 |
+
|
| 229 |
+
laplacian = x - resize(down_x, out_shape=up_size, antialiasing=False)
|
| 230 |
+
# laplacian = x - resize(down_x, out_shape=up_size, antialiasing=True)
|
| 231 |
+
|
| 232 |
+
laplacian = laplacian.reshape(b, laplacian.size(1), -1).permute(0, 2, 1).contiguous()
|
| 233 |
+
# laplacian = F.unfold(laplacian, 3, padding=1).view(b, -1, laplacian.shape[2] * laplacian.shape[3]).permute(0, 2, 1).contiguous()
|
| 234 |
+
|
| 235 |
+
inp = torch.cat([rel_coord.cuda(), rel_cell.cuda(), laplacian], dim=-1)
|
| 236 |
+
local_weight = self.imnet(inp)
|
| 237 |
+
local_weight = local_weight.type(torch.float32)
|
| 238 |
+
local_weight = local_weight.view(b, -1, x.shape[1] * 9, 3)
|
| 239 |
+
|
| 240 |
+
unfolded_x = F.unfold(x, 3, padding=1).view(b, -1, x.shape[2] * x.shape[3]).permute(0, 2, 1).contiguous()
|
| 241 |
+
|
| 242 |
+
cols = unfolded_x.unsqueeze(2)
|
| 243 |
+
|
| 244 |
+
out = torch.matmul(cols, local_weight).squeeze(2).permute(0, 2, 1).contiguous().view(b, -1, x.size(2),
|
| 245 |
+
x.size(3))
|
| 246 |
+
out = resize(out, out_shape=down_size, antialiasing=False)
|
| 247 |
+
down_x = self.modulation(torch.cat([down_x, out], dim=1))
|
| 248 |
+
|
| 249 |
+
res = resize(self.model(down_x), out_shape=up_size, antialiasing=False)
|
| 250 |
+
return res, down_x
|
| 251 |
+
|
| 252 |
+
def creat_coord(self, x):
|
| 253 |
+
b = x.shape[0]
|
| 254 |
+
coord = make_coord(x.shape[-2:], flatten=False)
|
| 255 |
+
coord = coord.permute(2, 0, 1).contiguous().unsqueeze(0)
|
| 256 |
+
coord = coord.expand(b, 2, *coord.shape[-2:])
|
| 257 |
+
|
| 258 |
+
coord_ = coord.clone()
|
| 259 |
+
coord_ = coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
|
| 260 |
+
coord_ = coord_.permute(0, 2, 3, 1).contiguous()
|
| 261 |
+
coord_ = coord_.view(b, -1, coord.size(1))
|
| 262 |
+
return coord.cuda(), coord_.cuda()
|
| 263 |
+
|
| 264 |
+
def get_local_grid(self, img):
|
| 265 |
+
local_grid = make_coord(img.shape[-2:], flatten=False).cuda()
|
| 266 |
+
local_grid = local_grid.permute(2, 0, 1).unsqueeze(0)
|
| 267 |
+
local_grid = local_grid.expand(img.shape[0], 2, *img.shape[-2:])
|
| 268 |
+
|
| 269 |
+
return local_grid
|
| 270 |
+
|
| 271 |
+
def get_cell(self, img, local_grid):
|
| 272 |
+
cell = torch.ones_like(local_grid)
|
| 273 |
+
cell[:, 0] *= 2 / img.size(2)
|
| 274 |
+
cell[:, 1] *= 2 / img.size(3)
|
| 275 |
+
|
| 276 |
+
return cell
|
| 277 |
+
|
model/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model import net
|
| 2 |
+
from .resize_right import resize
|
| 3 |
+
from .interp_methods import *
|
| 4 |
+
from .module import Discriminator, Discriminator_new
|
| 5 |
+
from .LMAR_model import *
|
model/interp_methods.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from math import pi
|
| 2 |
+
|
| 3 |
+
try:
|
| 4 |
+
import torch
|
| 5 |
+
except ImportError:
|
| 6 |
+
torch = None
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import numpy
|
| 10 |
+
except ImportError:
|
| 11 |
+
numpy = None
|
| 12 |
+
|
| 13 |
+
if numpy is None and torch is None:
|
| 14 |
+
raise ImportError("Must have either Numpy or PyTorch but both not found")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def set_framework_dependencies(x):
|
| 18 |
+
if type(x) is numpy.ndarray:
|
| 19 |
+
to_dtype = lambda a: a
|
| 20 |
+
fw = numpy
|
| 21 |
+
else:
|
| 22 |
+
to_dtype = lambda a: a.to(x.dtype)
|
| 23 |
+
fw = torch
|
| 24 |
+
eps = fw.finfo(fw.float32).eps
|
| 25 |
+
return fw, to_dtype, eps
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def support_sz(sz):
|
| 29 |
+
def wrapper(f):
|
| 30 |
+
f.support_sz = sz
|
| 31 |
+
return f
|
| 32 |
+
return wrapper
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@support_sz(4)
|
| 36 |
+
def cubic(x):
|
| 37 |
+
fw, to_dtype, eps = set_framework_dependencies(x)
|
| 38 |
+
absx = fw.abs(x)
|
| 39 |
+
absx2 = absx ** 2
|
| 40 |
+
absx3 = absx ** 3
|
| 41 |
+
return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +
|
| 42 |
+
(-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *
|
| 43 |
+
to_dtype((1. < absx) & (absx <= 2.)))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@support_sz(4)
|
| 47 |
+
def lanczos2(x):
|
| 48 |
+
fw, to_dtype, eps = set_framework_dependencies(x)
|
| 49 |
+
return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) /
|
| 50 |
+
((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@support_sz(6)
|
| 54 |
+
def lanczos3(x):
|
| 55 |
+
fw, to_dtype, eps = set_framework_dependencies(x)
|
| 56 |
+
return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) /
|
| 57 |
+
((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@support_sz(2)
|
| 61 |
+
def linear(x):
|
| 62 |
+
fw, to_dtype, eps = set_framework_dependencies(x)
|
| 63 |
+
return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) *
|
| 64 |
+
to_dtype((0 <= x) & (x <= 1)))
|
| 65 |
+
|
| 66 |
+
@support_sz(1)
|
| 67 |
+
def box(x):
|
| 68 |
+
fw, to_dtype, eps = set_framework_dependencies(x)
|
| 69 |
+
return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))
|
model/model.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from .module import *
|
| 3 |
+
except:
|
| 4 |
+
from module import *
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torch.nn.init as init
|
| 14 |
+
|
| 15 |
+
class SuperUnet_MS(nn.Module):
|
| 16 |
+
def __init__(self, channels, block="INV"):
|
| 17 |
+
super(SuperUnet_MS, self).__init__()
|
| 18 |
+
# ---------ENCODE
|
| 19 |
+
self.layer_dowm1 = basic_block(channels, channels, block)
|
| 20 |
+
self.dowm1 = nn.Sequential(nn.Conv2d(channels, channels * 2, 4, 2, 1, bias=True),
|
| 21 |
+
nn.InstanceNorm2d(channels * 2, affine=True), nn.LeakyReLU(0.2, inplace=True))
|
| 22 |
+
self.layer_dowm2 = basic_block(channels * 2, channels * 2, block)
|
| 23 |
+
self.dowm2 = nn.Sequential(nn.Conv2d(channels * 2, channels * 4, 4, 2, 1, bias=True),
|
| 24 |
+
nn.InstanceNorm2d(channels * 4, affine=True), nn.LeakyReLU(0.2, inplace=True))
|
| 25 |
+
# ---------DECODE
|
| 26 |
+
self.layer_bottom = basic_block(channels * 4, channels * 4, block)
|
| 27 |
+
self.up2 = nn.Sequential(nn.ConvTranspose2d(channels * 4, channels * 2, 4, 2, 1, bias=True),
|
| 28 |
+
nn.InstanceNorm2d(channels * 2, affine=True), nn.LeakyReLU(0.2, inplace=True))
|
| 29 |
+
self.layer_up2 = basic_block(channels * 2, channels * 2, block)
|
| 30 |
+
self.up1 = nn.Sequential(nn.ConvTranspose2d(channels * 2, channels, 4, 2, 1, bias=True),
|
| 31 |
+
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
|
| 32 |
+
self.layer_up1 = basic_block(channels, channels, block)
|
| 33 |
+
# ---------SKIP
|
| 34 |
+
self.fus2 = skip(channels * 4, channels * 2, "HIN")
|
| 35 |
+
self.fus1 = skip(channels * 2, channels, "HIN")
|
| 36 |
+
# ---------SKIP
|
| 37 |
+
self.skip_down1 = nn.Sequential(nn.Conv2d(channels, channels, 4, 2, 1, bias=True),
|
| 38 |
+
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
|
| 39 |
+
self.skip1 = skip(channels * 3, channels * 2, "CONV")
|
| 40 |
+
self.skip_down2 = nn.Sequential(nn.Conv2d(channels * 2, channels, 4, 2, 1, bias=True),
|
| 41 |
+
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
|
| 42 |
+
self.skip2 = skip(channels * 5, channels * 4, "CONV")
|
| 43 |
+
# self.skip3 = skip(channels*2, channels, "CONV")
|
| 44 |
+
self.skip_up4 = nn.Sequential(nn.ConvTranspose2d(channels * 4, channels, 4, 2, 1, bias=True),
|
| 45 |
+
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
|
| 46 |
+
self.skip4 = skip(channels * 3, channels * 2, "CONV")
|
| 47 |
+
# self.skip5 = skip(channels*2, channels, "CONV")
|
| 48 |
+
self.skip_up6 = nn.Sequential(nn.ConvTranspose2d(channels * 2, channels, 4, 2, 1, bias=True),
|
| 49 |
+
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
|
| 50 |
+
self.skip6 = skip(channels * 2, channels, "CONV")
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
# ---------ENCODE
|
| 54 |
+
x_11 = self.layer_dowm1(x)
|
| 55 |
+
x_down1 = self.dowm1(x_11)
|
| 56 |
+
# x =self.skip_down1(x)
|
| 57 |
+
# print(x.shape, x_down1.shape)
|
| 58 |
+
|
| 59 |
+
x_down1 = self.skip1(torch.cat([self.skip_down1(x), x_down1], 1), x_down1)
|
| 60 |
+
|
| 61 |
+
x_12 = self.layer_dowm2(x_down1)
|
| 62 |
+
x_down2 = self.dowm2(x_12)
|
| 63 |
+
x_down2 = self.skip2(torch.cat([self.skip_down2(x_down1), x_down2], 1), x_down2)
|
| 64 |
+
|
| 65 |
+
x_bottom = self.layer_bottom(x_down2)
|
| 66 |
+
|
| 67 |
+
# ---------DECODE
|
| 68 |
+
x_up2 = self.up2(x_bottom)
|
| 69 |
+
x_22 = self.layer_up2(x_up2)
|
| 70 |
+
x_22 = self.skip4(torch.cat([self.skip_up4(x_bottom), x_22], 1), x_22)
|
| 71 |
+
x_22 = self.fus2(torch.cat([x_12, x_22], 1), x_22)
|
| 72 |
+
|
| 73 |
+
x_up1 = self.up1(x_22)
|
| 74 |
+
x_21 = self.layer_up1(x_up1)
|
| 75 |
+
x_21 = self.skip6(torch.cat([self.skip_up6(x_22), x_21], 1), x_21)
|
| 76 |
+
x_21 = self.fus1(torch.cat([x_11, x_21], 1), x_21)
|
| 77 |
+
return x_21, x_down2
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class skip(nn.Module):
|
| 81 |
+
def __init__(self, channels_in, channels_out, block):
|
| 82 |
+
super(skip, self).__init__()
|
| 83 |
+
if block == "CONV":
|
| 84 |
+
self.body = nn.Sequential(nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=True),
|
| 85 |
+
nn.InstanceNorm2d(channels_out, affine=True), nn.ReLU(inplace=True), )
|
| 86 |
+
if block == "ID":
|
| 87 |
+
self.body = nn.Identity()
|
| 88 |
+
if block == "INV":
|
| 89 |
+
self.body = nn.Sequential(InvBlock(channels_in, channels_in // 2),
|
| 90 |
+
nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=True), )
|
| 91 |
+
if block == "HIN":
|
| 92 |
+
self.body = nn.Sequential(HinBlock(channels_in, channels_out))
|
| 93 |
+
# --------------------------------------
|
| 94 |
+
self.alpha1 = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
|
| 95 |
+
self.alpha1.data.fill_(1.0)
|
| 96 |
+
self.alpha2 = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
|
| 97 |
+
self.alpha2.data.fill_(0.5)
|
| 98 |
+
|
| 99 |
+
def forward(self, x, y):
|
| 100 |
+
out = self.alpha1 * self.body(x) + self.alpha2 * y
|
| 101 |
+
return out
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def subnet(net_structure, init='xavier'):
|
| 105 |
+
def constructor(channel_in, channel_out):
|
| 106 |
+
if net_structure == 'HIN':
|
| 107 |
+
return HinBlock(channel_in, channel_out)
|
| 108 |
+
|
| 109 |
+
return constructor
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class InvBlock(nn.Module):
|
| 113 |
+
def __init__(self, channel_num, channel_split_num, subnet_constructor=subnet('HIN'),
|
| 114 |
+
clamp=0.8): ################ split_channel一般设为channel_num的一半
|
| 115 |
+
super(InvBlock, self).__init__()
|
| 116 |
+
# channel_num: 3
|
| 117 |
+
# channel_split_num: 1
|
| 118 |
+
|
| 119 |
+
self.split_len1 = channel_split_num # 1
|
| 120 |
+
self.split_len2 = channel_num - channel_split_num # 2
|
| 121 |
+
|
| 122 |
+
self.clamp = clamp
|
| 123 |
+
|
| 124 |
+
self.F = subnet_constructor(self.split_len2, self.split_len1)
|
| 125 |
+
self.G = subnet_constructor(self.split_len1, self.split_len2)
|
| 126 |
+
self.H = subnet_constructor(self.split_len1, self.split_len2)
|
| 127 |
+
|
| 128 |
+
def forward(self, x):
|
| 129 |
+
x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2))
|
| 130 |
+
|
| 131 |
+
y1 = x1 + self.F(x2) # 1 channel
|
| 132 |
+
self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1)
|
| 133 |
+
y2 = x2.mul(torch.exp(self.s)) + self.G(y1) # 2 channel
|
| 134 |
+
out = torch.cat((y1, y2), 1)
|
| 135 |
+
|
| 136 |
+
return out + x
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class sample_block(nn.Module):
|
| 140 |
+
def __init__(self, channels_in, channels_out, size, dil):
|
| 141 |
+
super(sample_block, self).__init__()
|
| 142 |
+
# ------------------------------------------
|
| 143 |
+
if size == "DOWN":
|
| 144 |
+
self.conv = nn.Sequential(
|
| 145 |
+
nn.Conv2d(channels_in, channels_out, 3, 1, dil, dilation=dil),
|
| 146 |
+
nn.InstanceNorm2d(channels_out, affine=True),
|
| 147 |
+
nn.ReLU(inplace=True),
|
| 148 |
+
)
|
| 149 |
+
if size == "UP":
|
| 150 |
+
self.conv = nn.Sequential(
|
| 151 |
+
nn.ConvTranspose2d(channels_in, channels_out, 3, 1, dil, dilation=dil),
|
| 152 |
+
nn.InstanceNorm2d(channels_out, affine=True),
|
| 153 |
+
nn.ReLU(inplace=True),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
return self.conv(x)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class HinBlock(nn.Module):
|
| 161 |
+
def __init__(self, in_size, out_size):
|
| 162 |
+
super(HinBlock, self).__init__()
|
| 163 |
+
self.identity = nn.Conv2d(in_size, out_size, 1, 1, 0)
|
| 164 |
+
self.norm = nn.InstanceNorm2d(out_size // 2, affine=True)
|
| 165 |
+
|
| 166 |
+
self.conv_1 = nn.Conv2d(in_size, out_size, kernel_size=3, stride=1, padding=1, bias=True)
|
| 167 |
+
self.relu_1 = nn.Sequential(nn.LeakyReLU(0.2, inplace=False), )
|
| 168 |
+
self.conv_2 = nn.Sequential(nn.Conv2d(out_size, out_size, kernel_size=3, stride=1, padding=1, bias=True),
|
| 169 |
+
nn.LeakyReLU(0.2, inplace=False), )
|
| 170 |
+
|
| 171 |
+
def forward(self, x):
|
| 172 |
+
out = self.conv_1(x)
|
| 173 |
+
out_1, out_2 = torch.chunk(out, 2, dim=1)
|
| 174 |
+
out = torch.cat([self.norm(out_1), out_2], dim=1)
|
| 175 |
+
out = self.relu_1(out)
|
| 176 |
+
out = self.conv_2(out)
|
| 177 |
+
out += self.identity(x)
|
| 178 |
+
return out
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class net(nn.Module):
|
| 182 |
+
def __init__(self, args):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.args = args.model
|
| 185 |
+
self.hr_inc = DoubleConv(self.args["in_channel"], self.args["model_channel"] * 2)
|
| 186 |
+
self.hr_backbone = SuperUnet_MS(self.args["model_channel"] * 2)
|
| 187 |
+
self.final_out = nn.Conv2d(self.args["model_channel"] * 2, 3, kernel_size=1, bias=False)
|
| 188 |
+
|
| 189 |
+
def forward(self, x):
|
| 190 |
+
x = self.hr_inc(x)
|
| 191 |
+
x, mid_feat = self.hr_backbone(x)
|
| 192 |
+
out = self.final_out(x)
|
| 193 |
+
return out
|
| 194 |
+
|
model/module.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
from torchvision.transforms.functional import rgb_to_grayscale
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
class DoubleConv(nn.Module):
|
| 9 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
| 12 |
+
super().__init__()
|
| 13 |
+
if not mid_channels:
|
| 14 |
+
mid_channels = out_channels
|
| 15 |
+
self.double_conv = nn.Sequential(
|
| 16 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
| 17 |
+
nn.ReLU(inplace=True),
|
| 18 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
| 19 |
+
nn.ReLU(inplace=True)
|
| 20 |
+
)
|
| 21 |
+
self.apply(self._init_weights)
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
return self.double_conv(x)
|
| 25 |
+
|
| 26 |
+
def _init_weights(self, m):
|
| 27 |
+
if isinstance(m, nn.Conv2d):
|
| 28 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 29 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 30 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 31 |
+
m.weight.data.fill_(1)
|
| 32 |
+
m.bias.data.zero_()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Down(nn.Module):
|
| 36 |
+
"""Downscaling with maxpool then double conv"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, in_channels, out_channels):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.maxpool_conv = nn.Sequential(
|
| 41 |
+
nn.MaxPool2d(2),
|
| 42 |
+
DoubleConv(in_channels, out_channels)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
return self.maxpool_conv(x)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Up(nn.Module):
|
| 50 |
+
"""Upscaling then double conv"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
# if bilinear, use the normal convolutions to reduce the number of channels
|
| 56 |
+
if bilinear:
|
| 57 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 58 |
+
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
|
| 59 |
+
else:
|
| 60 |
+
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
| 61 |
+
self.conv = DoubleConv(in_channels, out_channels)
|
| 62 |
+
|
| 63 |
+
def forward(self, x1, x2):
|
| 64 |
+
x1 = self.up(x1)
|
| 65 |
+
# input is CHW
|
| 66 |
+
diffY = x2.size()[2] - x1.size()[2]
|
| 67 |
+
diffX = x2.size()[3] - x1.size()[3]
|
| 68 |
+
|
| 69 |
+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
|
| 70 |
+
diffY // 2, diffY - diffY // 2])
|
| 71 |
+
# if you have padding issues, see
|
| 72 |
+
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
| 73 |
+
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
| 74 |
+
x = torch.cat([x2, x1], dim=1)
|
| 75 |
+
return self.conv(x)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# spatial attention
|
| 79 |
+
class SpatialGate(nn.Module):
|
| 80 |
+
def __init__(self, in_channels):
|
| 81 |
+
super(SpatialGate, self).__init__()
|
| 82 |
+
self.spatial = nn.Conv2d(in_channels, 1, kernel_size=1)
|
| 83 |
+
self.sigmoid = nn.Sigmoid()
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
x_out = self.spatial(x)
|
| 87 |
+
scale = self.sigmoid(x_out)
|
| 88 |
+
return scale * x
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# sobel
|
| 92 |
+
class SobelOperator(nn.Module):
|
| 93 |
+
def __init__(self):
|
| 94 |
+
super(SobelOperator, self).__init__()
|
| 95 |
+
self.conv_x = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
|
| 96 |
+
self.conv_y = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
|
| 97 |
+
self.conv_x.weight[0].data[:, :, :] = torch.FloatTensor([[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]])
|
| 98 |
+
self.conv_y.weight[0].data[:, :, :] = torch.FloatTensor([[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]])
|
| 99 |
+
|
| 100 |
+
def forward(self, x):
|
| 101 |
+
G_x = self.conv_x(x)
|
| 102 |
+
G_y = self.conv_y(x)
|
| 103 |
+
grad_mag = torch.sqrt(torch.pow(G_x, 2) + torch.pow(G_y, 2))
|
| 104 |
+
return grad_mag
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class offset_estimator(nn.Sequential):
|
| 108 |
+
def __init__(self, kernel_size, fwhm, in_channel, mid_channel, out_channel) -> None:
|
| 109 |
+
super().__init__()
|
| 110 |
+
model = []
|
| 111 |
+
assert len(kernel_size) == len(fwhm), "length error"
|
| 112 |
+
for i in range(len(kernel_size)):
|
| 113 |
+
if i == 0:
|
| 114 |
+
gaussian_weight = torch.FloatTensor(gaussian_2d(kernel_size[i], fwhm=fwhm[i]))
|
| 115 |
+
gauss_filter = nn.Conv2d(in_channel, mid_channel, kernel_size[i], padding=(kernel_size[i] - 1) // 2,
|
| 116 |
+
bias=False)
|
| 117 |
+
gauss_filter.weight[0].data[:, :, :] = gaussian_weight
|
| 118 |
+
model += [gauss_filter, nn.ReLU(inplace=True)]
|
| 119 |
+
elif i == len(kernel_size) - 1:
|
| 120 |
+
gaussian_weight = torch.FloatTensor(gaussian_2d(kernel_size[i], fwhm=fwhm[i]))
|
| 121 |
+
gauss_filter = nn.Conv2d(mid_channel, out_channel, kernel_size[i], padding=(kernel_size[i] - 1) // 2,
|
| 122 |
+
bias=False)
|
| 123 |
+
gauss_filter.weight[0].data[:, :, :] = gaussian_weight
|
| 124 |
+
model += [gauss_filter, nn.ReLU(inplace=True)]
|
| 125 |
+
else:
|
| 126 |
+
gaussian_weight = torch.FloatTensor(gaussian_2d(kernel_size[i], fwhm=fwhm[i]))
|
| 127 |
+
gauss_filter = nn.Conv2d(mid_channel, mid_channel, kernel_size[i], padding=(kernel_size[i] - 1) // 2,
|
| 128 |
+
bias=False)
|
| 129 |
+
gauss_filter.weight[0].data[:, :, :] = gaussian_weight
|
| 130 |
+
model += [gauss_filter, nn.ReLU(inplace=True)]
|
| 131 |
+
self.model = nn.Sequential(*model)
|
| 132 |
+
|
| 133 |
+
def forward(self, x):
|
| 134 |
+
return self.model(x)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# Channel attention
|
| 138 |
+
def logsumexp_2d(tensor):
|
| 139 |
+
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
|
| 140 |
+
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
|
| 141 |
+
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
|
| 142 |
+
return outputs
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Flatten(nn.Module):
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
return x.view(x.size(0), -1)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class ChannelGate(nn.Module):
|
| 151 |
+
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
|
| 152 |
+
super(ChannelGate, self).__init__()
|
| 153 |
+
self.gate_channels = gate_channels
|
| 154 |
+
self.mlp = nn.Sequential(
|
| 155 |
+
Flatten(),
|
| 156 |
+
nn.Linear(gate_channels, gate_channels // reduction_ratio),
|
| 157 |
+
nn.ReLU(),
|
| 158 |
+
nn.Linear(gate_channels // reduction_ratio, gate_channels)
|
| 159 |
+
)
|
| 160 |
+
self.pool_types = pool_types
|
| 161 |
+
|
| 162 |
+
def forward(self, x):
|
| 163 |
+
channel_att_sum = None
|
| 164 |
+
for pool_type in self.pool_types:
|
| 165 |
+
if pool_type == 'avg':
|
| 166 |
+
avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
| 167 |
+
channel_att_raw = self.mlp(avg_pool)
|
| 168 |
+
elif pool_type == 'max':
|
| 169 |
+
max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
| 170 |
+
channel_att_raw = self.mlp(max_pool)
|
| 171 |
+
elif pool_type == 'lp':
|
| 172 |
+
lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
| 173 |
+
channel_att_raw = self.mlp(lp_pool)
|
| 174 |
+
elif pool_type == 'lse':
|
| 175 |
+
# LSE pool only
|
| 176 |
+
lse_pool = logsumexp_2d(x)
|
| 177 |
+
channel_att_raw = self.mlp(lse_pool)
|
| 178 |
+
|
| 179 |
+
if channel_att_sum is None:
|
| 180 |
+
channel_att_sum = channel_att_raw
|
| 181 |
+
else:
|
| 182 |
+
channel_att_sum = channel_att_sum + channel_att_raw
|
| 183 |
+
|
| 184 |
+
scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
|
| 185 |
+
return x * scale
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# LBP
|
| 189 |
+
def LBP(image): # b, 3, h, w tensor
|
| 190 |
+
radius = 2
|
| 191 |
+
n_points = 8 * radius
|
| 192 |
+
method = 'uniform'
|
| 193 |
+
gray_img = rgb_to_grayscale(image) # b, 1, h, w
|
| 194 |
+
gray_img = gray_img.squeeze(1)
|
| 195 |
+
lbf_feature = np.zeros((gray_img.shape[0], gray_img.shape[1], gray_img.shape[2]))
|
| 196 |
+
for i in range(gray_img.shape[0]):
|
| 197 |
+
lbf_feature[i] = feature.local_binary_pattern(gray_img[i], n_points, radius, method)
|
| 198 |
+
return torch.FloatTensor(lbf_feature).unsqueeze(1)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class Discriminator(nn.Module):
|
| 202 |
+
def __init__(self, in_channel):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.in_channel = in_channel
|
| 205 |
+
|
| 206 |
+
def discriminator_block(in_filters, out_filters):
|
| 207 |
+
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=False)]
|
| 208 |
+
return layers
|
| 209 |
+
|
| 210 |
+
self.model = nn.Sequential(
|
| 211 |
+
*discriminator_block(self.in_channel, 4),
|
| 212 |
+
*discriminator_block(4, 4),
|
| 213 |
+
*discriminator_block(4, 4),
|
| 214 |
+
*discriminator_block(4, 4),
|
| 215 |
+
nn.ZeroPad2d((1, 0, 1, 0)),
|
| 216 |
+
nn.Conv2d(4, 1, 4, padding=1, bias=False)
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def forward(self, x):
|
| 220 |
+
return self.model(x)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class Discriminator_new(nn.Module):
|
| 224 |
+
def __init__(self):
|
| 225 |
+
super().__init__()
|
| 226 |
+
|
| 227 |
+
def discriminator_block(in_filters, out_filters, first_block=False):
|
| 228 |
+
layers = []
|
| 229 |
+
layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
|
| 230 |
+
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
| 231 |
+
layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
|
| 232 |
+
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
| 233 |
+
return layers
|
| 234 |
+
|
| 235 |
+
layers = []
|
| 236 |
+
in_filters = 3
|
| 237 |
+
for i, out_filters in enumerate([4, 6, 8, 10]):
|
| 238 |
+
layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
|
| 239 |
+
in_filters = out_filters
|
| 240 |
+
|
| 241 |
+
layers.append(nn.ZeroPad2d((1, 0, 1, 0)))
|
| 242 |
+
layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))
|
| 243 |
+
|
| 244 |
+
self.model = nn.Sequential(*layers)
|
| 245 |
+
|
| 246 |
+
def forward(self, img):
|
| 247 |
+
return self.model(img)
|
| 248 |
+
|
model/resize_right.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
import warnings
|
| 3 |
+
from math import ceil
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from .interp_methods import *
|
| 7 |
+
except:
|
| 8 |
+
from interp_methods import *
|
| 9 |
+
from fractions import Fraction
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class NoneClass:
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
|
| 20 |
+
nnModuleWrapped = nn.Module
|
| 21 |
+
except ImportError:
|
| 22 |
+
warnings.warn('No PyTorch found, will work only with Numpy')
|
| 23 |
+
torch = None
|
| 24 |
+
nnModuleWrapped = NoneClass
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
import numpy
|
| 28 |
+
except ImportError:
|
| 29 |
+
warnings.warn('No Numpy found, will work only with PyTorch')
|
| 30 |
+
numpy = None
|
| 31 |
+
|
| 32 |
+
if numpy is None and torch is None:
|
| 33 |
+
raise ImportError("Must have either Numpy or PyTorch but both not found")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def resize(input, scale_factors=None, out_shape=None,
|
| 37 |
+
interp_method=lanczos3, support_sz=None,
|
| 38 |
+
antialiasing=True, by_convs=False, scale_tolerance=None,
|
| 39 |
+
max_numerator=10, pad_mode='constant', adv_weights=None):
|
| 40 |
+
# get properties of the input tensor
|
| 41 |
+
in_shape, n_dims = input.shape, input.ndim
|
| 42 |
+
|
| 43 |
+
# fw stands for framework that can be either numpy or torch,
|
| 44 |
+
# determined by the input type
|
| 45 |
+
fw = numpy if type(input) is numpy.ndarray else torch
|
| 46 |
+
eps = fw.finfo(fw.float32).eps
|
| 47 |
+
device = input.device if fw is torch else None
|
| 48 |
+
weights_container = []
|
| 49 |
+
|
| 50 |
+
# set missing scale factors or output shapem one according to another,
|
| 51 |
+
# scream if both missing. this is also where all the defults policies
|
| 52 |
+
# take place. also handling the by_convs attribute carefully.
|
| 53 |
+
scale_factors, out_shape, by_convs = set_scale_and_out_sz(in_shape,
|
| 54 |
+
out_shape,
|
| 55 |
+
scale_factors,
|
| 56 |
+
by_convs,
|
| 57 |
+
scale_tolerance,
|
| 58 |
+
max_numerator,
|
| 59 |
+
eps, fw)
|
| 60 |
+
|
| 61 |
+
# sort indices of dimensions according to scale of each dimension.
|
| 62 |
+
# since we are going dim by dim this is efficient
|
| 63 |
+
sorted_filtered_dims_and_scales = [(dim, scale_factors[dim], by_convs[dim],
|
| 64 |
+
in_shape[dim], out_shape[dim])
|
| 65 |
+
for dim in sorted(range(n_dims),
|
| 66 |
+
key=lambda ind: scale_factors[ind])
|
| 67 |
+
if scale_factors[dim] != 1.]
|
| 68 |
+
|
| 69 |
+
# unless support size is specified by the user, it is an attribute
|
| 70 |
+
# of the interpolation method
|
| 71 |
+
if support_sz is None:
|
| 72 |
+
support_sz = interp_method.support_sz
|
| 73 |
+
|
| 74 |
+
# output begins identical to input and changes with each iteration
|
| 75 |
+
output = input
|
| 76 |
+
|
| 77 |
+
# iterate over dims
|
| 78 |
+
for i, (dim, scale_factor, dim_by_convs, in_sz, out_sz
|
| 79 |
+
) in enumerate(sorted_filtered_dims_and_scales):
|
| 80 |
+
# STEP 1- PROJECTED GRID: The non-integer locations of the projection
|
| 81 |
+
# of output pixel locations to the input tensor
|
| 82 |
+
projected_grid = get_projected_grid(in_sz, out_sz,
|
| 83 |
+
scale_factor, fw, dim_by_convs,
|
| 84 |
+
device)
|
| 85 |
+
|
| 86 |
+
# STEP 1.5: ANTIALIASING- If antialiasing is taking place, we modify
|
| 87 |
+
# the window size and the interpolation method (see inside function)
|
| 88 |
+
cur_interp_method, cur_support_sz = apply_antialiasing_if_needed(
|
| 89 |
+
interp_method,
|
| 90 |
+
support_sz,
|
| 91 |
+
scale_factor,
|
| 92 |
+
antialiasing)
|
| 93 |
+
|
| 94 |
+
# STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels
|
| 95 |
+
# that influence it. Also calculate needed padding and update grid
|
| 96 |
+
# accoedingly
|
| 97 |
+
field_of_view = get_field_of_view(projected_grid, cur_support_sz, fw,
|
| 98 |
+
eps, device)
|
| 99 |
+
|
| 100 |
+
# STEP 2.5- CALCULATE PAD AND UPDATE: according to the field of view,
|
| 101 |
+
# the input should be padded to handle the boundaries, coordinates
|
| 102 |
+
# should be updated. actual padding only occurs when weights are
|
| 103 |
+
# aplied (step 4). if using by_convs for this dim, then we need to
|
| 104 |
+
# calc right and left boundaries for each filter instead.
|
| 105 |
+
pad_sz, projected_grid, field_of_view = calc_pad_sz(in_sz, out_sz,
|
| 106 |
+
field_of_view,
|
| 107 |
+
projected_grid,
|
| 108 |
+
scale_factor,
|
| 109 |
+
dim_by_convs, fw,
|
| 110 |
+
device)
|
| 111 |
+
|
| 112 |
+
# STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in
|
| 113 |
+
# the field of view for each output pixel
|
| 114 |
+
if adv_weights != None:
|
| 115 |
+
weights = adv_weights[i]
|
| 116 |
+
else:
|
| 117 |
+
weights = get_weights(cur_interp_method, projected_grid, field_of_view)
|
| 118 |
+
weights_container.append(weights)
|
| 119 |
+
|
| 120 |
+
# STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying
|
| 121 |
+
# its set of weights with the pixel values in its field of view.
|
| 122 |
+
# We now multiply the fields of view with their matching weights.
|
| 123 |
+
# We do this by tensor multiplication and broadcasting.
|
| 124 |
+
# if by_convs is true for this dim, then we do this action by
|
| 125 |
+
# convolutions. this is equivalent but faster.
|
| 126 |
+
if not dim_by_convs:
|
| 127 |
+
output = apply_weights(output, field_of_view, weights, dim, n_dims,
|
| 128 |
+
pad_sz, pad_mode, fw)
|
| 129 |
+
else:
|
| 130 |
+
output = apply_convs(output, scale_factor, in_sz, out_sz, weights,
|
| 131 |
+
dim, pad_sz, pad_mode, fw)
|
| 132 |
+
return output
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def get_projected_grid(in_sz, out_sz, scale_factor, fw, by_convs, device=None):
|
| 136 |
+
# we start by having the ouput coordinates which are just integer locations
|
| 137 |
+
# in the special case when usin by_convs, we only need two cycles of grid
|
| 138 |
+
# points. the first and last.
|
| 139 |
+
grid_sz = out_sz if not by_convs else scale_factor.numerator
|
| 140 |
+
out_coordinates = fw_arange(grid_sz, fw, device)
|
| 141 |
+
|
| 142 |
+
# This is projecting the ouput pixel locations in 1d to the input tensor,
|
| 143 |
+
# as non-integer locations.
|
| 144 |
+
# the following fomrula is derived in the paper
|
| 145 |
+
# "From Discrete to Continuous Convolutions" by Shocher et al.
|
| 146 |
+
return (out_coordinates / float(scale_factor) +
|
| 147 |
+
(in_sz - 1) / 2 - (out_sz - 1) / (2 * float(scale_factor)))
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def get_field_of_view(projected_grid, cur_support_sz, fw, eps, device):
|
| 151 |
+
# for each output pixel, map which input pixels influence it, in 1d.
|
| 152 |
+
# we start by calculating the leftmost neighbor, using half of the window
|
| 153 |
+
# size (eps is for when boundary is exact int)
|
| 154 |
+
left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw)
|
| 155 |
+
|
| 156 |
+
# then we simply take all the pixel centers in the field by counting
|
| 157 |
+
# window size pixels from the left boundary
|
| 158 |
+
ordinal_numbers = fw_arange(ceil(cur_support_sz - eps), fw, device)
|
| 159 |
+
return left_boundaries[:, None] + ordinal_numbers
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def calc_pad_sz(in_sz, out_sz, field_of_view, projected_grid, scale_factor,
|
| 163 |
+
dim_by_convs, fw, device):
|
| 164 |
+
if not dim_by_convs:
|
| 165 |
+
# determine padding according to neighbor coords out of bound.
|
| 166 |
+
# this is a generalized notion of padding, when pad<0 it means crop
|
| 167 |
+
pad_sz = [-field_of_view[0, 0].item(),
|
| 168 |
+
field_of_view[-1, -1].item() - in_sz + 1]
|
| 169 |
+
|
| 170 |
+
# since input image will be changed by padding, coordinates of both
|
| 171 |
+
# field_of_view and projected_grid need to be updated
|
| 172 |
+
field_of_view += pad_sz[0]
|
| 173 |
+
projected_grid += pad_sz[0]
|
| 174 |
+
|
| 175 |
+
else:
|
| 176 |
+
# only used for by_convs, to calc the boundaries of each filter the
|
| 177 |
+
# number of distinct convolutions is the numerator of the scale factor
|
| 178 |
+
num_convs, stride = scale_factor.numerator, scale_factor.denominator
|
| 179 |
+
|
| 180 |
+
# calculate left and right boundaries for each conv. left can also be
|
| 181 |
+
# negative right can be bigger than in_sz. such cases imply padding if
|
| 182 |
+
# needed. however if# both are in-bounds, it means we need to crop,
|
| 183 |
+
# practically apply the conv only on part of the image.
|
| 184 |
+
left_pads = -field_of_view[:, 0]
|
| 185 |
+
|
| 186 |
+
# next calc is tricky, explanation by rows:
|
| 187 |
+
# 1) counting output pixels between the first position of each filter
|
| 188 |
+
# to the right boundary of the input
|
| 189 |
+
# 2) dividing it by number of filters to count how many 'jumps'
|
| 190 |
+
# each filter does
|
| 191 |
+
# 3) multiplying by the stride gives us the distance over the input
|
| 192 |
+
# coords done by all these jumps for each filter
|
| 193 |
+
# 4) to this distance we add the right boundary of the filter when
|
| 194 |
+
# placed in its leftmost position. so now we get the right boundary
|
| 195 |
+
# of that filter in input coord.
|
| 196 |
+
# 5) the padding size needed is obtained by subtracting the rightmost
|
| 197 |
+
# input coordinate. if the result is positive padding is needed. if
|
| 198 |
+
# negative then negative padding means shaving off pixel columns.
|
| 199 |
+
right_pads = (((out_sz - fw_arange(num_convs, fw, device) - 1) # (1)
|
| 200 |
+
// num_convs) # (2)
|
| 201 |
+
* stride # (3)
|
| 202 |
+
+ field_of_view[:, -1] # (4)
|
| 203 |
+
- in_sz + 1) # (5)
|
| 204 |
+
|
| 205 |
+
# in the by_convs case pad_sz is a list of left-right pairs. one per
|
| 206 |
+
# each filter
|
| 207 |
+
|
| 208 |
+
pad_sz = list(zip(left_pads, right_pads))
|
| 209 |
+
|
| 210 |
+
return pad_sz, projected_grid, field_of_view
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def get_weights(interp_method, projected_grid, field_of_view):
|
| 214 |
+
# the set of weights per each output pixels is the result of the chosen
|
| 215 |
+
# interpolation method applied to the distances between projected grid
|
| 216 |
+
# locations and the pixel-centers in the field of view (distances are
|
| 217 |
+
# directed, can be positive or negative)
|
| 218 |
+
weights = interp_method(projected_grid[:, None] - field_of_view)
|
| 219 |
+
|
| 220 |
+
# we now carefully normalize the weights to sum to 1 per each output pixel
|
| 221 |
+
sum_weights = weights.sum(1, keepdims=True)
|
| 222 |
+
sum_weights[sum_weights == 0] = 1
|
| 223 |
+
return weights / sum_weights
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def apply_weights(input, field_of_view, weights, dim, n_dims, pad_sz, pad_mode,
|
| 227 |
+
fw):
|
| 228 |
+
# for this operation we assume the resized dim is the first one.
|
| 229 |
+
# so we transpose and will transpose back after multiplying
|
| 230 |
+
tmp_input = fw_swapaxes(input, dim, 0, fw)
|
| 231 |
+
|
| 232 |
+
# apply padding
|
| 233 |
+
tmp_input = fw_pad(tmp_input, fw, pad_sz, pad_mode)
|
| 234 |
+
|
| 235 |
+
# field_of_view is a tensor of order 2: for each output (1d location
|
| 236 |
+
# along cur dim)- a list of 1d neighbors locations.
|
| 237 |
+
# note that this whole operations is applied to each dim separately,
|
| 238 |
+
# this is why it is all in 1d.
|
| 239 |
+
# neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1:
|
| 240 |
+
# for each output pixel (this time indicated in all dims), these are the
|
| 241 |
+
# values of the neighbors in the 1d field of view. note that we only
|
| 242 |
+
# consider neighbors along the current dim, but such set exists for every
|
| 243 |
+
# multi-dim location, hence the final tensor order is image_dims+1.
|
| 244 |
+
neighbors = tmp_input[field_of_view]
|
| 245 |
+
|
| 246 |
+
# weights is an order 2 tensor: for each output location along 1d- a list
|
| 247 |
+
# of weights matching the field of view. we augment it with ones, for
|
| 248 |
+
# broadcasting, so that when multiplies some tensor the weights affect
|
| 249 |
+
# only its first dim.
|
| 250 |
+
tmp_weights = fw.reshape(weights, (*weights.shape, *[1] * (n_dims - 1)))
|
| 251 |
+
|
| 252 |
+
# now we simply multiply the weights with the neighbors, and then sum
|
| 253 |
+
# along the field of view, to get a single value per out pixel
|
| 254 |
+
tmp_output = (neighbors * tmp_weights).sum(1)
|
| 255 |
+
|
| 256 |
+
# we transpose back the resized dim to its original position
|
| 257 |
+
return fw_swapaxes(tmp_output, 0, dim, fw)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def apply_convs(input, scale_factor, in_sz, out_sz, weights, dim, pad_sz,
|
| 261 |
+
pad_mode, fw):
|
| 262 |
+
# for this operations we assume the resized dim is the last one.
|
| 263 |
+
# so we transpose and will transpose back after multiplying
|
| 264 |
+
input = fw_swapaxes(input, dim, -1, fw)
|
| 265 |
+
|
| 266 |
+
# the stride for all convs is the denominator of the scale factor
|
| 267 |
+
stride, num_convs = scale_factor.denominator, scale_factor.numerator
|
| 268 |
+
|
| 269 |
+
# prepare an empty tensor for the output
|
| 270 |
+
tmp_out_shape = list(input.shape)
|
| 271 |
+
tmp_out_shape[-1] = out_sz
|
| 272 |
+
tmp_output = fw_empty(tuple(tmp_out_shape), fw, input.device)
|
| 273 |
+
|
| 274 |
+
# iterate over the conv operations. we have as many as the numerator
|
| 275 |
+
# of the scale-factor. for each we need boundaries and a filter.
|
| 276 |
+
for conv_ind, (pad_sz, filt) in enumerate(zip(pad_sz, weights)):
|
| 277 |
+
# apply padding (we pad last dim, padding can be negative)
|
| 278 |
+
pad_dim = input.ndim - 1
|
| 279 |
+
tmp_input = fw_pad(input, fw, pad_sz, pad_mode, dim=pad_dim)
|
| 280 |
+
|
| 281 |
+
# apply convolution over last dim. store in the output tensor with
|
| 282 |
+
# positional strides so that when the loop is comlete conv results are
|
| 283 |
+
# interwind
|
| 284 |
+
tmp_output[..., conv_ind::num_convs] = fw_conv(tmp_input, filt, stride)
|
| 285 |
+
|
| 286 |
+
return fw_swapaxes(tmp_output, -1, dim, fw)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def set_scale_and_out_sz(in_shape, out_shape, scale_factors, by_convs,
|
| 290 |
+
scale_tolerance, max_numerator, eps, fw):
|
| 291 |
+
# eventually we must have both scale-factors and out-sizes for all in/out
|
| 292 |
+
# dims. however, we support many possible partial arguments
|
| 293 |
+
if scale_factors is None and out_shape is None:
|
| 294 |
+
raise ValueError("either scale_factors or out_shape should be "
|
| 295 |
+
"provided")
|
| 296 |
+
if out_shape is not None:
|
| 297 |
+
# if out_shape has less dims than in_shape, we defaultly resize the
|
| 298 |
+
# first dims for numpy and last dims for torch
|
| 299 |
+
out_shape = (list(out_shape) + list(in_shape[len(out_shape):])
|
| 300 |
+
if fw is numpy
|
| 301 |
+
else list(in_shape[:-len(out_shape)]) + list(out_shape))
|
| 302 |
+
if scale_factors is None:
|
| 303 |
+
# if no scale given, we calculate it as the out to in ratio
|
| 304 |
+
# (not recomended)
|
| 305 |
+
scale_factors = [out_sz / in_sz for out_sz, in_sz
|
| 306 |
+
in zip(out_shape, in_shape)]
|
| 307 |
+
if scale_factors is not None:
|
| 308 |
+
# by default, if a single number is given as scale, we assume resizing
|
| 309 |
+
# two dims (most common are images with 2 spatial dims)
|
| 310 |
+
scale_factors = (scale_factors
|
| 311 |
+
if isinstance(scale_factors, (list, tuple))
|
| 312 |
+
else [scale_factors, scale_factors])
|
| 313 |
+
# if less scale_factors than in_shape dims, we defaultly resize the
|
| 314 |
+
# first dims for numpy and last dims for torch
|
| 315 |
+
scale_factors = (list(scale_factors) + [1] *
|
| 316 |
+
(len(in_shape) - len(scale_factors)) if fw is numpy
|
| 317 |
+
else [1] * (len(in_shape) - len(scale_factors)) +
|
| 318 |
+
list(scale_factors))
|
| 319 |
+
if out_shape is None:
|
| 320 |
+
# when no out_shape given, it is calculated by multiplying the
|
| 321 |
+
# scale by the in_shape (not recomended)
|
| 322 |
+
out_shape = [ceil(scale_factor * in_sz)
|
| 323 |
+
for scale_factor, in_sz in
|
| 324 |
+
zip(scale_factors, in_shape)]
|
| 325 |
+
# next part intentionally after out_shape determined for stability
|
| 326 |
+
# we fix by_convs to be a list of truth values in case it is not
|
| 327 |
+
if not isinstance(by_convs, (list, tuple)):
|
| 328 |
+
by_convs = [by_convs] * len(out_shape)
|
| 329 |
+
|
| 330 |
+
# next loop fixes the scale for each dim to be either frac or float.
|
| 331 |
+
# this is determined by by_convs and by tolerance for scale accuracy.
|
| 332 |
+
for ind, (sf, dim_by_convs) in enumerate(zip(scale_factors, by_convs)):
|
| 333 |
+
# first we fractionaize
|
| 334 |
+
if dim_by_convs:
|
| 335 |
+
frac = Fraction(1 / sf).limit_denominator(max_numerator)
|
| 336 |
+
frac = Fraction(numerator=frac.denominator, denominator=frac.numerator)
|
| 337 |
+
|
| 338 |
+
# if accuracy is within tolerance scale will be frac. if not, then
|
| 339 |
+
# it will be float and the by_convs attr will be set false for
|
| 340 |
+
# this dim
|
| 341 |
+
if scale_tolerance is None:
|
| 342 |
+
scale_tolerance = eps
|
| 343 |
+
if dim_by_convs and abs(frac - sf) < scale_tolerance:
|
| 344 |
+
scale_factors[ind] = frac
|
| 345 |
+
else:
|
| 346 |
+
scale_factors[ind] = float(sf)
|
| 347 |
+
by_convs[ind] = False
|
| 348 |
+
|
| 349 |
+
return scale_factors, out_shape, by_convs
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,
|
| 353 |
+
antialiasing):
|
| 354 |
+
# antialiasing is "stretching" the field of view according to the scale
|
| 355 |
+
# factor (only for downscaling). this is low-pass filtering. this
|
| 356 |
+
# requires modifying both the interpolation (stretching the 1d
|
| 357 |
+
# function and multiplying by the scale-factor) and the window size.
|
| 358 |
+
scale_factor = float(scale_factor)
|
| 359 |
+
if scale_factor >= 1.0 or not antialiasing:
|
| 360 |
+
return interp_method, support_sz
|
| 361 |
+
cur_interp_method = (lambda arg: scale_factor *
|
| 362 |
+
interp_method(scale_factor * arg))
|
| 363 |
+
cur_support_sz = support_sz / scale_factor
|
| 364 |
+
return cur_interp_method, cur_support_sz
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def fw_ceil(x, fw):
|
| 368 |
+
if fw is numpy:
|
| 369 |
+
return fw.int_(fw.ceil(x))
|
| 370 |
+
else:
|
| 371 |
+
return x.ceil().long()
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def fw_floor(x, fw):
|
| 375 |
+
if fw is numpy:
|
| 376 |
+
return fw.int_(fw.floor(x))
|
| 377 |
+
else:
|
| 378 |
+
return x.floor().long()
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def fw_cat(x, fw):
|
| 382 |
+
if fw is numpy:
|
| 383 |
+
return fw.concatenate(x)
|
| 384 |
+
else:
|
| 385 |
+
return fw.cat(x)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def fw_swapaxes(x, ax_1, ax_2, fw):
|
| 389 |
+
if fw is numpy:
|
| 390 |
+
return fw.swapaxes(x, ax_1, ax_2)
|
| 391 |
+
else:
|
| 392 |
+
return x.transpose(ax_1, ax_2)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def fw_pad(x, fw, pad_sz, pad_mode, dim=0):
|
| 396 |
+
if pad_sz == (0, 0):
|
| 397 |
+
return x
|
| 398 |
+
if fw is numpy:
|
| 399 |
+
pad_vec = [(0, 0)] * x.ndim
|
| 400 |
+
pad_vec[dim] = pad_sz
|
| 401 |
+
return fw.pad(x, pad_width=pad_vec, mode=pad_mode)
|
| 402 |
+
else:
|
| 403 |
+
if x.ndim < 3:
|
| 404 |
+
x = x[None, None, ...]
|
| 405 |
+
|
| 406 |
+
pad_vec = [0] * ((x.ndim - 2) * 2)
|
| 407 |
+
pad_vec[0:2] = pad_sz
|
| 408 |
+
return fw.nn.functional.pad(x.transpose(dim, -1), pad=pad_vec,
|
| 409 |
+
mode=pad_mode).transpose(dim, -1)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def fw_conv(input, filter, stride):
|
| 413 |
+
# we want to apply 1d conv to any nd array. the way to do it is to reshape
|
| 414 |
+
# the input to a 4D tensor. first two dims are singeletons, 3rd dim stores
|
| 415 |
+
# all the spatial dims that we are not convolving along now. then we can
|
| 416 |
+
# apply conv2d with a 1xK filter. This convolves the same way all the other
|
| 417 |
+
# dims stored in the 3d dim. like depthwise conv over these.
|
| 418 |
+
# TODO: numpy support
|
| 419 |
+
reshaped_input = input.reshape(1, 1, -1, input.shape[-1])
|
| 420 |
+
reshaped_output = torch.nn.functional.conv2d(reshaped_input,
|
| 421 |
+
filter.view(1, 1, 1, -1),
|
| 422 |
+
stride=(1, stride))
|
| 423 |
+
return reshaped_output.reshape(*input.shape[:-1], -1)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def fw_arange(upper_bound, fw, device):
|
| 427 |
+
if fw is numpy:
|
| 428 |
+
return fw.arange(upper_bound)
|
| 429 |
+
else:
|
| 430 |
+
return fw.arange(upper_bound, device=device)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def fw_empty(shape, fw, device):
|
| 434 |
+
if fw is numpy:
|
| 435 |
+
return fw.empty(shape)
|
| 436 |
+
else:
|
| 437 |
+
return fw.empty(size=(*shape,), device=device)
|
pretrained_models/LMAR_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:27f1ada04c3297053af030ec2547a06f54d5de5e1ec20f3b430a9dd2f2f666ff
|
| 3 |
+
size 1475245
|
pretrained_models/base_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:45d49de91c08e7a6080d60f7059482fcd443377982e1908045625759e5931772
|
| 3 |
+
size 3417093
|
utils.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision.utils as vutils
|
| 2 |
+
import argparse
|
| 3 |
+
import yaml
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision
|
| 6 |
+
from metrics import calculate_psnr, calculate_ssim
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
import numpy as np
|
| 9 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AverageMeter(object):
|
| 14 |
+
"""Computes and stores the average and current value"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.reset()
|
| 18 |
+
|
| 19 |
+
def reset(self):
|
| 20 |
+
self.val = 0
|
| 21 |
+
self.avg = 0
|
| 22 |
+
self.sum = 0
|
| 23 |
+
self.count = 0
|
| 24 |
+
|
| 25 |
+
def update(self, val, n=1):
|
| 26 |
+
self.val = val
|
| 27 |
+
self.sum += val * n
|
| 28 |
+
self.count += n
|
| 29 |
+
self.avg = self.sum / self.count
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def calculate_metrics(imgs_1, imgs_2):
|
| 33 |
+
psnrs = []
|
| 34 |
+
ssims = []
|
| 35 |
+
assert imgs_1.shape[0] == imgs_2.shape[0]
|
| 36 |
+
batch_size = imgs_1.shape[0]
|
| 37 |
+
for i in range(batch_size):
|
| 38 |
+
img1 = imgs_1[i]
|
| 39 |
+
img2 = imgs_2[i]
|
| 40 |
+
img1 = np.asarray(transforms.ToPILImage()(img1))
|
| 41 |
+
img2 = np.asarray(transforms.ToPILImage()(img2))
|
| 42 |
+
psnr = calculate_psnr(img1, img2, 0)
|
| 43 |
+
ssim = calculate_ssim(img1, img2, 0)
|
| 44 |
+
psnrs.append(psnr)
|
| 45 |
+
ssims.append(ssim)
|
| 46 |
+
return np.asarray(psnrs).mean(), np.asarray(ssims).mean()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def read_args(config_file):
|
| 50 |
+
parser = argparse.ArgumentParser()
|
| 51 |
+
parser.add_argument("--config", default=config_file)
|
| 52 |
+
file = open(config_file)
|
| 53 |
+
config = yaml.safe_load(file)
|
| 54 |
+
for k, v in config.items():
|
| 55 |
+
parser.add_argument(f"--{k}", default=v)
|
| 56 |
+
return parser
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def save_checkpoint(state, filename):
|
| 60 |
+
torch.save(state, filename)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class CosineAnnealingWarmRestarts(_LRScheduler):
|
| 64 |
+
r"""Set the learning rate of each parameter group using a cosine annealing
|
| 65 |
+
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
|
| 66 |
+
is the number of epochs since the last restart and :math:`T_{i}` is the number
|
| 67 |
+
of epochs between two warm restarts in SGDR:
|
| 68 |
+
.. math::
|
| 69 |
+
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
| 70 |
+
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
|
| 71 |
+
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
|
| 72 |
+
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
|
| 73 |
+
It has been proposed in
|
| 74 |
+
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
|
| 75 |
+
Args:
|
| 76 |
+
optimizer (Optimizer): Wrapped optimizer.
|
| 77 |
+
T_0 (int): Number of iterations for the first restart.
|
| 78 |
+
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
|
| 79 |
+
eta_min (float, optional): Minimum learning rate. Default: 0.
|
| 80 |
+
last_epoch (int, optional): The index of last epoch. Default: -1.
|
| 81 |
+
verbose (bool): If ``True``, prints a message to stdout for
|
| 82 |
+
each update. Default: ``False``.
|
| 83 |
+
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
| 84 |
+
https://arxiv.org/abs/1608.03983
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
|
| 88 |
+
if T_0 <= 0 or not isinstance(T_0, int):
|
| 89 |
+
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
|
| 90 |
+
if T_mult < 1 or not isinstance(T_mult, int):
|
| 91 |
+
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
|
| 92 |
+
self.T_0 = T_0
|
| 93 |
+
self.T_i = T_0
|
| 94 |
+
self.T_mult = T_mult
|
| 95 |
+
self.eta_min = eta_min
|
| 96 |
+
|
| 97 |
+
self.T_cur = 0 if last_epoch < 0 else last_epoch
|
| 98 |
+
super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch, verbose)
|
| 99 |
+
|
| 100 |
+
def get_lr(self):
|
| 101 |
+
if not self._get_lr_called_within_step:
|
| 102 |
+
warnings.warn("To get the last learning rate computed by the scheduler, "
|
| 103 |
+
"please use `get_last_lr()`.", UserWarning)
|
| 104 |
+
return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
|
| 105 |
+
for base_lr in self.base_lrs]
|
| 106 |
+
|
| 107 |
+
def step(self, epoch=None):
|
| 108 |
+
"""Step could be called after every batch update
|
| 109 |
+
Example:
|
| 110 |
+
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
|
| 111 |
+
>>> iters = len(dataloader)
|
| 112 |
+
>>> for epoch in range(20):
|
| 113 |
+
>>> for i, sample in enumerate(dataloader):
|
| 114 |
+
>>> inputs, labels = sample['inputs'], sample['labels']
|
| 115 |
+
>>> optimizer.zero_grad()
|
| 116 |
+
>>> outputs = net(inputs)
|
| 117 |
+
>>> loss = criterion(outputs, labels)
|
| 118 |
+
>>> loss.backward()
|
| 119 |
+
>>> optimizer.step()
|
| 120 |
+
>>> scheduler.step(epoch + i / iters)
|
| 121 |
+
This function can be called in an interleaved way.
|
| 122 |
+
Example:
|
| 123 |
+
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
|
| 124 |
+
>>> for epoch in range(20):
|
| 125 |
+
>>> scheduler.step()
|
| 126 |
+
>>> scheduler.step(26)
|
| 127 |
+
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
|
| 128 |
+
"""
|
| 129 |
+
if epoch is None and self.last_epoch < 0:
|
| 130 |
+
epoch = 0
|
| 131 |
+
if epoch is None:
|
| 132 |
+
epoch = self.last_epoch + 1
|
| 133 |
+
self.T_cur = self.T_cur + 1
|
| 134 |
+
if self.T_cur >= self.T_i:
|
| 135 |
+
self.T_cur = self.T_cur - self.T_i
|
| 136 |
+
self.T_i = self.T_i * self.T_mult
|
| 137 |
+
else:
|
| 138 |
+
if epoch < 0:
|
| 139 |
+
raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
|
| 140 |
+
if epoch >= self.T_0:
|
| 141 |
+
if self.T_mult == 1:
|
| 142 |
+
self.T_cur = epoch % self.T_0
|
| 143 |
+
else:
|
| 144 |
+
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
|
| 145 |
+
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
|
| 146 |
+
self.T_i = self.T_0 * self.T_mult ** (n)
|
| 147 |
+
else:
|
| 148 |
+
self.T_i = self.T_0
|
| 149 |
+
self.T_cur = epoch
|
| 150 |
+
self.last_epoch = math.floor(epoch)
|
| 151 |
+
|
| 152 |
+
class _enable_get_lr_call:
|
| 153 |
+
def __init__(self, o):
|
| 154 |
+
self.o = o
|
| 155 |
+
|
| 156 |
+
def __enter__(self):
|
| 157 |
+
self.o._get_lr_called_within_step = True
|
| 158 |
+
return self
|
| 159 |
+
|
| 160 |
+
def __exit__(self, type, value, traceback):
|
| 161 |
+
self.o._get_lr_called_within_step = False
|
| 162 |
+
return self
|
| 163 |
+
|
| 164 |
+
with _enable_get_lr_call(self):
|
| 165 |
+
for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
|
| 166 |
+
param_group, lr = data
|
| 167 |
+
param_group['lr'] = lr
|
| 168 |
+
self.print_lr(self.verbose, i, lr, epoch)
|
| 169 |
+
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def set_seed(seed):
|
| 173 |
+
random.seed(seed)
|
| 174 |
+
np.random.seed(seed)
|
| 175 |
+
torch.manual_seed(seed)
|
| 176 |
+
if torch.cuda.is_available():
|
| 177 |
+
torch.cuda.manual_seed_all(seed)
|