BGone / app.py
itskyf's picture
fix: enable Gradio's queue
9483f35 verified
import functools
import re
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Final, cast
import cv2
import gradio as gr
import numpy as np
import onnx
import onnxruntime as ort
SCRIPT_DIR = Path(__file__).resolve().parent
ONNX_MODEL_PATH = SCRIPT_DIR / "BEN2-folded.onnx"
@dataclass(frozen=True)
class PreallocatedBuffers:
resized_image: np.ndarray
input_batch: np.ndarray
def __init__(self, input_node: ort.NodeArg):
b, c, h, w = input_node.shape
object.__setattr__(self, "resized_image", np.empty([h, w, c], dtype=np.uint8))
object.__setattr__(
self, "input_batch", np.empty([b, c, h, w], dtype=np.float32)
)
def get_ort_session_device_type(session: ort.InferenceSession) -> str:
# get_providers() returns e.g. ["CUDAExecutionProvider", "CPUExecutionProvider"]
provider = session.get_providers()[0]
# strip the common suffix and lower-case
return provider[: provider.index("ExecutionProvider")].lower()
def nodearg_to_numpy_dtype(node: ort.NodeArg):
match = cast(re.Match[str], re.match(r"tensor\((\w+)\)", node.type))
elem_name = match.group(1).upper()
enum_val = getattr(onnx.TensorProto.DataType, elem_name)
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(enum_val)
return np_dtype
ort_session = ort.InferenceSession(
ONNX_MODEL_PATH,
providers=[
"CUDAExecutionProvider",
(
"CoreMLExecutionProvider",
{
"ModelFormat": "MLProgram",
"RequireStaticInputShapes": "1",
"AllowLowPrecisionAccumulationOnGPU": "1",
},
),
"CPUExecutionProvider",
],
)
input_node = ort_session.get_inputs()[0]
output_node = ort_session.get_outputs()[0]
buffers = PreallocatedBuffers(input_node=input_node)
device_type = get_ort_session_device_type(ort_session)
if device_type == "coreml":
device_type = "cpu"
io_binding = ort_session.io_binding()
if device_type != "cpu":
input_ortvalue = ort.OrtValue.ortvalue_from_shape_and_type(
input_node.shape, nodearg_to_numpy_dtype(input_node), device_type
)
io_binding.bind_ortvalue_input(input_node.name, input_ortvalue)
else:
io_binding.bind_cpu_input(input_node.name, buffers.input_batch)
output_ortvalue = ort.OrtValue.ortvalue_from_shape_and_type(
output_node.shape, nodearg_to_numpy_dtype(output_node), device_type
)
io_binding.bind_ortvalue_output(output_node.name, output_ortvalue)
def remove_background(rgb_image: np.ndarray, output_path_str: str):
input_batch: Final[np.ndarray] = buffers.input_batch
resized_image: Final[np.ndarray] = buffers.resized_image
# Preprocess
cv2.resize(
src=rgb_image,
dst=resized_image,
dsize=(resized_image.shape[1], resized_image.shape[0]),
interpolation=cv2.INTER_LANCZOS4,
)
chw_image = resized_image.transpose(2, 0, 1)
np.divide(chw_image, np.iinfo(np.uint8).max, out=input_batch[0])
# Inference
if device_type != "cpu":
# Update existing OrtValue's memory in-place (no re-alloc)
input_ortvalue.update_inplace(input_batch)
ort_session.run_with_iobinding(io_binding)
outputs = output_ortvalue.numpy()
# Postprocess
raw_mask = outputs.squeeze()
min_val = raw_mask.min()
max_val = raw_mask.max()
normalized_mask = (
(raw_mask - min_val)
/ (max_val - min_val + np.finfo(np.float32).eps)
* np.iinfo(np.uint8).max
)
resized_mask = cv2.resize(
normalized_mask.astype(np.uint8),
dsize=(rgb_image.shape[1], rgb_image.shape[0]),
interpolation=cv2.INTER_LANCZOS4,
)
bgra_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGRA)
bgra_image[:, :, 3] = resized_mask
cv2.imwrite(output_path_str, bgra_image, [cv2.IMWRITE_WEBP_QUALITY, 101])
return output_path_str
if __name__ == "__main__":
with tempfile.NamedTemporaryFile(suffix=".webp") as webp_tmp:
demo = gr.Interface(
fn=functools.partial(remove_background, output_path_str=webp_tmp.name),
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=[
gr.Image(type="filepath", label="Result Image"),
],
title="BGone 💕🍒",
description=("Upload an image to get a background-free, lossless WebP."),
)
demo.queue()
demo.launch()