|
|
import argparse |
|
|
import yaml |
|
|
import torchvision.transforms as transforms |
|
|
from utils import read_args, save_checkpoint, AverageMeter, calculate_metrics, CosineAnnealingWarmRestarts |
|
|
|
|
|
import time |
|
|
from tqdm import trange, tqdm |
|
|
from torchvision.utils import save_image |
|
|
import os |
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
import json |
|
|
import time |
|
|
import logging |
|
|
import torch |
|
|
from torch import nn, optim |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from model import * |
|
|
from data import * |
|
|
from PIL import Image |
|
|
from torchvision.transforms import Resize |
|
|
import pyiqa |
|
|
from thop import profile |
|
|
from thop import clever_format |
|
|
|
|
|
psnr_calculator = pyiqa.create_metric('psnr').cuda() |
|
|
ssim_calculator = pyiqa.create_metric('ssimc', downsample=True).cuda() |
|
|
lpips_calculator = pyiqa.create_metric('lpips').cuda() |
|
|
niqe_calculator = pyiqa.create_metric('niqe').cuda() |
|
|
|
|
|
|
|
|
def test(load_path, data_loader, args): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = net(args) |
|
|
checkpoint = torch.load(load_path) |
|
|
model.load_state_dict(checkpoint["state_dict"]) |
|
|
model.cuda() |
|
|
model.eval() |
|
|
|
|
|
psnrs = AverageMeter() |
|
|
ssims = AverageMeter() |
|
|
lpipss = AverageMeter() |
|
|
niqes = AverageMeter() |
|
|
|
|
|
start_time = time.time() |
|
|
down_size = (1440, 2560) |
|
|
logging.info("Inference at down size: {}".format(down_size)) |
|
|
with torch.no_grad(): |
|
|
for i, batch in enumerate(tqdm(data_loader)): |
|
|
input_img, gt_img, inp_img_path = batch |
|
|
|
|
|
name = inp_img_path[0].split("/")[-1] |
|
|
input_img = input_img.cuda() |
|
|
batch_size = input_img.size(0) |
|
|
start_time = time.time() |
|
|
input_img = resize(input_img, out_shape=down_size, antialiasing=False) |
|
|
out_img = model(input_img) |
|
|
out_img = resize(out_img, out_shape=eval(args.test_loader["gt_size"]), antialiasing=False) |
|
|
|
|
|
|
|
|
clamped_out = torch.clamp(out_img, 0, 1) |
|
|
psnr_val, ssim_val = psnr_calculator(clamped_out, gt_img), ssim_calculator(clamped_out, gt_img) |
|
|
psnrs.update(torch.mean(psnr_val).item(), batch_size) |
|
|
ssims.update(torch.mean(ssim_val).item(), batch_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
if i % 700 == 0: |
|
|
logging.info( |
|
|
"PSNR {:.4f}, SSIM {:.4f}, LPIPS {:.4F}, NIQE {:.4F}, Elapse time {:.2f}\n".format(psnrs.avg, ssims.avg, lpipss.avg, niqes.avg, |
|
|
time.time() - start_time)) |
|
|
|
|
|
logging.info("Finish test: avg PSNR: %.4f, avg SSIM: %.4F, avg LPIPS: %.4F, avg NIQE: %.4F, and takes %.2f seconds" % ( |
|
|
psnrs.avg, ssims.avg, lpipss.avg, niqes.avg, time.time() - start_time)) |
|
|
|
|
|
def main(args, load_path): |
|
|
if not os.path.exists(args.output_dir): |
|
|
os.mkdir(args.output_dir) |
|
|
test_transforms = transforms.Compose([transforms.ToTensor()]) |
|
|
|
|
|
log_format = "%(asctime)s %(levelname)-8s %(message)s" |
|
|
log_file = os.path.join(args.output_dir, "baseline_log") |
|
|
logging.basicConfig(filename=log_file, level=logging.INFO, format=log_format) |
|
|
logging.getLogger().addHandler(logging.StreamHandler()) |
|
|
|
|
|
logging.info("Building data loader") |
|
|
|
|
|
test_loader = get_loader(args.data["test_dir"], |
|
|
eval(args.test_loader["img_size"]), test_transforms, False, |
|
|
int(args.test_loader["batch_size"]), args.test_loader["num_workers"], |
|
|
args.test_loader["shuffle"], random_flag=False) |
|
|
test(load_path, test_loader, args) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = read_args("/home/yuwei/code/cvpr/config/base_config.yaml") |
|
|
args = parser.parse_args() |
|
|
main(args, "./pretrained_models/base_model.bin") |
|
|
|