ryzenai-psfrgan / onnx_runner.py
hongw.qin
upload models
d1faacc
from pathlib import Path
import cv2
import numpy as np
import onnxruntime as ort
def parse_input_shape_fmt(input_shape):
"""parse input shape is nchw or nhwc format.
We assume c is smaller than h&w dimensions
"""
assert len(input_shape) == 4
c1, c2, c3 = input_shape[1:]
if c1 < min(c2, c3): # c1 is channel dimension
return "nchw"
elif c3 < min(c1, c2): # c3 is channel dimension
return "nhwc"
else:
raise ValueError(f"can not parse input format for shape: {input_shape}")
def preprocess(img_bgr: np.ndarray, input_shape_hw: tuple[int, int]):
in_h, in_w = input_shape_hw
resized_bgr = cv2.resize(img_bgr, (in_w, in_h), interpolation=cv2.INTER_LINEAR)
resized_rgb = cv2.cvtColor(resized_bgr, cv2.COLOR_BGR2RGB)
normed_rgb = (resized_rgb / 255.0 - 0.5) / 0.5 # norm 0~255 -> -1~1
return normed_rgb
def postprocess(pred_3d: np.ndarray, pred_fmt: str, origin_hw: tuple[int, int]):
de_normed_3d = (pred_3d * 0.5 + 0.5) * 255 # de-norm -1~1 -> 0~255
if pred_fmt == "nchw":
hwc = np.transpose(de_normed_3d, [1, 2, 0]) # chw -> hwc
else: # nhwc
hwc = de_normed_3d # unchanged
pred_rgb = np.clip(hwc, 0, 255).astype(np.uint8)
pred_bgr = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2BGR)
if tuple(pred_bgr.shape[:2]) != tuple(origin_hw):
pred_bgr = cv2.resize(pred_bgr, origin_hw[::-1], interpolation=cv2.INTER_LINEAR)
return pred_bgr
class OnnxRunner:
def __init__(self, onnx_path, latent_path=None, debug=False):
if "CUDAExecutionProvider" in ort.get_available_providers():
providers = ["CUDAExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
ort_session = ort.InferenceSession(str(onnx_path), providers=providers)
input0 = ort_session.get_inputs()[0]
self.input_name = input0.name
self.input_shape = tuple(input0.shape)
self.input_format = parse_input_shape_fmt(input0.shape)
self.ort_session = ort_session
self.debug = debug
if self.input_format == "nchw":
self._in_h, self._in_w = self.input_shape[2:]
else: # nhwc
self._in_h, self._in_w = self.input_shape[1:3]
if len(ort_session.get_inputs()) == 2:
latent_input = ort_session.get_inputs()[1]
self.latent_input_name = latent_input.name
if latent_path is not None and Path(latent_path).is_file():
latent = np.load(str(latent_path)) # nchw format
latent = np.transpose(latent, [0, 2, 3, 1]) # nchw -> nhwc
else:
rng = np.random.default_rng(seed=5122)
latent = rng.standard_normal(latent_input.shape)
self.latent = np.float32(latent)
else:
self.latent_input_name = None
if debug:
self._dbg_out_dir = Path(__file__).parent / "outputs"
self._dbg_out_dir.mkdir(exist_ok=True, parents=True)
def run(self, original_bgr: np.ndarray) -> np.ndarray:
"""Enhance given uint8 bgr image, and return enhanced uint8 bgr image."""
assert original_bgr.dtype == np.uint8
assert original_bgr.ndim == 3
assert original_bgr.shape[2] == 3
# =====================
# preprocessing
# =====================
input_hwc = preprocess(original_bgr, (self._in_h, self._in_w))
# =====================
# inference
# =====================
if self.input_format == "nchw":
input_3d = np.transpose(input_hwc, [2, 0, 1]) # hwc -> chw
else: # nhwc
input_3d = input_hwc
feed = {
self.input_name: np.float32(input_3d[None, ...]),
}
if self.latent_input_name is not None:
feed[self.latent_input_name] = self.latent
outputs = self.ort_session.run(None, feed)
pred_3d: np.ndarray = outputs[0][0]
enhanced_bgr = postprocess(pred_3d, self.input_format, original_bgr.shape[:2])
return enhanced_bgr