ryzenai-psfrgan / onnx_eval.py
hongw.qin
upload models
d1faacc
import sys
import json
from pathlib import Path
sys.path.insert(0, Path(__file__).parent.as_posix())
import cv2
import pyiqa
import torch
import numpy as np
from tqdm import tqdm
from onnx_runner import OnnxRunner
def collect_common_image_pairs(
lq_dir: Path, hq_dir: Path
) -> tuple[list[Path], list[Path]]:
exts = {".png", ".jpg", ".jpeg"}
def is_img(p: Path) -> bool:
return p.is_file() and p.suffix.lower() in exts
hq_map = {p.stem: p for p in hq_dir.iterdir() if is_img(p)}
hq_names = sorted(hq_map.keys())
lq_files = [p for p in lq_dir.iterdir() if is_img(p)]
lq_paths: list[Path] = []
hq_paths: list[Path] = []
for base in hq_names:
# try full match first
best_lq = next((p for p in lq_files if p.stem == base), None)
# try prefix match then
if best_lq is None:
best_lq = next(
(
p
for p in lq_files
if p.stem.startswith(base) and len(p.stem) > len(base)
),
None,
)
if best_lq is not None: # matched
hq_paths.append(hq_map[base])
lq_paths.append(best_lq)
return lq_paths, hq_paths
def align_shape(sr_bgr: np.ndarray, hq_bgr: np.ndarray):
if sr_bgr.shape != hq_bgr.shape:
sr_bgr = cv2.resize(
sr_bgr,
(hq_bgr.shape[1], hq_bgr.shape[0]),
interpolation=cv2.INTER_LINEAR,
)
return sr_bgr, hq_bgr
def gen_sr_images(
hq_dir: Path,
lq_dir: Path,
out_dir: Path,
onnx_path: Path,
latent_path: Path,
max_samples: int,
):
out_dir.mkdir(exist_ok=True, parents=True)
onnx_runner = OnnxRunner(onnx_path, latent_path)
lq_paths, hq_paths = collect_common_image_pairs(lq_dir, hq_dir)
if max_samples is not None:
lq_paths = lq_paths[: max(max_samples, 1)]
hq_paths = hq_paths[: max(max_samples, 1)]
sr_paths = []
for i in tqdm(range(len(lq_paths)), desc="generating"):
lq_img_path = lq_paths[i]
lq_bgr = cv2.imread(lq_img_path.as_posix(), cv2.IMREAD_COLOR)
assert lq_bgr is not None
sr_bgr = onnx_runner.run(lq_bgr)
hq_img_path = hq_paths[i]
hq_bgr = cv2.imread(hq_img_path.as_posix(), cv2.IMREAD_COLOR)
sr_bgr, hq_bgr = align_shape(sr_bgr, hq_bgr)
out_path = out_dir / f"{lq_img_path.stem}.png"
cv2.imwrite(out_path.as_posix(), sr_bgr)
sr_paths.append(out_path)
return hq_paths, sr_paths
def eval_metrics(
hq_paths: list[Path],
sr_paths: list[Path],
hq_dir: Path,
sr_dir: Path,
device: torch.device | None = None,
) -> dict[str, float]:
assert len(hq_paths) == len(sr_paths)
device = device or (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
psnr_metric = pyiqa.create_metric("psnr", device=device) # FR: sr, ref
ms_ssim_metric = pyiqa.create_metric("ms_ssim", device=device) # FR: sr, ref
fid_metric = pyiqa.create_metric("fid")
with torch.inference_mode():
psnr_vals = []
ms_ssim_vals = []
for sr_p, hq_p in zip(sr_paths, hq_paths):
sr_p = sr_p.as_posix()
hq_p = hq_p.as_posix()
psnr_vals.append(psnr_metric(sr_p, hq_p).detach())
ms_ssim_vals.append(ms_ssim_metric(sr_p, hq_p).detach())
psnr = torch.stack(psnr_vals).mean().item()
ms_ssim = torch.stack(ms_ssim_vals).mean().item()
fid = fid_metric(
sr_dir.as_posix(),
hq_dir.as_posix(),
mode="clean",
batch_size=1,
num_workers=0,
).item()
return {"psnr": psnr, "ms_ssim": ms_ssim, "fid": fid}
def main(args):
onnx_path = Path(args.onnx)
latent_path = Path(args.latent)
hq_dir = Path(args.hq_dir)
lq_dir = Path(args.lq_dir)
out_dir = Path(args.out_dir)
assert onnx_path.suffix == ".onnx" and onnx_path.is_file()
assert latent_path.suffix == ".npy" and latent_path.is_file()
assert lq_dir.is_dir(), f"{lq_dir} is not a dir!"
assert hq_dir.is_dir(), f"{hq_dir} is not a dir!"
sr_dir = out_dir / "sr"
hq_paths, sr_paths = gen_sr_images(
hq_dir, lq_dir, sr_dir, onnx_path, latent_path, args.max_samples
)
scores = eval_metrics(hq_paths, sr_paths, hq_dir, sr_dir)
summary = {
"onnx": onnx_path.as_posix(),
"psnr": scores["psnr"],
"ms_ssim": scores["ms_ssim"],
"fid": scores["fid"],
}
out_file = out_dir / f"eval_{onnx_path.stem}_result.json"
with open(out_file, "w") as f:
json.dump(summary, f, indent=2)
dataset_name = hq_dir.parent.name
print(f"summary of {dataset_name}: PSNR | MS_SSIM | FID")
print(
f"{dataset_name}: {scores['psnr']:.2f} | {scores['ms_ssim']:.4f} | {scores['fid']:.2f}"
)
print(f"result saved to {out_file}")
if args.clean:
import shutil
print(f"cleaning enhanced lq dir: {sr_dir}")
shutil.rmtree(sr_dir.as_posix(), ignore_errors=True)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--onnx", type=str, required=True)
parser.add_argument("--latent", type=str, required=True)
parser.add_argument("--hq-dir", type=str, required=True)
parser.add_argument("--lq-dir", type=str, required=True)
parser.add_argument("--out-dir", type=str, default="outputs")
parser.add_argument(
"--max-samples",
type=int,
default=None,
help="limit number of used samples(debug purpose only), None means not-limited",
)
parser.add_argument(
"-clean",
action="store_true",
default=False,
help="clean out-dir when finished",
)
main(parser.parse_args())