| import functools |
| import re |
| import tempfile |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Final, cast |
|
|
| import cv2 |
| import gradio as gr |
| import numpy as np |
| import onnx |
| import onnxruntime as ort |
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent |
| ONNX_MODEL_PATH = SCRIPT_DIR / "BEN2-folded.onnx" |
|
|
|
|
| @dataclass(frozen=True) |
| class PreallocatedBuffers: |
| resized_image: np.ndarray |
| input_batch: np.ndarray |
|
|
| def __init__(self, input_node: ort.NodeArg): |
| b, c, h, w = input_node.shape |
| object.__setattr__(self, "resized_image", np.empty([h, w, c], dtype=np.uint8)) |
| object.__setattr__( |
| self, "input_batch", np.empty([b, c, h, w], dtype=np.float32) |
| ) |
|
|
|
|
| def get_ort_session_device_type(session: ort.InferenceSession) -> str: |
| |
| provider = session.get_providers()[0] |
| |
| return provider[: provider.index("ExecutionProvider")].lower() |
|
|
|
|
| def nodearg_to_numpy_dtype(node: ort.NodeArg): |
| match = cast(re.Match[str], re.match(r"tensor\((\w+)\)", node.type)) |
| elem_name = match.group(1).upper() |
| enum_val = getattr(onnx.TensorProto.DataType, elem_name) |
| np_dtype = onnx.helper.tensor_dtype_to_np_dtype(enum_val) |
| return np_dtype |
|
|
|
|
| ort_session = ort.InferenceSession( |
| ONNX_MODEL_PATH, |
| providers=[ |
| "CUDAExecutionProvider", |
| ( |
| "CoreMLExecutionProvider", |
| { |
| "ModelFormat": "MLProgram", |
| "RequireStaticInputShapes": "1", |
| "AllowLowPrecisionAccumulationOnGPU": "1", |
| }, |
| ), |
| "CPUExecutionProvider", |
| ], |
| ) |
| input_node = ort_session.get_inputs()[0] |
| output_node = ort_session.get_outputs()[0] |
|
|
| buffers = PreallocatedBuffers(input_node=input_node) |
| device_type = get_ort_session_device_type(ort_session) |
| if device_type == "coreml": |
| device_type = "cpu" |
|
|
| io_binding = ort_session.io_binding() |
| if device_type != "cpu": |
| input_ortvalue = ort.OrtValue.ortvalue_from_shape_and_type( |
| input_node.shape, nodearg_to_numpy_dtype(input_node), device_type |
| ) |
| io_binding.bind_ortvalue_input(input_node.name, input_ortvalue) |
| else: |
| io_binding.bind_cpu_input(input_node.name, buffers.input_batch) |
| output_ortvalue = ort.OrtValue.ortvalue_from_shape_and_type( |
| output_node.shape, nodearg_to_numpy_dtype(output_node), device_type |
| ) |
| io_binding.bind_ortvalue_output(output_node.name, output_ortvalue) |
|
|
|
|
| def remove_background(rgb_image: np.ndarray, output_path_str: str): |
| input_batch: Final[np.ndarray] = buffers.input_batch |
| resized_image: Final[np.ndarray] = buffers.resized_image |
|
|
| |
| cv2.resize( |
| src=rgb_image, |
| dst=resized_image, |
| dsize=(resized_image.shape[1], resized_image.shape[0]), |
| interpolation=cv2.INTER_LANCZOS4, |
| ) |
| chw_image = resized_image.transpose(2, 0, 1) |
| np.divide(chw_image, np.iinfo(np.uint8).max, out=input_batch[0]) |
|
|
| |
| if device_type != "cpu": |
| |
| input_ortvalue.update_inplace(input_batch) |
| ort_session.run_with_iobinding(io_binding) |
| outputs = output_ortvalue.numpy() |
|
|
| |
| raw_mask = outputs.squeeze() |
| min_val = raw_mask.min() |
| max_val = raw_mask.max() |
| normalized_mask = ( |
| (raw_mask - min_val) |
| / (max_val - min_val + np.finfo(np.float32).eps) |
| * np.iinfo(np.uint8).max |
| ) |
|
|
| resized_mask = cv2.resize( |
| normalized_mask.astype(np.uint8), |
| dsize=(rgb_image.shape[1], rgb_image.shape[0]), |
| interpolation=cv2.INTER_LANCZOS4, |
| ) |
| bgra_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGRA) |
| bgra_image[:, :, 3] = resized_mask |
| cv2.imwrite(output_path_str, bgra_image, [cv2.IMWRITE_WEBP_QUALITY, 101]) |
| return output_path_str |
|
|
|
|
| if __name__ == "__main__": |
| with tempfile.NamedTemporaryFile(suffix=".webp") as webp_tmp: |
| demo = gr.Interface( |
| fn=functools.partial(remove_background, output_path_str=webp_tmp.name), |
| inputs=gr.Image(type="numpy", label="Upload Image"), |
| outputs=[ |
| gr.Image(type="filepath", label="Result Image"), |
| ], |
| title="BGone 💕🍒", |
| description=("Upload an image to get a background-free, lossless WebP."), |
| ) |
| demo.queue() |
| demo.launch() |
|
|