ibrhrsw's picture
option to pick any model
ea84176
import os
import time
from dataclasses import dataclass
from functools import lru_cache
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
from openvino import Core
from PIL import Image, ImageOps
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "ibrhr/BiRefNet-lite-openvino-xeon-w2145")
DEVICE = os.getenv("OPENVINO_DEVICE", "CPU")
DEFAULT_MODEL_VARIANT_KEY = os.getenv("MODEL_VARIANT", "fp32_1024x1024")
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
@dataclass(frozen=True)
class ModelVariant:
key: str
label: str
xml: str
precision: str
input_size: int
benchmark_ms: float
benchmark_fps: float
notes: str
@property
def bin(self) -> str:
return self.xml.replace(".xml", ".bin")
MODEL_VARIANTS = (
ModelVariant(
key="int8_1024x1024",
label="INT8 NNCF - 1024x1024 - 1272 ms / 0.79 FPS",
xml="openvino_int8/birefnet_lite_1024x1024_int8.xml",
precision="INT8 NNCF",
input_size=1024,
benchmark_ms=1272.2,
benchmark_fps=0.79,
notes="Best benchmarked full-quality option on the target CPU.",
),
ModelVariant(
key="int8_512x512",
label="INT8 NNCF - 512x512 - 332 ms / 3.01 FPS",
xml="openvino_int8/birefnet_lite_512x512_int8.xml",
precision="INT8 NNCF",
input_size=512,
benchmark_ms=332.32,
benchmark_fps=3.01,
notes="Fastest benchmarked option, with lower input resolution.",
),
ModelVariant(
key="fp16_1024x1024",
label="FP16 - 1024x1024 - 1419 ms / 0.70 FPS",
xml="openvino_fp16/birefnet_lite_1024x1024_fp16.xml",
precision="FP16",
input_size=1024,
benchmark_ms=1419.0,
benchmark_fps=0.70,
notes="Smaller weights than FP32 at full input resolution.",
),
ModelVariant(
key="fp16_512x512",
label="FP16 - 512x512 - 366 ms / 2.73 FPS",
xml="openvino_fp16/birefnet_lite_512x512_fp16.xml",
precision="FP16",
input_size=512,
benchmark_ms=365.97,
benchmark_fps=2.73,
notes="Smaller weights than FP32 at lower input resolution.",
),
ModelVariant(
key="fp32_1024x1024",
label="FP32 - 1024x1024 - 1441 ms / 0.69 FPS",
xml="openvino_fp32/birefnet_lite_1024x1024.xml",
precision="FP32",
input_size=1024,
benchmark_ms=1440.9,
benchmark_fps=0.69,
notes="Original default and reference OpenVINO precision.",
),
ModelVariant(
key="fp32_512x512",
label="FP32 - 512x512 - 366 ms / 2.73 FPS",
xml="openvino_fp32/birefnet_lite_512x512.xml",
precision="FP32",
input_size=512,
benchmark_ms=366.46,
benchmark_fps=2.73,
notes="Reference OpenVINO precision at lower input resolution.",
),
ModelVariant(
key="int8wo_1024x1024",
label="INT8 weight-only - 1024x1024 - 1440 ms / 0.69 FPS",
xml="openvino_int8wo/birefnet_lite_1024x1024_int8wo.xml",
precision="INT8 weight-only",
input_size=1024,
benchmark_ms=1439.53,
benchmark_fps=0.69,
notes="Alternative weight-only quantized full-resolution model.",
),
ModelVariant(
key="int8wo_512x512",
label="INT8 weight-only - 512x512 - 366 ms / 2.73 FPS",
xml="openvino_int8wo/birefnet_lite_512x512_int8wo.xml",
precision="INT8 weight-only",
input_size=512,
benchmark_ms=365.75,
benchmark_fps=2.73,
notes="Alternative weight-only quantized lower-resolution model.",
),
)
MODEL_VARIANTS_BY_KEY = {variant.key: variant for variant in MODEL_VARIANTS}
@dataclass(frozen=True)
class Runtime:
compiled_model: object
input_node: object
output_node: object
variant: ModelVariant
model_path: str
load_seconds: float
device: str
def get_model_variant(variant_key: str | None) -> ModelVariant:
key = variant_key or DEFAULT_MODEL_VARIANT_KEY
if key not in MODEL_VARIANTS_BY_KEY:
valid_keys = ", ".join(MODEL_VARIANTS_BY_KEY)
raise gr.Error(f"Unknown model variant '{key}'. Valid variants: {valid_keys}")
return MODEL_VARIANTS_BY_KEY[key]
def _resampling(name: str) -> int:
return getattr(Image.Resampling, name)
@lru_cache(maxsize=len(MODEL_VARIANTS))
def get_runtime(variant_key: str) -> Runtime:
variant = get_model_variant(variant_key)
started = time.perf_counter()
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=variant.xml)
weights_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=variant.bin)
core = Core()
model = core.read_model(model=model_path, weights=weights_path)
model.reshape({model.input(0): [1, 3, variant.input_size, variant.input_size]})
compiled_model = core.compile_model(model, DEVICE)
return Runtime(
compiled_model=compiled_model,
input_node=compiled_model.input(0),
output_node=compiled_model.output(0),
variant=variant,
model_path=model_path,
load_seconds=time.perf_counter() - started,
device=DEVICE,
)
def preprocess(image: Image.Image, model_size: int) -> np.ndarray:
rgb_image = ImageOps.exif_transpose(image).convert("RGB")
resized = rgb_image.resize((model_size, model_size), _resampling("BICUBIC"))
array = np.asarray(resized, dtype=np.float32) / 255.0
array = (array - IMAGENET_MEAN) / IMAGENET_STD
array = np.transpose(array, (2, 0, 1))[None, ...]
return np.ascontiguousarray(array, dtype=np.float32)
def sigmoid(array: np.ndarray) -> np.ndarray:
clipped = np.clip(array, -50.0, 50.0)
return 1.0 / (1.0 + np.exp(-clipped))
def postprocess_mask(output: np.ndarray, size: tuple[int, int]) -> Image.Image:
mask = np.asarray(output, dtype=np.float32)
while mask.ndim > 2:
mask = mask[0]
mask = sigmoid(mask)
mask = np.clip(mask * 255.0, 0, 255).astype(np.uint8)
mask_image = Image.fromarray(mask, mode="L")
return mask_image.resize(size, _resampling("LANCZOS"))
def remove_background(image: Image.Image, model_variant_key: str):
if image is None:
raise gr.Error("Upload an image first.")
total_started = time.perf_counter()
variant = get_model_variant(model_variant_key)
runtime = get_runtime(variant.key)
original = ImageOps.exif_transpose(image).convert("RGB")
preprocess_started = time.perf_counter()
tensor = preprocess(original, variant.input_size)
preprocess_seconds = time.perf_counter() - preprocess_started
inference_started = time.perf_counter()
output = runtime.compiled_model({runtime.input_node: tensor})[runtime.output_node]
inference_seconds = time.perf_counter() - inference_started
postprocess_started = time.perf_counter()
mask_image = postprocess_mask(output, original.size)
cutout = original.convert("RGBA")
cutout.putalpha(mask_image)
postprocess_seconds = time.perf_counter() - postprocess_started
total_seconds = time.perf_counter() - total_started
timing = (
f"Total: {total_seconds:.3f} s\n"
f"Preprocess: {preprocess_seconds * 1000:.1f} ms\n"
f"Inference: {inference_seconds * 1000:.1f} ms\n"
f"Postprocess: {postprocess_seconds * 1000:.1f} ms"
)
specs = {
"model": MODEL_REPO_ID,
"variant": variant.key,
"variant_label": variant.label,
"model_xml": variant.xml,
"device": runtime.device,
"precision": variant.precision,
"model_input_size": f"{variant.input_size}x{variant.input_size}",
"benchmark_ms": variant.benchmark_ms,
"benchmark_fps": variant.benchmark_fps,
"variant_notes": variant.notes,
"uploaded_image_size": f"{original.width}x{original.height}",
"input_tensor_shape": list(tensor.shape),
"output_tensor_shape": list(np.asarray(output).shape),
"model_load_seconds": round(runtime.load_seconds, 3),
"total_seconds": round(total_seconds, 3),
"preprocess_ms": round(preprocess_seconds * 1000, 1),
"inference_ms": round(inference_seconds * 1000, 1),
"postprocess_ms": round(postprocess_seconds * 1000, 1),
}
return mask_image, cutout, timing, specs
with gr.Blocks(title="BiRefNet OpenVINO") as demo:
gr.Markdown("# BiRefNet OpenVINO")
with gr.Row():
input_image = gr.Image(label="Image", type="pil")
model_dropdown = gr.Dropdown(
label="Model variant",
choices=[(variant.label, variant.key) for variant in MODEL_VARIANTS],
value=get_model_variant(DEFAULT_MODEL_VARIANT_KEY).key,
interactive=True,
)
run_button = gr.Button("Run", variant="primary")
with gr.Row():
mask_output = gr.Image(label="Mask", type="pil")
cutout_output = gr.Image(label="Background removed", type="pil", format="png")
with gr.Row():
timing_output = gr.Textbox(label="Processing time", lines=4)
specs_output = gr.JSON(label="Specs")
run_button.click(
fn=remove_background,
inputs=[input_image, model_dropdown],
outputs=[mask_output, cutout_output, timing_output, specs_output],
)
input_image.upload(
fn=remove_background,
inputs=[input_image, model_dropdown],
outputs=[mask_output, cutout_output, timing_output, specs_output],
)
if __name__ == "__main__":
demo.queue(max_size=8).launch()