| from pathlib import Path |
|
|
| import cv2 |
| import numpy as np |
| import onnxruntime as ort |
|
|
|
|
| def parse_input_shape_fmt(input_shape): |
| """parse input shape is nchw or nhwc format. |
| We assume c is smaller than h&w dimensions |
| """ |
| assert len(input_shape) == 4 |
|
|
| c1, c2, c3 = input_shape[1:] |
|
|
| if c1 < min(c2, c3): |
| return "nchw" |
| elif c3 < min(c1, c2): |
| return "nhwc" |
| else: |
| raise ValueError(f"can not parse input format for shape: {input_shape}") |
|
|
|
|
| def preprocess(img_bgr: np.ndarray, input_shape_hw: tuple[int, int]): |
| in_h, in_w = input_shape_hw |
|
|
| resized_bgr = cv2.resize(img_bgr, (in_w, in_h), interpolation=cv2.INTER_LINEAR) |
| resized_rgb = cv2.cvtColor(resized_bgr, cv2.COLOR_BGR2RGB) |
| normed_rgb = (resized_rgb / 255.0 - 0.5) / 0.5 |
|
|
| return normed_rgb |
|
|
|
|
| def postprocess(pred_3d: np.ndarray, pred_fmt: str, origin_hw: tuple[int, int]): |
| de_normed_3d = (pred_3d * 0.5 + 0.5) * 255 |
|
|
| if pred_fmt == "nchw": |
| hwc = np.transpose(de_normed_3d, [1, 2, 0]) |
| else: |
| hwc = de_normed_3d |
|
|
| pred_rgb = np.clip(hwc, 0, 255).astype(np.uint8) |
| pred_bgr = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2BGR) |
|
|
| if tuple(pred_bgr.shape[:2]) != tuple(origin_hw): |
| pred_bgr = cv2.resize(pred_bgr, origin_hw[::-1], interpolation=cv2.INTER_LINEAR) |
|
|
| return pred_bgr |
|
|
|
|
| class OnnxRunner: |
| def __init__(self, onnx_path, latent_path=None, debug=False): |
| if "CUDAExecutionProvider" in ort.get_available_providers(): |
| providers = ["CUDAExecutionProvider"] |
| else: |
| providers = ["CPUExecutionProvider"] |
|
|
| ort_session = ort.InferenceSession(str(onnx_path), providers=providers) |
|
|
| input0 = ort_session.get_inputs()[0] |
| self.input_name = input0.name |
| self.input_shape = tuple(input0.shape) |
| self.input_format = parse_input_shape_fmt(input0.shape) |
| self.ort_session = ort_session |
| self.debug = debug |
|
|
| if self.input_format == "nchw": |
| self._in_h, self._in_w = self.input_shape[2:] |
| else: |
| self._in_h, self._in_w = self.input_shape[1:3] |
|
|
| if len(ort_session.get_inputs()) == 2: |
| latent_input = ort_session.get_inputs()[1] |
| self.latent_input_name = latent_input.name |
| if latent_path is not None and Path(latent_path).is_file(): |
| latent = np.load(str(latent_path)) |
| latent = np.transpose(latent, [0, 2, 3, 1]) |
| else: |
| rng = np.random.default_rng(seed=5122) |
| latent = rng.standard_normal(latent_input.shape) |
| self.latent = np.float32(latent) |
| else: |
| self.latent_input_name = None |
|
|
| if debug: |
| self._dbg_out_dir = Path(__file__).parent / "outputs" |
| self._dbg_out_dir.mkdir(exist_ok=True, parents=True) |
|
|
| def run(self, original_bgr: np.ndarray) -> np.ndarray: |
| """Enhance given uint8 bgr image, and return enhanced uint8 bgr image.""" |
| assert original_bgr.dtype == np.uint8 |
| assert original_bgr.ndim == 3 |
| assert original_bgr.shape[2] == 3 |
|
|
| |
| |
| |
| input_hwc = preprocess(original_bgr, (self._in_h, self._in_w)) |
|
|
| |
| |
| |
| if self.input_format == "nchw": |
| input_3d = np.transpose(input_hwc, [2, 0, 1]) |
| else: |
| input_3d = input_hwc |
|
|
| feed = { |
| self.input_name: np.float32(input_3d[None, ...]), |
| } |
| if self.latent_input_name is not None: |
| feed[self.latent_input_name] = self.latent |
|
|
| outputs = self.ort_session.run(None, feed) |
|
|
| pred_3d: np.ndarray = outputs[0][0] |
| enhanced_bgr = postprocess(pred_3d, self.input_format, original_bgr.shape[:2]) |
|
|
| return enhanced_bgr |
|
|