import functools import re import tempfile from dataclasses import dataclass from pathlib import Path from typing import Final, cast import cv2 import gradio as gr import numpy as np import onnx import onnxruntime as ort SCRIPT_DIR = Path(__file__).resolve().parent ONNX_MODEL_PATH = SCRIPT_DIR / "BEN2-folded.onnx" @dataclass(frozen=True) class PreallocatedBuffers: resized_image: np.ndarray input_batch: np.ndarray def __init__(self, input_node: ort.NodeArg): b, c, h, w = input_node.shape object.__setattr__(self, "resized_image", np.empty([h, w, c], dtype=np.uint8)) object.__setattr__( self, "input_batch", np.empty([b, c, h, w], dtype=np.float32) ) def get_ort_session_device_type(session: ort.InferenceSession) -> str: # get_providers() returns e.g. ["CUDAExecutionProvider", "CPUExecutionProvider"] provider = session.get_providers()[0] # strip the common suffix and lower-case return provider[: provider.index("ExecutionProvider")].lower() def nodearg_to_numpy_dtype(node: ort.NodeArg): match = cast(re.Match[str], re.match(r"tensor\((\w+)\)", node.type)) elem_name = match.group(1).upper() enum_val = getattr(onnx.TensorProto.DataType, elem_name) np_dtype = onnx.helper.tensor_dtype_to_np_dtype(enum_val) return np_dtype ort_session = ort.InferenceSession( ONNX_MODEL_PATH, providers=[ "CUDAExecutionProvider", ( "CoreMLExecutionProvider", { "ModelFormat": "MLProgram", "RequireStaticInputShapes": "1", "AllowLowPrecisionAccumulationOnGPU": "1", }, ), "CPUExecutionProvider", ], ) input_node = ort_session.get_inputs()[0] output_node = ort_session.get_outputs()[0] buffers = PreallocatedBuffers(input_node=input_node) device_type = get_ort_session_device_type(ort_session) if device_type == "coreml": device_type = "cpu" io_binding = ort_session.io_binding() if device_type != "cpu": input_ortvalue = ort.OrtValue.ortvalue_from_shape_and_type( input_node.shape, nodearg_to_numpy_dtype(input_node), device_type ) io_binding.bind_ortvalue_input(input_node.name, input_ortvalue) else: io_binding.bind_cpu_input(input_node.name, buffers.input_batch) output_ortvalue = ort.OrtValue.ortvalue_from_shape_and_type( output_node.shape, nodearg_to_numpy_dtype(output_node), device_type ) io_binding.bind_ortvalue_output(output_node.name, output_ortvalue) def remove_background(rgb_image: np.ndarray, output_path_str: str): input_batch: Final[np.ndarray] = buffers.input_batch resized_image: Final[np.ndarray] = buffers.resized_image # Preprocess cv2.resize( src=rgb_image, dst=resized_image, dsize=(resized_image.shape[1], resized_image.shape[0]), interpolation=cv2.INTER_LANCZOS4, ) chw_image = resized_image.transpose(2, 0, 1) np.divide(chw_image, np.iinfo(np.uint8).max, out=input_batch[0]) # Inference if device_type != "cpu": # Update existing OrtValue's memory in-place (no re-alloc) input_ortvalue.update_inplace(input_batch) ort_session.run_with_iobinding(io_binding) outputs = output_ortvalue.numpy() # Postprocess raw_mask = outputs.squeeze() min_val = raw_mask.min() max_val = raw_mask.max() normalized_mask = ( (raw_mask - min_val) / (max_val - min_val + np.finfo(np.float32).eps) * np.iinfo(np.uint8).max ) resized_mask = cv2.resize( normalized_mask.astype(np.uint8), dsize=(rgb_image.shape[1], rgb_image.shape[0]), interpolation=cv2.INTER_LANCZOS4, ) bgra_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGRA) bgra_image[:, :, 3] = resized_mask cv2.imwrite(output_path_str, bgra_image, [cv2.IMWRITE_WEBP_QUALITY, 101]) return output_path_str if __name__ == "__main__": with tempfile.NamedTemporaryFile(suffix=".webp") as webp_tmp: demo = gr.Interface( fn=functools.partial(remove_background, output_path_str=webp_tmp.name), inputs=gr.Image(type="numpy", label="Upload Image"), outputs=[ gr.Image(type="filepath", label="Result Image"), ], title="BGone 💕🍒", description=("Upload an image to get a background-free, lossless WebP."), ) demo.queue() demo.launch()