| import torch, os, glob, pyiqa |
| from argparse import ArgumentParser |
| import numpy as np |
| from PIL import Image |
| from tqdm import tqdm |
| from torchvision import transforms |
|
|
| parser = ArgumentParser() |
| parser.add_argument("--HR_dir", type=str, default="testset/RealSR/HR") |
| parser.add_argument("--SR_dir", type=str, default="result/RealSR") |
| args = parser.parse_args() |
|
|
| device = torch.device("cuda") |
|
|
| psnr = pyiqa.create_metric("psnr", test_y_channel=True, color_space="ycbcr", device=device) |
| ssim = pyiqa.create_metric("ssim", test_y_channel=True, color_space="ycbcr", device=device) |
| lpips = pyiqa.create_metric("lpips", device=device) |
| dists = pyiqa.create_metric("dists", device=device) |
| fid = pyiqa.create_metric("fid", device=device) |
| niqe = pyiqa.create_metric("niqe", device=device) |
| maniqa = pyiqa.create_metric("maniqa-pipal", device=device) |
| clipiqa = pyiqa.create_metric("clipiqa", device=device) |
| musiq = pyiqa.create_metric("musiq", device=device) |
|
|
| test_SR_paths = list(sorted(glob.glob(os.path.join(args.SR_dir, "*")))) |
| test_HR_paths = list(sorted(glob.glob(os.path.join(args.HR_dir, "*")))) |
|
|
| metrics = {"psnr": [], "ssim": [], "lpips": [], "dists": [], "niqe": [], "maniqa": [], "musiq": [], "clipiqa": []} |
|
|
| for i, (SR_path, HR_path) in tqdm(enumerate(zip(test_SR_paths, test_HR_paths))): |
| SR = Image.open(SR_path).convert("RGB") |
| SR = transforms.ToTensor()(SR).to(device).unsqueeze(0) |
| HR = Image.open(HR_path).convert("RGB") |
| HR = transforms.ToTensor()(HR).to(device).unsqueeze(0) |
| metrics["psnr"].append(psnr(SR, HR).item()) |
| metrics["ssim"].append(ssim(SR, HR).item()) |
| metrics["lpips"].append(lpips(SR, HR).item()) |
| metrics["dists"].append(dists(SR, HR).item()) |
| metrics["niqe"].append(niqe(SR).item()) |
| metrics["maniqa"].append(maniqa(SR).item()) |
| metrics["clipiqa"].append(clipiqa(SR).item()) |
| metrics["musiq"].append(musiq(SR).item()) |
|
|
| for k in metrics.keys(): |
| metrics[k] = np.mean(metrics[k]) |
|
|
| metrics["fid"] = fid(args.SR_dir, args.HR_dir) |
|
|
| for k, v in metrics.items(): |
| if k == "niqe": |
| print(k, f"{v:.3g}") |
| elif k == "fid": |
| print(k, f"{v:.5g}") |
| else: |
| print(k, f"{v:.4g}") |