import sys import json from pathlib import Path sys.path.insert(0, Path(__file__).parent.as_posix()) import cv2 import pyiqa import torch import numpy as np from tqdm import tqdm from onnx_runner import OnnxRunner def collect_common_image_pairs( lq_dir: Path, hq_dir: Path ) -> tuple[list[Path], list[Path]]: exts = {".png", ".jpg", ".jpeg"} def is_img(p: Path) -> bool: return p.is_file() and p.suffix.lower() in exts hq_map = {p.stem: p for p in hq_dir.iterdir() if is_img(p)} hq_names = sorted(hq_map.keys()) lq_files = [p for p in lq_dir.iterdir() if is_img(p)] lq_paths: list[Path] = [] hq_paths: list[Path] = [] for base in hq_names: # try full match first best_lq = next((p for p in lq_files if p.stem == base), None) # try prefix match then if best_lq is None: best_lq = next( ( p for p in lq_files if p.stem.startswith(base) and len(p.stem) > len(base) ), None, ) if best_lq is not None: # matched hq_paths.append(hq_map[base]) lq_paths.append(best_lq) return lq_paths, hq_paths def align_shape(sr_bgr: np.ndarray, hq_bgr: np.ndarray): if sr_bgr.shape != hq_bgr.shape: sr_bgr = cv2.resize( sr_bgr, (hq_bgr.shape[1], hq_bgr.shape[0]), interpolation=cv2.INTER_LINEAR, ) return sr_bgr, hq_bgr def gen_sr_images( hq_dir: Path, lq_dir: Path, out_dir: Path, onnx_path: Path, latent_path: Path, max_samples: int, ): out_dir.mkdir(exist_ok=True, parents=True) onnx_runner = OnnxRunner(onnx_path, latent_path) lq_paths, hq_paths = collect_common_image_pairs(lq_dir, hq_dir) if max_samples is not None: lq_paths = lq_paths[: max(max_samples, 1)] hq_paths = hq_paths[: max(max_samples, 1)] sr_paths = [] for i in tqdm(range(len(lq_paths)), desc="generating"): lq_img_path = lq_paths[i] lq_bgr = cv2.imread(lq_img_path.as_posix(), cv2.IMREAD_COLOR) assert lq_bgr is not None sr_bgr = onnx_runner.run(lq_bgr) hq_img_path = hq_paths[i] hq_bgr = cv2.imread(hq_img_path.as_posix(), cv2.IMREAD_COLOR) sr_bgr, hq_bgr = align_shape(sr_bgr, hq_bgr) out_path = out_dir / f"{lq_img_path.stem}.png" cv2.imwrite(out_path.as_posix(), sr_bgr) sr_paths.append(out_path) return hq_paths, sr_paths def eval_metrics( hq_paths: list[Path], sr_paths: list[Path], hq_dir: Path, sr_dir: Path, device: torch.device | None = None, ) -> dict[str, float]: assert len(hq_paths) == len(sr_paths) device = device or ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) psnr_metric = pyiqa.create_metric("psnr", device=device) # FR: sr, ref ms_ssim_metric = pyiqa.create_metric("ms_ssim", device=device) # FR: sr, ref fid_metric = pyiqa.create_metric("fid") with torch.inference_mode(): psnr_vals = [] ms_ssim_vals = [] for sr_p, hq_p in zip(sr_paths, hq_paths): sr_p = sr_p.as_posix() hq_p = hq_p.as_posix() psnr_vals.append(psnr_metric(sr_p, hq_p).detach()) ms_ssim_vals.append(ms_ssim_metric(sr_p, hq_p).detach()) psnr = torch.stack(psnr_vals).mean().item() ms_ssim = torch.stack(ms_ssim_vals).mean().item() fid = fid_metric( sr_dir.as_posix(), hq_dir.as_posix(), mode="clean", batch_size=1, num_workers=0, ).item() return {"psnr": psnr, "ms_ssim": ms_ssim, "fid": fid} def main(args): onnx_path = Path(args.onnx) latent_path = Path(args.latent) hq_dir = Path(args.hq_dir) lq_dir = Path(args.lq_dir) out_dir = Path(args.out_dir) assert onnx_path.suffix == ".onnx" and onnx_path.is_file() assert latent_path.suffix == ".npy" and latent_path.is_file() assert lq_dir.is_dir(), f"{lq_dir} is not a dir!" assert hq_dir.is_dir(), f"{hq_dir} is not a dir!" sr_dir = out_dir / "sr" hq_paths, sr_paths = gen_sr_images( hq_dir, lq_dir, sr_dir, onnx_path, latent_path, args.max_samples ) scores = eval_metrics(hq_paths, sr_paths, hq_dir, sr_dir) summary = { "onnx": onnx_path.as_posix(), "psnr": scores["psnr"], "ms_ssim": scores["ms_ssim"], "fid": scores["fid"], } out_file = out_dir / f"eval_{onnx_path.stem}_result.json" with open(out_file, "w") as f: json.dump(summary, f, indent=2) dataset_name = hq_dir.parent.name print(f"summary of {dataset_name}: PSNR | MS_SSIM | FID") print( f"{dataset_name}: {scores['psnr']:.2f} | {scores['ms_ssim']:.4f} | {scores['fid']:.2f}" ) print(f"result saved to {out_file}") if args.clean: import shutil print(f"cleaning enhanced lq dir: {sr_dir}") shutil.rmtree(sr_dir.as_posix(), ignore_errors=True) if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument("--onnx", type=str, required=True) parser.add_argument("--latent", type=str, required=True) parser.add_argument("--hq-dir", type=str, required=True) parser.add_argument("--lq-dir", type=str, required=True) parser.add_argument("--out-dir", type=str, default="outputs") parser.add_argument( "--max-samples", type=int, default=None, help="limit number of used samples(debug purpose only), None means not-limited", ) parser.add_argument( "-clean", action="store_true", default=False, help="clean out-dir when finished", ) main(parser.parse_args())