hongw.qin
commited on
Commit
·
d1faacc
1
Parent(s):
567d35c
upload models
Browse files- .gitignore +6 -0
- README.md +107 -0
- latent.npy +3 -0
- onnx_eval.py +206 -0
- onnx_inference.py +53 -0
- onnx_runner.py +119 -0
- psfrgan_nchw_fp32.onnx +3 -0
- psfrgan_nhwc_int8.onnx +3 -0
- requirements-eval.txt +6 -0
- requirements-infer.txt +4 -0
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.vscode/
|
| 2 |
+
.venv/
|
| 3 |
+
*.pyc
|
| 4 |
+
__pycache__/
|
| 5 |
+
outputs/
|
| 6 |
+
datasets/
|
README.md
CHANGED
|
@@ -1,3 +1,110 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- RyzenAI
|
| 5 |
+
- Int8 quantization
|
| 6 |
+
- Face Restoration
|
| 7 |
+
- PSFRGAN
|
| 8 |
+
- ONNX
|
| 9 |
+
- Computer Vision
|
| 10 |
+
metrics:
|
| 11 |
+
- PSNR
|
| 12 |
+
- MS_SSIM
|
| 13 |
+
- FID
|
| 14 |
---
|
| 15 |
+
|
| 16 |
+
# PSFRGAN for face restoration
|
| 17 |
+
|
| 18 |
+
The model operates at 512x512 resolution and is particularly effective at restoring faces with various degradations including blur, noise, compression artifacts, and low resolution.
|
| 19 |
+
|
| 20 |
+
It was introduced in the paper _Progressive Semantic-Aware Style Transformation for Blind Face Restoration_ by Chaofeng Chen et al. at CVPR 2021.
|
| 21 |
+
|
| 22 |
+
We have developed a modified version optimized for [AMD Ryzen AI](https://onnxruntime.ai/docs/execution-providers/Vitis-AI-ExecutionProvider.html).
|
| 23 |
+
|
| 24 |
+
## Model description
|
| 25 |
+
|
| 26 |
+
PSFRGAN (Progressive Semantic-aware Face Restoration Generative Adversarial Network) is a deep learning model designed for blind face restoration, capable of recovering high-quality face images from severely degraded inputs.
|
| 27 |
+
|
| 28 |
+
## Intended uses & limitations
|
| 29 |
+
|
| 30 |
+
You can use this model for face restoration tasks. See the [model hub](https://huggingface.co/models?search=amd/ryzenai-psfrgan) for all available psfrgan models.
|
| 31 |
+
|
| 32 |
+
## How to use
|
| 33 |
+
|
| 34 |
+
### Installation
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
# inference only
|
| 38 |
+
pip install -r requirements-infer.txt
|
| 39 |
+
# inference & evaluation
|
| 40 |
+
pip install -r requirements-eval.txt
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
### Data Preparation (optional: for accuracy evaluation)
|
| 44 |
+
|
| 45 |
+
1. Download `CelebA-Test (LQ)` and `CelebA-Test (HQ)` from [GFP-GAN homepage](https://xinntao.github.io/projects/gfpgan)
|
| 46 |
+
2. Organize the dataset directory as follows:
|
| 47 |
+
|
| 48 |
+
```Plain
|
| 49 |
+
└── datasets
|
| 50 |
+
└── celeba_512_validation
|
| 51 |
+
├── 00000000.png
|
| 52 |
+
├── ...
|
| 53 |
+
├── celeba_512_validation_lq
|
| 54 |
+
├── 00000000.png
|
| 55 |
+
├── ...
|
| 56 |
+
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### Test & Evaluation
|
| 60 |
+
|
| 61 |
+
- Run inference on images
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
python onnx_inference.py --onnx psfrgan_nchw_fp32.onnx --latent latent.npy --input /Path/To/Image --out-dir outputs
|
| 65 |
+
python onnx_inference.py --onnx psfrgan_nhwc_int8.onnx --latent latent.npy --input /Path/To/Image --out-dir outputs
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
**Arguments:**
|
| 69 |
+
|
| 70 |
+
- `--input`: Accepts either a single image file path or a directory path. If it's a file, the script will process that image only. If it's a directory, the script will recursively scan for .png, .jpg, and .jpeg files and process all of them.
|
| 71 |
+
- `--latent`: (Optional) Path to the latent code file (.npy). If not provided, random latent values will be generated with a fixed seed for reproducibility.
|
| 72 |
+
- `--out-dir`: Output directory where the restored images will be saved.
|
| 73 |
+
|
| 74 |
+
- Evaluate the quantized model
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
# eval fp32
|
| 78 |
+
python onnx_eval.py \
|
| 79 |
+
--onnx psfrgan_nchw_fp32.onnx \
|
| 80 |
+
--latent latent.npy \
|
| 81 |
+
--hq-dir datasets/celeba_512_validation \
|
| 82 |
+
--lq-dir datasets/celeba_512_validation_lq \
|
| 83 |
+
--out-dir outputs/fp32 -clean
|
| 84 |
+
|
| 85 |
+
# eval int8
|
| 86 |
+
python onnx_eval.py \
|
| 87 |
+
--onnx psfrgan_nhwc_int8.onnx \
|
| 88 |
+
--latent latent.npy \
|
| 89 |
+
--hq-dir datasets/celeba_512_validation \
|
| 90 |
+
--lq-dir datasets/celeba_512_validation_lq \
|
| 91 |
+
--out-dir outputs/int8 -clean
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
### Performance
|
| 95 |
+
|
| 96 |
+
| Model | PSNR(↑) | MS_SSIM(↑) | FID(↓) |
|
| 97 |
+
| -------------- | ------- | ---------- | ------ |
|
| 98 |
+
| PSFRGAN (fp32) | 25.27 | 0.8500 | 21.99 |
|
| 99 |
+
| PSFRGAN (int8) | 25.27 | 0.8487 | 24.34 |
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
```bibtex
|
| 104 |
+
@inproceedings{ChenPSFRGAN,
|
| 105 |
+
author = {Chen, Chaofeng and Li, Xiaoming and Lingbo, Yang and Lin, Xianhui and Zhang, Lei and Wong, Kwan-Yee~K.},
|
| 106 |
+
title = {Progressive Semantic-Aware Style Transformation for Blind Face Restoration},
|
| 107 |
+
Journal = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
| 108 |
+
year = {2021}
|
| 109 |
+
}
|
| 110 |
+
```
|
latent.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6570f1486bc5366e148bc7bdbd6054bc07e54d3a575bffde060f1f36a742f2b9
|
| 3 |
+
size 1048704
|
onnx_eval.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
sys.path.insert(0, Path(__file__).parent.as_posix())
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import pyiqa
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from onnx_runner import OnnxRunner
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def collect_common_image_pairs(
|
| 17 |
+
lq_dir: Path, hq_dir: Path
|
| 18 |
+
) -> tuple[list[Path], list[Path]]:
|
| 19 |
+
exts = {".png", ".jpg", ".jpeg"}
|
| 20 |
+
|
| 21 |
+
def is_img(p: Path) -> bool:
|
| 22 |
+
return p.is_file() and p.suffix.lower() in exts
|
| 23 |
+
|
| 24 |
+
hq_map = {p.stem: p for p in hq_dir.iterdir() if is_img(p)}
|
| 25 |
+
hq_names = sorted(hq_map.keys())
|
| 26 |
+
|
| 27 |
+
lq_files = [p for p in lq_dir.iterdir() if is_img(p)]
|
| 28 |
+
|
| 29 |
+
lq_paths: list[Path] = []
|
| 30 |
+
hq_paths: list[Path] = []
|
| 31 |
+
for base in hq_names:
|
| 32 |
+
# try full match first
|
| 33 |
+
best_lq = next((p for p in lq_files if p.stem == base), None)
|
| 34 |
+
|
| 35 |
+
# try prefix match then
|
| 36 |
+
if best_lq is None:
|
| 37 |
+
best_lq = next(
|
| 38 |
+
(
|
| 39 |
+
p
|
| 40 |
+
for p in lq_files
|
| 41 |
+
if p.stem.startswith(base) and len(p.stem) > len(base)
|
| 42 |
+
),
|
| 43 |
+
None,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if best_lq is not None: # matched
|
| 47 |
+
hq_paths.append(hq_map[base])
|
| 48 |
+
lq_paths.append(best_lq)
|
| 49 |
+
|
| 50 |
+
return lq_paths, hq_paths
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def align_shape(sr_bgr: np.ndarray, hq_bgr: np.ndarray):
|
| 54 |
+
if sr_bgr.shape != hq_bgr.shape:
|
| 55 |
+
sr_bgr = cv2.resize(
|
| 56 |
+
sr_bgr,
|
| 57 |
+
(hq_bgr.shape[1], hq_bgr.shape[0]),
|
| 58 |
+
interpolation=cv2.INTER_LINEAR,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
return sr_bgr, hq_bgr
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def gen_sr_images(
|
| 65 |
+
hq_dir: Path,
|
| 66 |
+
lq_dir: Path,
|
| 67 |
+
out_dir: Path,
|
| 68 |
+
onnx_path: Path,
|
| 69 |
+
latent_path: Path,
|
| 70 |
+
max_samples: int,
|
| 71 |
+
):
|
| 72 |
+
out_dir.mkdir(exist_ok=True, parents=True)
|
| 73 |
+
|
| 74 |
+
onnx_runner = OnnxRunner(onnx_path, latent_path)
|
| 75 |
+
|
| 76 |
+
lq_paths, hq_paths = collect_common_image_pairs(lq_dir, hq_dir)
|
| 77 |
+
|
| 78 |
+
if max_samples is not None:
|
| 79 |
+
lq_paths = lq_paths[: max(max_samples, 1)]
|
| 80 |
+
hq_paths = hq_paths[: max(max_samples, 1)]
|
| 81 |
+
|
| 82 |
+
sr_paths = []
|
| 83 |
+
for i in tqdm(range(len(lq_paths)), desc="generating"):
|
| 84 |
+
lq_img_path = lq_paths[i]
|
| 85 |
+
lq_bgr = cv2.imread(lq_img_path.as_posix(), cv2.IMREAD_COLOR)
|
| 86 |
+
assert lq_bgr is not None
|
| 87 |
+
sr_bgr = onnx_runner.run(lq_bgr)
|
| 88 |
+
|
| 89 |
+
hq_img_path = hq_paths[i]
|
| 90 |
+
hq_bgr = cv2.imread(hq_img_path.as_posix(), cv2.IMREAD_COLOR)
|
| 91 |
+
|
| 92 |
+
sr_bgr, hq_bgr = align_shape(sr_bgr, hq_bgr)
|
| 93 |
+
|
| 94 |
+
out_path = out_dir / f"{lq_img_path.stem}.png"
|
| 95 |
+
cv2.imwrite(out_path.as_posix(), sr_bgr)
|
| 96 |
+
|
| 97 |
+
sr_paths.append(out_path)
|
| 98 |
+
|
| 99 |
+
return hq_paths, sr_paths
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def eval_metrics(
|
| 103 |
+
hq_paths: list[Path],
|
| 104 |
+
sr_paths: list[Path],
|
| 105 |
+
hq_dir: Path,
|
| 106 |
+
sr_dir: Path,
|
| 107 |
+
device: torch.device | None = None,
|
| 108 |
+
) -> dict[str, float]:
|
| 109 |
+
assert len(hq_paths) == len(sr_paths)
|
| 110 |
+
|
| 111 |
+
device = device or (
|
| 112 |
+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
psnr_metric = pyiqa.create_metric("psnr", device=device) # FR: sr, ref
|
| 116 |
+
ms_ssim_metric = pyiqa.create_metric("ms_ssim", device=device) # FR: sr, ref
|
| 117 |
+
fid_metric = pyiqa.create_metric("fid")
|
| 118 |
+
|
| 119 |
+
with torch.inference_mode():
|
| 120 |
+
psnr_vals = []
|
| 121 |
+
ms_ssim_vals = []
|
| 122 |
+
for sr_p, hq_p in zip(sr_paths, hq_paths):
|
| 123 |
+
sr_p = sr_p.as_posix()
|
| 124 |
+
hq_p = hq_p.as_posix()
|
| 125 |
+
psnr_vals.append(psnr_metric(sr_p, hq_p).detach())
|
| 126 |
+
ms_ssim_vals.append(ms_ssim_metric(sr_p, hq_p).detach())
|
| 127 |
+
|
| 128 |
+
psnr = torch.stack(psnr_vals).mean().item()
|
| 129 |
+
ms_ssim = torch.stack(ms_ssim_vals).mean().item()
|
| 130 |
+
|
| 131 |
+
fid = fid_metric(
|
| 132 |
+
sr_dir.as_posix(),
|
| 133 |
+
hq_dir.as_posix(),
|
| 134 |
+
mode="clean",
|
| 135 |
+
batch_size=1,
|
| 136 |
+
num_workers=0,
|
| 137 |
+
).item()
|
| 138 |
+
|
| 139 |
+
return {"psnr": psnr, "ms_ssim": ms_ssim, "fid": fid}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def main(args):
|
| 143 |
+
onnx_path = Path(args.onnx)
|
| 144 |
+
latent_path = Path(args.latent)
|
| 145 |
+
hq_dir = Path(args.hq_dir)
|
| 146 |
+
lq_dir = Path(args.lq_dir)
|
| 147 |
+
out_dir = Path(args.out_dir)
|
| 148 |
+
|
| 149 |
+
assert onnx_path.suffix == ".onnx" and onnx_path.is_file()
|
| 150 |
+
assert latent_path.suffix == ".npy" and latent_path.is_file()
|
| 151 |
+
assert lq_dir.is_dir(), f"{lq_dir} is not a dir!"
|
| 152 |
+
assert hq_dir.is_dir(), f"{hq_dir} is not a dir!"
|
| 153 |
+
|
| 154 |
+
sr_dir = out_dir / "sr"
|
| 155 |
+
hq_paths, sr_paths = gen_sr_images(
|
| 156 |
+
hq_dir, lq_dir, sr_dir, onnx_path, latent_path, args.max_samples
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
scores = eval_metrics(hq_paths, sr_paths, hq_dir, sr_dir)
|
| 160 |
+
|
| 161 |
+
summary = {
|
| 162 |
+
"onnx": onnx_path.as_posix(),
|
| 163 |
+
"psnr": scores["psnr"],
|
| 164 |
+
"ms_ssim": scores["ms_ssim"],
|
| 165 |
+
"fid": scores["fid"],
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
out_file = out_dir / f"eval_{onnx_path.stem}_result.json"
|
| 169 |
+
with open(out_file, "w") as f:
|
| 170 |
+
json.dump(summary, f, indent=2)
|
| 171 |
+
dataset_name = hq_dir.parent.name
|
| 172 |
+
print(f"summary of {dataset_name}: PSNR | MS_SSIM | FID")
|
| 173 |
+
print(
|
| 174 |
+
f"{dataset_name}: {scores['psnr']:.2f} | {scores['ms_ssim']:.4f} | {scores['fid']:.2f}"
|
| 175 |
+
)
|
| 176 |
+
print(f"result saved to {out_file}")
|
| 177 |
+
|
| 178 |
+
if args.clean:
|
| 179 |
+
import shutil
|
| 180 |
+
|
| 181 |
+
print(f"cleaning enhanced lq dir: {sr_dir}")
|
| 182 |
+
shutil.rmtree(sr_dir.as_posix(), ignore_errors=True)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
from argparse import ArgumentParser
|
| 187 |
+
|
| 188 |
+
parser = ArgumentParser()
|
| 189 |
+
parser.add_argument("--onnx", type=str, required=True)
|
| 190 |
+
parser.add_argument("--latent", type=str, required=True)
|
| 191 |
+
parser.add_argument("--hq-dir", type=str, required=True)
|
| 192 |
+
parser.add_argument("--lq-dir", type=str, required=True)
|
| 193 |
+
parser.add_argument("--out-dir", type=str, default="outputs")
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--max-samples",
|
| 196 |
+
type=int,
|
| 197 |
+
default=None,
|
| 198 |
+
help="limit number of used samples(debug purpose only), None means not-limited",
|
| 199 |
+
)
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"-clean",
|
| 202 |
+
action="store_true",
|
| 203 |
+
default=False,
|
| 204 |
+
help="clean out-dir when finished",
|
| 205 |
+
)
|
| 206 |
+
main(parser.parse_args())
|
onnx_inference.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
sys.path.insert(0, Path(__file__).parent.as_posix())
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
from onnx_runner import OnnxRunner
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main(args):
|
| 13 |
+
onnx_path = Path(args.onnx)
|
| 14 |
+
input_path = Path(args.input)
|
| 15 |
+
out_dir = Path(args.out_dir)
|
| 16 |
+
|
| 17 |
+
assert onnx_path.suffix == ".onnx"
|
| 18 |
+
|
| 19 |
+
if input_path.is_file():
|
| 20 |
+
input_images_path = [input_path]
|
| 21 |
+
else:
|
| 22 |
+
input_images_path = sorted(
|
| 23 |
+
[
|
| 24 |
+
p
|
| 25 |
+
for p in input_path.rglob("*")
|
| 26 |
+
if p.suffix.lower() in (".png", ".jpg", ".jpeg")
|
| 27 |
+
]
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
out_dir.mkdir(exist_ok=True, parents=True)
|
| 31 |
+
onnx_runner = OnnxRunner(onnx_path, args.latent)
|
| 32 |
+
for input_img_path in input_images_path:
|
| 33 |
+
input_img_path: Path
|
| 34 |
+
|
| 35 |
+
input_bgr = cv2.imread(input_img_path.as_posix(), cv2.IMREAD_COLOR)
|
| 36 |
+
assert input_bgr is not None
|
| 37 |
+
out_bgr = onnx_runner.run(input_bgr)
|
| 38 |
+
|
| 39 |
+
out_path = out_dir / f"{input_img_path.stem}.png"
|
| 40 |
+
cv2.imwrite(out_path.as_posix(), out_bgr)
|
| 41 |
+
print(f"saved {out_path}")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
from argparse import ArgumentParser
|
| 46 |
+
|
| 47 |
+
parser = ArgumentParser()
|
| 48 |
+
parser.add_argument("--onnx", type=str, required=True)
|
| 49 |
+
parser.add_argument("--input", type=str, required=True)
|
| 50 |
+
parser.add_argument("--out-dir", type=str, required=True)
|
| 51 |
+
parser.add_argument("--latent", type=str, default=None)
|
| 52 |
+
|
| 53 |
+
main(parser.parse_args())
|
onnx_runner.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import onnxruntime as ort
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def parse_input_shape_fmt(input_shape):
|
| 9 |
+
"""parse input shape is nchw or nhwc format.
|
| 10 |
+
We assume c is smaller than h&w dimensions
|
| 11 |
+
"""
|
| 12 |
+
assert len(input_shape) == 4
|
| 13 |
+
|
| 14 |
+
c1, c2, c3 = input_shape[1:]
|
| 15 |
+
|
| 16 |
+
if c1 < min(c2, c3): # c1 is channel dimension
|
| 17 |
+
return "nchw"
|
| 18 |
+
elif c3 < min(c1, c2): # c3 is channel dimension
|
| 19 |
+
return "nhwc"
|
| 20 |
+
else:
|
| 21 |
+
raise ValueError(f"can not parse input format for shape: {input_shape}")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def preprocess(img_bgr: np.ndarray, input_shape_hw: tuple[int, int]):
|
| 25 |
+
in_h, in_w = input_shape_hw
|
| 26 |
+
|
| 27 |
+
resized_bgr = cv2.resize(img_bgr, (in_w, in_h), interpolation=cv2.INTER_LINEAR)
|
| 28 |
+
resized_rgb = cv2.cvtColor(resized_bgr, cv2.COLOR_BGR2RGB)
|
| 29 |
+
normed_rgb = (resized_rgb / 255.0 - 0.5) / 0.5 # norm 0~255 -> -1~1
|
| 30 |
+
|
| 31 |
+
return normed_rgb
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def postprocess(pred_3d: np.ndarray, pred_fmt: str, origin_hw: tuple[int, int]):
|
| 35 |
+
de_normed_3d = (pred_3d * 0.5 + 0.5) * 255 # de-norm -1~1 -> 0~255
|
| 36 |
+
|
| 37 |
+
if pred_fmt == "nchw":
|
| 38 |
+
hwc = np.transpose(de_normed_3d, [1, 2, 0]) # chw -> hwc
|
| 39 |
+
else: # nhwc
|
| 40 |
+
hwc = de_normed_3d # unchanged
|
| 41 |
+
|
| 42 |
+
pred_rgb = np.clip(hwc, 0, 255).astype(np.uint8)
|
| 43 |
+
pred_bgr = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2BGR)
|
| 44 |
+
|
| 45 |
+
if tuple(pred_bgr.shape[:2]) != tuple(origin_hw):
|
| 46 |
+
pred_bgr = cv2.resize(pred_bgr, origin_hw[::-1], interpolation=cv2.INTER_LINEAR)
|
| 47 |
+
|
| 48 |
+
return pred_bgr
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class OnnxRunner:
|
| 52 |
+
def __init__(self, onnx_path, latent_path=None, debug=False):
|
| 53 |
+
if "CUDAExecutionProvider" in ort.get_available_providers():
|
| 54 |
+
providers = ["CUDAExecutionProvider"]
|
| 55 |
+
else:
|
| 56 |
+
providers = ["CPUExecutionProvider"]
|
| 57 |
+
|
| 58 |
+
ort_session = ort.InferenceSession(str(onnx_path), providers=providers)
|
| 59 |
+
|
| 60 |
+
input0 = ort_session.get_inputs()[0]
|
| 61 |
+
self.input_name = input0.name
|
| 62 |
+
self.input_shape = tuple(input0.shape)
|
| 63 |
+
self.input_format = parse_input_shape_fmt(input0.shape)
|
| 64 |
+
self.ort_session = ort_session
|
| 65 |
+
self.debug = debug
|
| 66 |
+
|
| 67 |
+
if self.input_format == "nchw":
|
| 68 |
+
self._in_h, self._in_w = self.input_shape[2:]
|
| 69 |
+
else: # nhwc
|
| 70 |
+
self._in_h, self._in_w = self.input_shape[1:3]
|
| 71 |
+
|
| 72 |
+
if len(ort_session.get_inputs()) == 2:
|
| 73 |
+
latent_input = ort_session.get_inputs()[1]
|
| 74 |
+
self.latent_input_name = latent_input.name
|
| 75 |
+
if latent_path is not None and Path(latent_path).is_file():
|
| 76 |
+
latent = np.load(str(latent_path)) # nchw format
|
| 77 |
+
latent = np.transpose(latent, [0, 2, 3, 1]) # nchw -> nhwc
|
| 78 |
+
else:
|
| 79 |
+
rng = np.random.default_rng(seed=5122)
|
| 80 |
+
latent = rng.standard_normal(latent_input.shape)
|
| 81 |
+
self.latent = np.float32(latent)
|
| 82 |
+
else:
|
| 83 |
+
self.latent_input_name = None
|
| 84 |
+
|
| 85 |
+
if debug:
|
| 86 |
+
self._dbg_out_dir = Path(__file__).parent / "outputs"
|
| 87 |
+
self._dbg_out_dir.mkdir(exist_ok=True, parents=True)
|
| 88 |
+
|
| 89 |
+
def run(self, original_bgr: np.ndarray) -> np.ndarray:
|
| 90 |
+
"""Enhance given uint8 bgr image, and return enhanced uint8 bgr image."""
|
| 91 |
+
assert original_bgr.dtype == np.uint8
|
| 92 |
+
assert original_bgr.ndim == 3
|
| 93 |
+
assert original_bgr.shape[2] == 3
|
| 94 |
+
|
| 95 |
+
# =====================
|
| 96 |
+
# preprocessing
|
| 97 |
+
# =====================
|
| 98 |
+
input_hwc = preprocess(original_bgr, (self._in_h, self._in_w))
|
| 99 |
+
|
| 100 |
+
# =====================
|
| 101 |
+
# inference
|
| 102 |
+
# =====================
|
| 103 |
+
if self.input_format == "nchw":
|
| 104 |
+
input_3d = np.transpose(input_hwc, [2, 0, 1]) # hwc -> chw
|
| 105 |
+
else: # nhwc
|
| 106 |
+
input_3d = input_hwc
|
| 107 |
+
|
| 108 |
+
feed = {
|
| 109 |
+
self.input_name: np.float32(input_3d[None, ...]),
|
| 110 |
+
}
|
| 111 |
+
if self.latent_input_name is not None:
|
| 112 |
+
feed[self.latent_input_name] = self.latent
|
| 113 |
+
|
| 114 |
+
outputs = self.ort_session.run(None, feed)
|
| 115 |
+
|
| 116 |
+
pred_3d: np.ndarray = outputs[0][0]
|
| 117 |
+
enhanced_bgr = postprocess(pred_3d, self.input_format, original_bgr.shape[:2])
|
| 118 |
+
|
| 119 |
+
return enhanced_bgr
|
psfrgan_nchw_fp32.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a4869eef68b926f381b921d4322052ce78d98dbc9419d38cd7e21c8b757e3dc0
|
| 3 |
+
size 26298729
|
psfrgan_nhwc_int8.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:21a7d09b2406ddce9a4540c6088152223b25aef6af13c1b0524b7b6c757d6a78
|
| 3 |
+
size 25331858
|
requirements-eval.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
onnxruntime==1.22
|
| 2 |
+
numpy==1.26.*
|
| 3 |
+
opencv-python==4.8.*
|
| 4 |
+
tqdm
|
| 5 |
+
torch==2.6.0
|
| 6 |
+
pyiqa @ git+https://github.com/chaofengc/IQA-PyTorch.git@e851fd62e66a97345e1281d80e8deb4ab7b93c83
|
requirements-infer.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
onnxruntime==1.22
|
| 2 |
+
numpy==1.26.*
|
| 3 |
+
opencv-python==4.8.*
|
| 4 |
+
tqdm
|