File size: 4,146 Bytes
98feea6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import yaml
import torchvision.transforms as transforms
from utils import read_args, save_checkpoint, AverageMeter, calculate_metrics, CosineAnnealingWarmRestarts
# import torchvision.transforms.InterpolationMode
import time
from tqdm import trange, tqdm
from torchvision.utils import save_image
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import json
import time
import logging
import torch
from torch import nn, optim
import numpy as np
import torch.nn.functional as F

from model import *
from data import *
from PIL import Image
from torchvision.transforms import Resize
import pyiqa
from thop import profile
from thop import clever_format

psnr_calculator = pyiqa.create_metric('psnr').cuda()
ssim_calculator = pyiqa.create_metric('ssimc', downsample=True).cuda()
lpips_calculator = pyiqa.create_metric('lpips').cuda()
niqe_calculator = pyiqa.create_metric('niqe').cuda()


def test(load_path, data_loader, args):
    # if not os.path.exists(args.output_dir + '/out_my'):
        # os.mkdir(args.output_dir + '/out_my')

    # save_path = args.output_dir + "/out_my"
    model = net(args)
    checkpoint = torch.load(load_path)
    model.load_state_dict(checkpoint["state_dict"])
    model.cuda()
    model.eval()

    psnrs = AverageMeter()
    ssims = AverageMeter()
    lpipss = AverageMeter()
    niqes = AverageMeter()
    
    start_time = time.time()
    down_size = (1440, 2560)
    logging.info("Inference at down size: {}".format(down_size))
    with torch.no_grad():
        for i, batch in enumerate(tqdm(data_loader)):
            input_img, gt_img, inp_img_path = batch

            name = inp_img_path[0].split("/")[-1]
            input_img = input_img.cuda()
            batch_size = input_img.size(0)
            start_time = time.time()
            input_img = resize(input_img, out_shape=down_size, antialiasing=False)
            out_img = model(input_img)
            out_img = resize(out_img, out_shape=eval(args.test_loader["gt_size"]), antialiasing=False)

            # metrics
            clamped_out = torch.clamp(out_img, 0, 1)
            psnr_val, ssim_val = psnr_calculator(clamped_out, gt_img), ssim_calculator(clamped_out, gt_img)
            psnrs.update(torch.mean(psnr_val).item(), batch_size)
            ssims.update(torch.mean(ssim_val).item(), batch_size)

            # lpips = lpips_calculator(clamped_out, gt_img)
            # lpipss.update(torch.mean(lpips).item(), batch_size)
            # niqe = niqe_calculator(clamped_out)
            # niqes.update(torch.mean(niqe).item(), batch_size)
            torch.cuda.empty_cache()

            if i % 700 == 0:
                logging.info(
                    "PSNR {:.4f}, SSIM {:.4f}, LPIPS {:.4F}, NIQE {:.4F}, Elapse time {:.2f}\n".format(psnrs.avg, ssims.avg, lpipss.avg, niqes.avg,
                                                                            time.time() - start_time))

        logging.info("Finish test: avg PSNR: %.4f, avg SSIM: %.4F, avg LPIPS: %.4F, avg NIQE: %.4F, and takes %.2f seconds" % (
            psnrs.avg, ssims.avg, lpipss.avg, niqes.avg, time.time() - start_time))

def main(args, load_path):
    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)
    test_transforms = transforms.Compose([transforms.ToTensor()])

    log_format = "%(asctime)s %(levelname)-8s %(message)s"
    log_file = os.path.join(args.output_dir, "baseline_log")
    logging.basicConfig(filename=log_file, level=logging.INFO, format=log_format)
    logging.getLogger().addHandler(logging.StreamHandler())

    logging.info("Building data loader")

    test_loader = get_loader(args.data["test_dir"],
                             eval(args.test_loader["img_size"]), test_transforms, False,
                             int(args.test_loader["batch_size"]), args.test_loader["num_workers"],
                             args.test_loader["shuffle"], random_flag=False)
    test(load_path, test_loader, args)


if __name__ == '__main__':
    parser = read_args("/home/yuwei/code/cvpr/config/base_config.yaml")
    args = parser.parse_args()
    main(args, "./pretrained_models/base_model.bin")