import os import time from dataclasses import dataclass from functools import lru_cache import gradio as gr import numpy as np from huggingface_hub import hf_hub_download from openvino import Core from PIL import Image, ImageOps MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "ibrhr/BiRefNet-lite-openvino-xeon-w2145") DEVICE = os.getenv("OPENVINO_DEVICE", "CPU") DEFAULT_MODEL_VARIANT_KEY = os.getenv("MODEL_VARIANT", "fp32_1024x1024") IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) @dataclass(frozen=True) class ModelVariant: key: str label: str xml: str precision: str input_size: int benchmark_ms: float benchmark_fps: float notes: str @property def bin(self) -> str: return self.xml.replace(".xml", ".bin") MODEL_VARIANTS = ( ModelVariant( key="int8_1024x1024", label="INT8 NNCF - 1024x1024 - 1272 ms / 0.79 FPS", xml="openvino_int8/birefnet_lite_1024x1024_int8.xml", precision="INT8 NNCF", input_size=1024, benchmark_ms=1272.2, benchmark_fps=0.79, notes="Best benchmarked full-quality option on the target CPU.", ), ModelVariant( key="int8_512x512", label="INT8 NNCF - 512x512 - 332 ms / 3.01 FPS", xml="openvino_int8/birefnet_lite_512x512_int8.xml", precision="INT8 NNCF", input_size=512, benchmark_ms=332.32, benchmark_fps=3.01, notes="Fastest benchmarked option, with lower input resolution.", ), ModelVariant( key="fp16_1024x1024", label="FP16 - 1024x1024 - 1419 ms / 0.70 FPS", xml="openvino_fp16/birefnet_lite_1024x1024_fp16.xml", precision="FP16", input_size=1024, benchmark_ms=1419.0, benchmark_fps=0.70, notes="Smaller weights than FP32 at full input resolution.", ), ModelVariant( key="fp16_512x512", label="FP16 - 512x512 - 366 ms / 2.73 FPS", xml="openvino_fp16/birefnet_lite_512x512_fp16.xml", precision="FP16", input_size=512, benchmark_ms=365.97, benchmark_fps=2.73, notes="Smaller weights than FP32 at lower input resolution.", ), ModelVariant( key="fp32_1024x1024", label="FP32 - 1024x1024 - 1441 ms / 0.69 FPS", xml="openvino_fp32/birefnet_lite_1024x1024.xml", precision="FP32", input_size=1024, benchmark_ms=1440.9, benchmark_fps=0.69, notes="Original default and reference OpenVINO precision.", ), ModelVariant( key="fp32_512x512", label="FP32 - 512x512 - 366 ms / 2.73 FPS", xml="openvino_fp32/birefnet_lite_512x512.xml", precision="FP32", input_size=512, benchmark_ms=366.46, benchmark_fps=2.73, notes="Reference OpenVINO precision at lower input resolution.", ), ModelVariant( key="int8wo_1024x1024", label="INT8 weight-only - 1024x1024 - 1440 ms / 0.69 FPS", xml="openvino_int8wo/birefnet_lite_1024x1024_int8wo.xml", precision="INT8 weight-only", input_size=1024, benchmark_ms=1439.53, benchmark_fps=0.69, notes="Alternative weight-only quantized full-resolution model.", ), ModelVariant( key="int8wo_512x512", label="INT8 weight-only - 512x512 - 366 ms / 2.73 FPS", xml="openvino_int8wo/birefnet_lite_512x512_int8wo.xml", precision="INT8 weight-only", input_size=512, benchmark_ms=365.75, benchmark_fps=2.73, notes="Alternative weight-only quantized lower-resolution model.", ), ) MODEL_VARIANTS_BY_KEY = {variant.key: variant for variant in MODEL_VARIANTS} @dataclass(frozen=True) class Runtime: compiled_model: object input_node: object output_node: object variant: ModelVariant model_path: str load_seconds: float device: str def get_model_variant(variant_key: str | None) -> ModelVariant: key = variant_key or DEFAULT_MODEL_VARIANT_KEY if key not in MODEL_VARIANTS_BY_KEY: valid_keys = ", ".join(MODEL_VARIANTS_BY_KEY) raise gr.Error(f"Unknown model variant '{key}'. Valid variants: {valid_keys}") return MODEL_VARIANTS_BY_KEY[key] def _resampling(name: str) -> int: return getattr(Image.Resampling, name) @lru_cache(maxsize=len(MODEL_VARIANTS)) def get_runtime(variant_key: str) -> Runtime: variant = get_model_variant(variant_key) started = time.perf_counter() model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=variant.xml) weights_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=variant.bin) core = Core() model = core.read_model(model=model_path, weights=weights_path) model.reshape({model.input(0): [1, 3, variant.input_size, variant.input_size]}) compiled_model = core.compile_model(model, DEVICE) return Runtime( compiled_model=compiled_model, input_node=compiled_model.input(0), output_node=compiled_model.output(0), variant=variant, model_path=model_path, load_seconds=time.perf_counter() - started, device=DEVICE, ) def preprocess(image: Image.Image, model_size: int) -> np.ndarray: rgb_image = ImageOps.exif_transpose(image).convert("RGB") resized = rgb_image.resize((model_size, model_size), _resampling("BICUBIC")) array = np.asarray(resized, dtype=np.float32) / 255.0 array = (array - IMAGENET_MEAN) / IMAGENET_STD array = np.transpose(array, (2, 0, 1))[None, ...] return np.ascontiguousarray(array, dtype=np.float32) def sigmoid(array: np.ndarray) -> np.ndarray: clipped = np.clip(array, -50.0, 50.0) return 1.0 / (1.0 + np.exp(-clipped)) def postprocess_mask(output: np.ndarray, size: tuple[int, int]) -> Image.Image: mask = np.asarray(output, dtype=np.float32) while mask.ndim > 2: mask = mask[0] mask = sigmoid(mask) mask = np.clip(mask * 255.0, 0, 255).astype(np.uint8) mask_image = Image.fromarray(mask, mode="L") return mask_image.resize(size, _resampling("LANCZOS")) def remove_background(image: Image.Image, model_variant_key: str): if image is None: raise gr.Error("Upload an image first.") total_started = time.perf_counter() variant = get_model_variant(model_variant_key) runtime = get_runtime(variant.key) original = ImageOps.exif_transpose(image).convert("RGB") preprocess_started = time.perf_counter() tensor = preprocess(original, variant.input_size) preprocess_seconds = time.perf_counter() - preprocess_started inference_started = time.perf_counter() output = runtime.compiled_model({runtime.input_node: tensor})[runtime.output_node] inference_seconds = time.perf_counter() - inference_started postprocess_started = time.perf_counter() mask_image = postprocess_mask(output, original.size) cutout = original.convert("RGBA") cutout.putalpha(mask_image) postprocess_seconds = time.perf_counter() - postprocess_started total_seconds = time.perf_counter() - total_started timing = ( f"Total: {total_seconds:.3f} s\n" f"Preprocess: {preprocess_seconds * 1000:.1f} ms\n" f"Inference: {inference_seconds * 1000:.1f} ms\n" f"Postprocess: {postprocess_seconds * 1000:.1f} ms" ) specs = { "model": MODEL_REPO_ID, "variant": variant.key, "variant_label": variant.label, "model_xml": variant.xml, "device": runtime.device, "precision": variant.precision, "model_input_size": f"{variant.input_size}x{variant.input_size}", "benchmark_ms": variant.benchmark_ms, "benchmark_fps": variant.benchmark_fps, "variant_notes": variant.notes, "uploaded_image_size": f"{original.width}x{original.height}", "input_tensor_shape": list(tensor.shape), "output_tensor_shape": list(np.asarray(output).shape), "model_load_seconds": round(runtime.load_seconds, 3), "total_seconds": round(total_seconds, 3), "preprocess_ms": round(preprocess_seconds * 1000, 1), "inference_ms": round(inference_seconds * 1000, 1), "postprocess_ms": round(postprocess_seconds * 1000, 1), } return mask_image, cutout, timing, specs with gr.Blocks(title="BiRefNet OpenVINO") as demo: gr.Markdown("# BiRefNet OpenVINO") with gr.Row(): input_image = gr.Image(label="Image", type="pil") model_dropdown = gr.Dropdown( label="Model variant", choices=[(variant.label, variant.key) for variant in MODEL_VARIANTS], value=get_model_variant(DEFAULT_MODEL_VARIANT_KEY).key, interactive=True, ) run_button = gr.Button("Run", variant="primary") with gr.Row(): mask_output = gr.Image(label="Mask", type="pil") cutout_output = gr.Image(label="Background removed", type="pil", format="png") with gr.Row(): timing_output = gr.Textbox(label="Processing time", lines=4) specs_output = gr.JSON(label="Specs") run_button.click( fn=remove_background, inputs=[input_image, model_dropdown], outputs=[mask_output, cutout_output, timing_output, specs_output], ) input_image.upload( fn=remove_background, inputs=[input_image, model_dropdown], outputs=[mask_output, cutout_output, timing_output, specs_output], ) if __name__ == "__main__": demo.queue(max_size=8).launch()