|
|
import sys |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
sys.path.insert(0, Path(__file__).parent.as_posix()) |
|
|
|
|
|
|
|
|
import cv2 |
|
|
import pyiqa |
|
|
import torch |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
from onnx_runner import OnnxRunner |
|
|
|
|
|
|
|
|
def collect_common_image_pairs( |
|
|
lq_dir: Path, hq_dir: Path |
|
|
) -> tuple[list[Path], list[Path]]: |
|
|
exts = {".png", ".jpg", ".jpeg"} |
|
|
|
|
|
def is_img(p: Path) -> bool: |
|
|
return p.is_file() and p.suffix.lower() in exts |
|
|
|
|
|
hq_map = {p.stem: p for p in hq_dir.iterdir() if is_img(p)} |
|
|
hq_names = sorted(hq_map.keys()) |
|
|
|
|
|
lq_files = [p for p in lq_dir.iterdir() if is_img(p)] |
|
|
|
|
|
lq_paths: list[Path] = [] |
|
|
hq_paths: list[Path] = [] |
|
|
for base in hq_names: |
|
|
|
|
|
best_lq = next((p for p in lq_files if p.stem == base), None) |
|
|
|
|
|
|
|
|
if best_lq is None: |
|
|
best_lq = next( |
|
|
( |
|
|
p |
|
|
for p in lq_files |
|
|
if p.stem.startswith(base) and len(p.stem) > len(base) |
|
|
), |
|
|
None, |
|
|
) |
|
|
|
|
|
if best_lq is not None: |
|
|
hq_paths.append(hq_map[base]) |
|
|
lq_paths.append(best_lq) |
|
|
|
|
|
return lq_paths, hq_paths |
|
|
|
|
|
|
|
|
def align_shape(sr_bgr: np.ndarray, hq_bgr: np.ndarray): |
|
|
if sr_bgr.shape != hq_bgr.shape: |
|
|
sr_bgr = cv2.resize( |
|
|
sr_bgr, |
|
|
(hq_bgr.shape[1], hq_bgr.shape[0]), |
|
|
interpolation=cv2.INTER_LINEAR, |
|
|
) |
|
|
|
|
|
return sr_bgr, hq_bgr |
|
|
|
|
|
|
|
|
def gen_sr_images( |
|
|
hq_dir: Path, |
|
|
lq_dir: Path, |
|
|
out_dir: Path, |
|
|
onnx_path: Path, |
|
|
latent_path: Path, |
|
|
max_samples: int, |
|
|
): |
|
|
out_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
onnx_runner = OnnxRunner(onnx_path, latent_path) |
|
|
|
|
|
lq_paths, hq_paths = collect_common_image_pairs(lq_dir, hq_dir) |
|
|
|
|
|
if max_samples is not None: |
|
|
lq_paths = lq_paths[: max(max_samples, 1)] |
|
|
hq_paths = hq_paths[: max(max_samples, 1)] |
|
|
|
|
|
sr_paths = [] |
|
|
for i in tqdm(range(len(lq_paths)), desc="generating"): |
|
|
lq_img_path = lq_paths[i] |
|
|
lq_bgr = cv2.imread(lq_img_path.as_posix(), cv2.IMREAD_COLOR) |
|
|
assert lq_bgr is not None |
|
|
sr_bgr = onnx_runner.run(lq_bgr) |
|
|
|
|
|
hq_img_path = hq_paths[i] |
|
|
hq_bgr = cv2.imread(hq_img_path.as_posix(), cv2.IMREAD_COLOR) |
|
|
|
|
|
sr_bgr, hq_bgr = align_shape(sr_bgr, hq_bgr) |
|
|
|
|
|
out_path = out_dir / f"{lq_img_path.stem}.png" |
|
|
cv2.imwrite(out_path.as_posix(), sr_bgr) |
|
|
|
|
|
sr_paths.append(out_path) |
|
|
|
|
|
return hq_paths, sr_paths |
|
|
|
|
|
|
|
|
def eval_metrics( |
|
|
hq_paths: list[Path], |
|
|
sr_paths: list[Path], |
|
|
hq_dir: Path, |
|
|
sr_dir: Path, |
|
|
device: torch.device | None = None, |
|
|
) -> dict[str, float]: |
|
|
assert len(hq_paths) == len(sr_paths) |
|
|
|
|
|
device = device or ( |
|
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
) |
|
|
|
|
|
psnr_metric = pyiqa.create_metric("psnr", device=device) |
|
|
ms_ssim_metric = pyiqa.create_metric("ms_ssim", device=device) |
|
|
fid_metric = pyiqa.create_metric("fid") |
|
|
|
|
|
with torch.inference_mode(): |
|
|
psnr_vals = [] |
|
|
ms_ssim_vals = [] |
|
|
for sr_p, hq_p in zip(sr_paths, hq_paths): |
|
|
sr_p = sr_p.as_posix() |
|
|
hq_p = hq_p.as_posix() |
|
|
psnr_vals.append(psnr_metric(sr_p, hq_p).detach()) |
|
|
ms_ssim_vals.append(ms_ssim_metric(sr_p, hq_p).detach()) |
|
|
|
|
|
psnr = torch.stack(psnr_vals).mean().item() |
|
|
ms_ssim = torch.stack(ms_ssim_vals).mean().item() |
|
|
|
|
|
fid = fid_metric( |
|
|
sr_dir.as_posix(), |
|
|
hq_dir.as_posix(), |
|
|
mode="clean", |
|
|
batch_size=1, |
|
|
num_workers=0, |
|
|
).item() |
|
|
|
|
|
return {"psnr": psnr, "ms_ssim": ms_ssim, "fid": fid} |
|
|
|
|
|
|
|
|
def main(args): |
|
|
onnx_path = Path(args.onnx) |
|
|
latent_path = Path(args.latent) |
|
|
hq_dir = Path(args.hq_dir) |
|
|
lq_dir = Path(args.lq_dir) |
|
|
out_dir = Path(args.out_dir) |
|
|
|
|
|
assert onnx_path.suffix == ".onnx" and onnx_path.is_file() |
|
|
assert latent_path.suffix == ".npy" and latent_path.is_file() |
|
|
assert lq_dir.is_dir(), f"{lq_dir} is not a dir!" |
|
|
assert hq_dir.is_dir(), f"{hq_dir} is not a dir!" |
|
|
|
|
|
sr_dir = out_dir / "sr" |
|
|
hq_paths, sr_paths = gen_sr_images( |
|
|
hq_dir, lq_dir, sr_dir, onnx_path, latent_path, args.max_samples |
|
|
) |
|
|
|
|
|
scores = eval_metrics(hq_paths, sr_paths, hq_dir, sr_dir) |
|
|
|
|
|
summary = { |
|
|
"onnx": onnx_path.as_posix(), |
|
|
"psnr": scores["psnr"], |
|
|
"ms_ssim": scores["ms_ssim"], |
|
|
"fid": scores["fid"], |
|
|
} |
|
|
|
|
|
out_file = out_dir / f"eval_{onnx_path.stem}_result.json" |
|
|
with open(out_file, "w") as f: |
|
|
json.dump(summary, f, indent=2) |
|
|
dataset_name = hq_dir.parent.name |
|
|
print(f"summary of {dataset_name}: PSNR | MS_SSIM | FID") |
|
|
print( |
|
|
f"{dataset_name}: {scores['psnr']:.2f} | {scores['ms_ssim']:.4f} | {scores['fid']:.2f}" |
|
|
) |
|
|
print(f"result saved to {out_file}") |
|
|
|
|
|
if args.clean: |
|
|
import shutil |
|
|
|
|
|
print(f"cleaning enhanced lq dir: {sr_dir}") |
|
|
shutil.rmtree(sr_dir.as_posix(), ignore_errors=True) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
from argparse import ArgumentParser |
|
|
|
|
|
parser = ArgumentParser() |
|
|
parser.add_argument("--onnx", type=str, required=True) |
|
|
parser.add_argument("--latent", type=str, required=True) |
|
|
parser.add_argument("--hq-dir", type=str, required=True) |
|
|
parser.add_argument("--lq-dir", type=str, required=True) |
|
|
parser.add_argument("--out-dir", type=str, default="outputs") |
|
|
parser.add_argument( |
|
|
"--max-samples", |
|
|
type=int, |
|
|
default=None, |
|
|
help="limit number of used samples(debug purpose only), None means not-limited", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-clean", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="clean out-dir when finished", |
|
|
) |
|
|
main(parser.parse_args()) |
|
|
|