hongw.qin commited on
Commit
d1faacc
·
1 Parent(s): 567d35c

upload models

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .vscode/
2
+ .venv/
3
+ *.pyc
4
+ __pycache__/
5
+ outputs/
6
+ datasets/
README.md CHANGED
@@ -1,3 +1,110 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ tags:
4
+ - RyzenAI
5
+ - Int8 quantization
6
+ - Face Restoration
7
+ - PSFRGAN
8
+ - ONNX
9
+ - Computer Vision
10
+ metrics:
11
+ - PSNR
12
+ - MS_SSIM
13
+ - FID
14
  ---
15
+
16
+ # PSFRGAN for face restoration
17
+
18
+ The model operates at 512x512 resolution and is particularly effective at restoring faces with various degradations including blur, noise, compression artifacts, and low resolution.
19
+
20
+ It was introduced in the paper _Progressive Semantic-Aware Style Transformation for Blind Face Restoration_ by Chaofeng Chen et al. at CVPR 2021.
21
+
22
+ We have developed a modified version optimized for [AMD Ryzen AI](https://onnxruntime.ai/docs/execution-providers/Vitis-AI-ExecutionProvider.html).
23
+
24
+ ## Model description
25
+
26
+ PSFRGAN (Progressive Semantic-aware Face Restoration Generative Adversarial Network) is a deep learning model designed for blind face restoration, capable of recovering high-quality face images from severely degraded inputs.
27
+
28
+ ## Intended uses & limitations
29
+
30
+ You can use this model for face restoration tasks. See the [model hub](https://huggingface.co/models?search=amd/ryzenai-psfrgan) for all available psfrgan models.
31
+
32
+ ## How to use
33
+
34
+ ### Installation
35
+
36
+ ```bash
37
+ # inference only
38
+ pip install -r requirements-infer.txt
39
+ # inference & evaluation
40
+ pip install -r requirements-eval.txt
41
+ ```
42
+
43
+ ### Data Preparation (optional: for accuracy evaluation)
44
+
45
+ 1. Download `CelebA-Test (LQ)` and `CelebA-Test (HQ)` from [GFP-GAN homepage](https://xinntao.github.io/projects/gfpgan)
46
+ 2. Organize the dataset directory as follows:
47
+
48
+ ```Plain
49
+ └── datasets
50
+ └── celeba_512_validation
51
+ ├── 00000000.png
52
+ ├── ...
53
+ ├── celeba_512_validation_lq
54
+ ├── 00000000.png
55
+ ├── ...
56
+
57
+ ```
58
+
59
+ ### Test & Evaluation
60
+
61
+ - Run inference on images
62
+
63
+ ```bash
64
+ python onnx_inference.py --onnx psfrgan_nchw_fp32.onnx --latent latent.npy --input /Path/To/Image --out-dir outputs
65
+ python onnx_inference.py --onnx psfrgan_nhwc_int8.onnx --latent latent.npy --input /Path/To/Image --out-dir outputs
66
+ ```
67
+
68
+ **Arguments:**
69
+
70
+ - `--input`: Accepts either a single image file path or a directory path. If it's a file, the script will process that image only. If it's a directory, the script will recursively scan for .png, .jpg, and .jpeg files and process all of them.
71
+ - `--latent`: (Optional) Path to the latent code file (.npy). If not provided, random latent values will be generated with a fixed seed for reproducibility.
72
+ - `--out-dir`: Output directory where the restored images will be saved.
73
+
74
+ - Evaluate the quantized model
75
+
76
+ ```bash
77
+ # eval fp32
78
+ python onnx_eval.py \
79
+ --onnx psfrgan_nchw_fp32.onnx \
80
+ --latent latent.npy \
81
+ --hq-dir datasets/celeba_512_validation \
82
+ --lq-dir datasets/celeba_512_validation_lq \
83
+ --out-dir outputs/fp32 -clean
84
+
85
+ # eval int8
86
+ python onnx_eval.py \
87
+ --onnx psfrgan_nhwc_int8.onnx \
88
+ --latent latent.npy \
89
+ --hq-dir datasets/celeba_512_validation \
90
+ --lq-dir datasets/celeba_512_validation_lq \
91
+ --out-dir outputs/int8 -clean
92
+ ```
93
+
94
+ ### Performance
95
+
96
+ | Model | PSNR(↑) | MS_SSIM(↑) | FID(↓) |
97
+ | -------------- | ------- | ---------- | ------ |
98
+ | PSFRGAN (fp32) | 25.27 | 0.8500 | 21.99 |
99
+ | PSFRGAN (int8) | 25.27 | 0.8487 | 24.34 |
100
+
101
+ ---
102
+
103
+ ```bibtex
104
+ @inproceedings{ChenPSFRGAN,
105
+ author = {Chen, Chaofeng and Li, Xiaoming and Lingbo, Yang and Lin, Xianhui and Zhang, Lei and Wong, Kwan-Yee~K.},
106
+ title = {Progressive Semantic-Aware Style Transformation for Blind Face Restoration},
107
+ Journal = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
108
+ year = {2021}
109
+ }
110
+ ```
latent.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6570f1486bc5366e148bc7bdbd6054bc07e54d3a575bffde060f1f36a742f2b9
3
+ size 1048704
onnx_eval.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import json
3
+ from pathlib import Path
4
+
5
+ sys.path.insert(0, Path(__file__).parent.as_posix())
6
+
7
+
8
+ import cv2
9
+ import pyiqa
10
+ import torch
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ from onnx_runner import OnnxRunner
14
+
15
+
16
+ def collect_common_image_pairs(
17
+ lq_dir: Path, hq_dir: Path
18
+ ) -> tuple[list[Path], list[Path]]:
19
+ exts = {".png", ".jpg", ".jpeg"}
20
+
21
+ def is_img(p: Path) -> bool:
22
+ return p.is_file() and p.suffix.lower() in exts
23
+
24
+ hq_map = {p.stem: p for p in hq_dir.iterdir() if is_img(p)}
25
+ hq_names = sorted(hq_map.keys())
26
+
27
+ lq_files = [p for p in lq_dir.iterdir() if is_img(p)]
28
+
29
+ lq_paths: list[Path] = []
30
+ hq_paths: list[Path] = []
31
+ for base in hq_names:
32
+ # try full match first
33
+ best_lq = next((p for p in lq_files if p.stem == base), None)
34
+
35
+ # try prefix match then
36
+ if best_lq is None:
37
+ best_lq = next(
38
+ (
39
+ p
40
+ for p in lq_files
41
+ if p.stem.startswith(base) and len(p.stem) > len(base)
42
+ ),
43
+ None,
44
+ )
45
+
46
+ if best_lq is not None: # matched
47
+ hq_paths.append(hq_map[base])
48
+ lq_paths.append(best_lq)
49
+
50
+ return lq_paths, hq_paths
51
+
52
+
53
+ def align_shape(sr_bgr: np.ndarray, hq_bgr: np.ndarray):
54
+ if sr_bgr.shape != hq_bgr.shape:
55
+ sr_bgr = cv2.resize(
56
+ sr_bgr,
57
+ (hq_bgr.shape[1], hq_bgr.shape[0]),
58
+ interpolation=cv2.INTER_LINEAR,
59
+ )
60
+
61
+ return sr_bgr, hq_bgr
62
+
63
+
64
+ def gen_sr_images(
65
+ hq_dir: Path,
66
+ lq_dir: Path,
67
+ out_dir: Path,
68
+ onnx_path: Path,
69
+ latent_path: Path,
70
+ max_samples: int,
71
+ ):
72
+ out_dir.mkdir(exist_ok=True, parents=True)
73
+
74
+ onnx_runner = OnnxRunner(onnx_path, latent_path)
75
+
76
+ lq_paths, hq_paths = collect_common_image_pairs(lq_dir, hq_dir)
77
+
78
+ if max_samples is not None:
79
+ lq_paths = lq_paths[: max(max_samples, 1)]
80
+ hq_paths = hq_paths[: max(max_samples, 1)]
81
+
82
+ sr_paths = []
83
+ for i in tqdm(range(len(lq_paths)), desc="generating"):
84
+ lq_img_path = lq_paths[i]
85
+ lq_bgr = cv2.imread(lq_img_path.as_posix(), cv2.IMREAD_COLOR)
86
+ assert lq_bgr is not None
87
+ sr_bgr = onnx_runner.run(lq_bgr)
88
+
89
+ hq_img_path = hq_paths[i]
90
+ hq_bgr = cv2.imread(hq_img_path.as_posix(), cv2.IMREAD_COLOR)
91
+
92
+ sr_bgr, hq_bgr = align_shape(sr_bgr, hq_bgr)
93
+
94
+ out_path = out_dir / f"{lq_img_path.stem}.png"
95
+ cv2.imwrite(out_path.as_posix(), sr_bgr)
96
+
97
+ sr_paths.append(out_path)
98
+
99
+ return hq_paths, sr_paths
100
+
101
+
102
+ def eval_metrics(
103
+ hq_paths: list[Path],
104
+ sr_paths: list[Path],
105
+ hq_dir: Path,
106
+ sr_dir: Path,
107
+ device: torch.device | None = None,
108
+ ) -> dict[str, float]:
109
+ assert len(hq_paths) == len(sr_paths)
110
+
111
+ device = device or (
112
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
113
+ )
114
+
115
+ psnr_metric = pyiqa.create_metric("psnr", device=device) # FR: sr, ref
116
+ ms_ssim_metric = pyiqa.create_metric("ms_ssim", device=device) # FR: sr, ref
117
+ fid_metric = pyiqa.create_metric("fid")
118
+
119
+ with torch.inference_mode():
120
+ psnr_vals = []
121
+ ms_ssim_vals = []
122
+ for sr_p, hq_p in zip(sr_paths, hq_paths):
123
+ sr_p = sr_p.as_posix()
124
+ hq_p = hq_p.as_posix()
125
+ psnr_vals.append(psnr_metric(sr_p, hq_p).detach())
126
+ ms_ssim_vals.append(ms_ssim_metric(sr_p, hq_p).detach())
127
+
128
+ psnr = torch.stack(psnr_vals).mean().item()
129
+ ms_ssim = torch.stack(ms_ssim_vals).mean().item()
130
+
131
+ fid = fid_metric(
132
+ sr_dir.as_posix(),
133
+ hq_dir.as_posix(),
134
+ mode="clean",
135
+ batch_size=1,
136
+ num_workers=0,
137
+ ).item()
138
+
139
+ return {"psnr": psnr, "ms_ssim": ms_ssim, "fid": fid}
140
+
141
+
142
+ def main(args):
143
+ onnx_path = Path(args.onnx)
144
+ latent_path = Path(args.latent)
145
+ hq_dir = Path(args.hq_dir)
146
+ lq_dir = Path(args.lq_dir)
147
+ out_dir = Path(args.out_dir)
148
+
149
+ assert onnx_path.suffix == ".onnx" and onnx_path.is_file()
150
+ assert latent_path.suffix == ".npy" and latent_path.is_file()
151
+ assert lq_dir.is_dir(), f"{lq_dir} is not a dir!"
152
+ assert hq_dir.is_dir(), f"{hq_dir} is not a dir!"
153
+
154
+ sr_dir = out_dir / "sr"
155
+ hq_paths, sr_paths = gen_sr_images(
156
+ hq_dir, lq_dir, sr_dir, onnx_path, latent_path, args.max_samples
157
+ )
158
+
159
+ scores = eval_metrics(hq_paths, sr_paths, hq_dir, sr_dir)
160
+
161
+ summary = {
162
+ "onnx": onnx_path.as_posix(),
163
+ "psnr": scores["psnr"],
164
+ "ms_ssim": scores["ms_ssim"],
165
+ "fid": scores["fid"],
166
+ }
167
+
168
+ out_file = out_dir / f"eval_{onnx_path.stem}_result.json"
169
+ with open(out_file, "w") as f:
170
+ json.dump(summary, f, indent=2)
171
+ dataset_name = hq_dir.parent.name
172
+ print(f"summary of {dataset_name}: PSNR | MS_SSIM | FID")
173
+ print(
174
+ f"{dataset_name}: {scores['psnr']:.2f} | {scores['ms_ssim']:.4f} | {scores['fid']:.2f}"
175
+ )
176
+ print(f"result saved to {out_file}")
177
+
178
+ if args.clean:
179
+ import shutil
180
+
181
+ print(f"cleaning enhanced lq dir: {sr_dir}")
182
+ shutil.rmtree(sr_dir.as_posix(), ignore_errors=True)
183
+
184
+
185
+ if __name__ == "__main__":
186
+ from argparse import ArgumentParser
187
+
188
+ parser = ArgumentParser()
189
+ parser.add_argument("--onnx", type=str, required=True)
190
+ parser.add_argument("--latent", type=str, required=True)
191
+ parser.add_argument("--hq-dir", type=str, required=True)
192
+ parser.add_argument("--lq-dir", type=str, required=True)
193
+ parser.add_argument("--out-dir", type=str, default="outputs")
194
+ parser.add_argument(
195
+ "--max-samples",
196
+ type=int,
197
+ default=None,
198
+ help="limit number of used samples(debug purpose only), None means not-limited",
199
+ )
200
+ parser.add_argument(
201
+ "-clean",
202
+ action="store_true",
203
+ default=False,
204
+ help="clean out-dir when finished",
205
+ )
206
+ main(parser.parse_args())
onnx_inference.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+
5
+ sys.path.insert(0, Path(__file__).parent.as_posix())
6
+
7
+
8
+ import cv2
9
+ from onnx_runner import OnnxRunner
10
+
11
+
12
+ def main(args):
13
+ onnx_path = Path(args.onnx)
14
+ input_path = Path(args.input)
15
+ out_dir = Path(args.out_dir)
16
+
17
+ assert onnx_path.suffix == ".onnx"
18
+
19
+ if input_path.is_file():
20
+ input_images_path = [input_path]
21
+ else:
22
+ input_images_path = sorted(
23
+ [
24
+ p
25
+ for p in input_path.rglob("*")
26
+ if p.suffix.lower() in (".png", ".jpg", ".jpeg")
27
+ ]
28
+ )
29
+
30
+ out_dir.mkdir(exist_ok=True, parents=True)
31
+ onnx_runner = OnnxRunner(onnx_path, args.latent)
32
+ for input_img_path in input_images_path:
33
+ input_img_path: Path
34
+
35
+ input_bgr = cv2.imread(input_img_path.as_posix(), cv2.IMREAD_COLOR)
36
+ assert input_bgr is not None
37
+ out_bgr = onnx_runner.run(input_bgr)
38
+
39
+ out_path = out_dir / f"{input_img_path.stem}.png"
40
+ cv2.imwrite(out_path.as_posix(), out_bgr)
41
+ print(f"saved {out_path}")
42
+
43
+
44
+ if __name__ == "__main__":
45
+ from argparse import ArgumentParser
46
+
47
+ parser = ArgumentParser()
48
+ parser.add_argument("--onnx", type=str, required=True)
49
+ parser.add_argument("--input", type=str, required=True)
50
+ parser.add_argument("--out-dir", type=str, required=True)
51
+ parser.add_argument("--latent", type=str, default=None)
52
+
53
+ main(parser.parse_args())
onnx_runner.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+
7
+
8
+ def parse_input_shape_fmt(input_shape):
9
+ """parse input shape is nchw or nhwc format.
10
+ We assume c is smaller than h&w dimensions
11
+ """
12
+ assert len(input_shape) == 4
13
+
14
+ c1, c2, c3 = input_shape[1:]
15
+
16
+ if c1 < min(c2, c3): # c1 is channel dimension
17
+ return "nchw"
18
+ elif c3 < min(c1, c2): # c3 is channel dimension
19
+ return "nhwc"
20
+ else:
21
+ raise ValueError(f"can not parse input format for shape: {input_shape}")
22
+
23
+
24
+ def preprocess(img_bgr: np.ndarray, input_shape_hw: tuple[int, int]):
25
+ in_h, in_w = input_shape_hw
26
+
27
+ resized_bgr = cv2.resize(img_bgr, (in_w, in_h), interpolation=cv2.INTER_LINEAR)
28
+ resized_rgb = cv2.cvtColor(resized_bgr, cv2.COLOR_BGR2RGB)
29
+ normed_rgb = (resized_rgb / 255.0 - 0.5) / 0.5 # norm 0~255 -> -1~1
30
+
31
+ return normed_rgb
32
+
33
+
34
+ def postprocess(pred_3d: np.ndarray, pred_fmt: str, origin_hw: tuple[int, int]):
35
+ de_normed_3d = (pred_3d * 0.5 + 0.5) * 255 # de-norm -1~1 -> 0~255
36
+
37
+ if pred_fmt == "nchw":
38
+ hwc = np.transpose(de_normed_3d, [1, 2, 0]) # chw -> hwc
39
+ else: # nhwc
40
+ hwc = de_normed_3d # unchanged
41
+
42
+ pred_rgb = np.clip(hwc, 0, 255).astype(np.uint8)
43
+ pred_bgr = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2BGR)
44
+
45
+ if tuple(pred_bgr.shape[:2]) != tuple(origin_hw):
46
+ pred_bgr = cv2.resize(pred_bgr, origin_hw[::-1], interpolation=cv2.INTER_LINEAR)
47
+
48
+ return pred_bgr
49
+
50
+
51
+ class OnnxRunner:
52
+ def __init__(self, onnx_path, latent_path=None, debug=False):
53
+ if "CUDAExecutionProvider" in ort.get_available_providers():
54
+ providers = ["CUDAExecutionProvider"]
55
+ else:
56
+ providers = ["CPUExecutionProvider"]
57
+
58
+ ort_session = ort.InferenceSession(str(onnx_path), providers=providers)
59
+
60
+ input0 = ort_session.get_inputs()[0]
61
+ self.input_name = input0.name
62
+ self.input_shape = tuple(input0.shape)
63
+ self.input_format = parse_input_shape_fmt(input0.shape)
64
+ self.ort_session = ort_session
65
+ self.debug = debug
66
+
67
+ if self.input_format == "nchw":
68
+ self._in_h, self._in_w = self.input_shape[2:]
69
+ else: # nhwc
70
+ self._in_h, self._in_w = self.input_shape[1:3]
71
+
72
+ if len(ort_session.get_inputs()) == 2:
73
+ latent_input = ort_session.get_inputs()[1]
74
+ self.latent_input_name = latent_input.name
75
+ if latent_path is not None and Path(latent_path).is_file():
76
+ latent = np.load(str(latent_path)) # nchw format
77
+ latent = np.transpose(latent, [0, 2, 3, 1]) # nchw -> nhwc
78
+ else:
79
+ rng = np.random.default_rng(seed=5122)
80
+ latent = rng.standard_normal(latent_input.shape)
81
+ self.latent = np.float32(latent)
82
+ else:
83
+ self.latent_input_name = None
84
+
85
+ if debug:
86
+ self._dbg_out_dir = Path(__file__).parent / "outputs"
87
+ self._dbg_out_dir.mkdir(exist_ok=True, parents=True)
88
+
89
+ def run(self, original_bgr: np.ndarray) -> np.ndarray:
90
+ """Enhance given uint8 bgr image, and return enhanced uint8 bgr image."""
91
+ assert original_bgr.dtype == np.uint8
92
+ assert original_bgr.ndim == 3
93
+ assert original_bgr.shape[2] == 3
94
+
95
+ # =====================
96
+ # preprocessing
97
+ # =====================
98
+ input_hwc = preprocess(original_bgr, (self._in_h, self._in_w))
99
+
100
+ # =====================
101
+ # inference
102
+ # =====================
103
+ if self.input_format == "nchw":
104
+ input_3d = np.transpose(input_hwc, [2, 0, 1]) # hwc -> chw
105
+ else: # nhwc
106
+ input_3d = input_hwc
107
+
108
+ feed = {
109
+ self.input_name: np.float32(input_3d[None, ...]),
110
+ }
111
+ if self.latent_input_name is not None:
112
+ feed[self.latent_input_name] = self.latent
113
+
114
+ outputs = self.ort_session.run(None, feed)
115
+
116
+ pred_3d: np.ndarray = outputs[0][0]
117
+ enhanced_bgr = postprocess(pred_3d, self.input_format, original_bgr.shape[:2])
118
+
119
+ return enhanced_bgr
psfrgan_nchw_fp32.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4869eef68b926f381b921d4322052ce78d98dbc9419d38cd7e21c8b757e3dc0
3
+ size 26298729
psfrgan_nhwc_int8.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21a7d09b2406ddce9a4540c6088152223b25aef6af13c1b0524b7b6c757d6a78
3
+ size 25331858
requirements-eval.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ onnxruntime==1.22
2
+ numpy==1.26.*
3
+ opencv-python==4.8.*
4
+ tqdm
5
+ torch==2.6.0
6
+ pyiqa @ git+https://github.com/chaofengc/IQA-PyTorch.git@e851fd62e66a97345e1281d80e8deb4ab7b93c83
requirements-infer.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ onnxruntime==1.22
2
+ numpy==1.26.*
3
+ opencv-python==4.8.*
4
+ tqdm