Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| import gradio as gr | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| from openvino import Core | |
| from PIL import Image, ImageOps | |
| MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "ibrhr/BiRefNet-lite-openvino-xeon-w2145") | |
| DEVICE = os.getenv("OPENVINO_DEVICE", "CPU") | |
| DEFAULT_MODEL_VARIANT_KEY = os.getenv("MODEL_VARIANT", "fp32_1024x1024") | |
| IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) | |
| IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) | |
| class ModelVariant: | |
| key: str | |
| label: str | |
| xml: str | |
| precision: str | |
| input_size: int | |
| benchmark_ms: float | |
| benchmark_fps: float | |
| notes: str | |
| def bin(self) -> str: | |
| return self.xml.replace(".xml", ".bin") | |
| MODEL_VARIANTS = ( | |
| ModelVariant( | |
| key="int8_1024x1024", | |
| label="INT8 NNCF - 1024x1024 - 1272 ms / 0.79 FPS", | |
| xml="openvino_int8/birefnet_lite_1024x1024_int8.xml", | |
| precision="INT8 NNCF", | |
| input_size=1024, | |
| benchmark_ms=1272.2, | |
| benchmark_fps=0.79, | |
| notes="Best benchmarked full-quality option on the target CPU.", | |
| ), | |
| ModelVariant( | |
| key="int8_512x512", | |
| label="INT8 NNCF - 512x512 - 332 ms / 3.01 FPS", | |
| xml="openvino_int8/birefnet_lite_512x512_int8.xml", | |
| precision="INT8 NNCF", | |
| input_size=512, | |
| benchmark_ms=332.32, | |
| benchmark_fps=3.01, | |
| notes="Fastest benchmarked option, with lower input resolution.", | |
| ), | |
| ModelVariant( | |
| key="fp16_1024x1024", | |
| label="FP16 - 1024x1024 - 1419 ms / 0.70 FPS", | |
| xml="openvino_fp16/birefnet_lite_1024x1024_fp16.xml", | |
| precision="FP16", | |
| input_size=1024, | |
| benchmark_ms=1419.0, | |
| benchmark_fps=0.70, | |
| notes="Smaller weights than FP32 at full input resolution.", | |
| ), | |
| ModelVariant( | |
| key="fp16_512x512", | |
| label="FP16 - 512x512 - 366 ms / 2.73 FPS", | |
| xml="openvino_fp16/birefnet_lite_512x512_fp16.xml", | |
| precision="FP16", | |
| input_size=512, | |
| benchmark_ms=365.97, | |
| benchmark_fps=2.73, | |
| notes="Smaller weights than FP32 at lower input resolution.", | |
| ), | |
| ModelVariant( | |
| key="fp32_1024x1024", | |
| label="FP32 - 1024x1024 - 1441 ms / 0.69 FPS", | |
| xml="openvino_fp32/birefnet_lite_1024x1024.xml", | |
| precision="FP32", | |
| input_size=1024, | |
| benchmark_ms=1440.9, | |
| benchmark_fps=0.69, | |
| notes="Original default and reference OpenVINO precision.", | |
| ), | |
| ModelVariant( | |
| key="fp32_512x512", | |
| label="FP32 - 512x512 - 366 ms / 2.73 FPS", | |
| xml="openvino_fp32/birefnet_lite_512x512.xml", | |
| precision="FP32", | |
| input_size=512, | |
| benchmark_ms=366.46, | |
| benchmark_fps=2.73, | |
| notes="Reference OpenVINO precision at lower input resolution.", | |
| ), | |
| ModelVariant( | |
| key="int8wo_1024x1024", | |
| label="INT8 weight-only - 1024x1024 - 1440 ms / 0.69 FPS", | |
| xml="openvino_int8wo/birefnet_lite_1024x1024_int8wo.xml", | |
| precision="INT8 weight-only", | |
| input_size=1024, | |
| benchmark_ms=1439.53, | |
| benchmark_fps=0.69, | |
| notes="Alternative weight-only quantized full-resolution model.", | |
| ), | |
| ModelVariant( | |
| key="int8wo_512x512", | |
| label="INT8 weight-only - 512x512 - 366 ms / 2.73 FPS", | |
| xml="openvino_int8wo/birefnet_lite_512x512_int8wo.xml", | |
| precision="INT8 weight-only", | |
| input_size=512, | |
| benchmark_ms=365.75, | |
| benchmark_fps=2.73, | |
| notes="Alternative weight-only quantized lower-resolution model.", | |
| ), | |
| ) | |
| MODEL_VARIANTS_BY_KEY = {variant.key: variant for variant in MODEL_VARIANTS} | |
| class Runtime: | |
| compiled_model: object | |
| input_node: object | |
| output_node: object | |
| variant: ModelVariant | |
| model_path: str | |
| load_seconds: float | |
| device: str | |
| def get_model_variant(variant_key: str | None) -> ModelVariant: | |
| key = variant_key or DEFAULT_MODEL_VARIANT_KEY | |
| if key not in MODEL_VARIANTS_BY_KEY: | |
| valid_keys = ", ".join(MODEL_VARIANTS_BY_KEY) | |
| raise gr.Error(f"Unknown model variant '{key}'. Valid variants: {valid_keys}") | |
| return MODEL_VARIANTS_BY_KEY[key] | |
| def _resampling(name: str) -> int: | |
| return getattr(Image.Resampling, name) | |
| def get_runtime(variant_key: str) -> Runtime: | |
| variant = get_model_variant(variant_key) | |
| started = time.perf_counter() | |
| model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=variant.xml) | |
| weights_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=variant.bin) | |
| core = Core() | |
| model = core.read_model(model=model_path, weights=weights_path) | |
| model.reshape({model.input(0): [1, 3, variant.input_size, variant.input_size]}) | |
| compiled_model = core.compile_model(model, DEVICE) | |
| return Runtime( | |
| compiled_model=compiled_model, | |
| input_node=compiled_model.input(0), | |
| output_node=compiled_model.output(0), | |
| variant=variant, | |
| model_path=model_path, | |
| load_seconds=time.perf_counter() - started, | |
| device=DEVICE, | |
| ) | |
| def preprocess(image: Image.Image, model_size: int) -> np.ndarray: | |
| rgb_image = ImageOps.exif_transpose(image).convert("RGB") | |
| resized = rgb_image.resize((model_size, model_size), _resampling("BICUBIC")) | |
| array = np.asarray(resized, dtype=np.float32) / 255.0 | |
| array = (array - IMAGENET_MEAN) / IMAGENET_STD | |
| array = np.transpose(array, (2, 0, 1))[None, ...] | |
| return np.ascontiguousarray(array, dtype=np.float32) | |
| def sigmoid(array: np.ndarray) -> np.ndarray: | |
| clipped = np.clip(array, -50.0, 50.0) | |
| return 1.0 / (1.0 + np.exp(-clipped)) | |
| def postprocess_mask(output: np.ndarray, size: tuple[int, int]) -> Image.Image: | |
| mask = np.asarray(output, dtype=np.float32) | |
| while mask.ndim > 2: | |
| mask = mask[0] | |
| mask = sigmoid(mask) | |
| mask = np.clip(mask * 255.0, 0, 255).astype(np.uint8) | |
| mask_image = Image.fromarray(mask, mode="L") | |
| return mask_image.resize(size, _resampling("LANCZOS")) | |
| def remove_background(image: Image.Image, model_variant_key: str): | |
| if image is None: | |
| raise gr.Error("Upload an image first.") | |
| total_started = time.perf_counter() | |
| variant = get_model_variant(model_variant_key) | |
| runtime = get_runtime(variant.key) | |
| original = ImageOps.exif_transpose(image).convert("RGB") | |
| preprocess_started = time.perf_counter() | |
| tensor = preprocess(original, variant.input_size) | |
| preprocess_seconds = time.perf_counter() - preprocess_started | |
| inference_started = time.perf_counter() | |
| output = runtime.compiled_model({runtime.input_node: tensor})[runtime.output_node] | |
| inference_seconds = time.perf_counter() - inference_started | |
| postprocess_started = time.perf_counter() | |
| mask_image = postprocess_mask(output, original.size) | |
| cutout = original.convert("RGBA") | |
| cutout.putalpha(mask_image) | |
| postprocess_seconds = time.perf_counter() - postprocess_started | |
| total_seconds = time.perf_counter() - total_started | |
| timing = ( | |
| f"Total: {total_seconds:.3f} s\n" | |
| f"Preprocess: {preprocess_seconds * 1000:.1f} ms\n" | |
| f"Inference: {inference_seconds * 1000:.1f} ms\n" | |
| f"Postprocess: {postprocess_seconds * 1000:.1f} ms" | |
| ) | |
| specs = { | |
| "model": MODEL_REPO_ID, | |
| "variant": variant.key, | |
| "variant_label": variant.label, | |
| "model_xml": variant.xml, | |
| "device": runtime.device, | |
| "precision": variant.precision, | |
| "model_input_size": f"{variant.input_size}x{variant.input_size}", | |
| "benchmark_ms": variant.benchmark_ms, | |
| "benchmark_fps": variant.benchmark_fps, | |
| "variant_notes": variant.notes, | |
| "uploaded_image_size": f"{original.width}x{original.height}", | |
| "input_tensor_shape": list(tensor.shape), | |
| "output_tensor_shape": list(np.asarray(output).shape), | |
| "model_load_seconds": round(runtime.load_seconds, 3), | |
| "total_seconds": round(total_seconds, 3), | |
| "preprocess_ms": round(preprocess_seconds * 1000, 1), | |
| "inference_ms": round(inference_seconds * 1000, 1), | |
| "postprocess_ms": round(postprocess_seconds * 1000, 1), | |
| } | |
| return mask_image, cutout, timing, specs | |
| with gr.Blocks(title="BiRefNet OpenVINO") as demo: | |
| gr.Markdown("# BiRefNet OpenVINO") | |
| with gr.Row(): | |
| input_image = gr.Image(label="Image", type="pil") | |
| model_dropdown = gr.Dropdown( | |
| label="Model variant", | |
| choices=[(variant.label, variant.key) for variant in MODEL_VARIANTS], | |
| value=get_model_variant(DEFAULT_MODEL_VARIANT_KEY).key, | |
| interactive=True, | |
| ) | |
| run_button = gr.Button("Run", variant="primary") | |
| with gr.Row(): | |
| mask_output = gr.Image(label="Mask", type="pil") | |
| cutout_output = gr.Image(label="Background removed", type="pil", format="png") | |
| with gr.Row(): | |
| timing_output = gr.Textbox(label="Processing time", lines=4) | |
| specs_output = gr.JSON(label="Specs") | |
| run_button.click( | |
| fn=remove_background, | |
| inputs=[input_image, model_dropdown], | |
| outputs=[mask_output, cutout_output, timing_output, specs_output], | |
| ) | |
| input_image.upload( | |
| fn=remove_background, | |
| inputs=[input_image, model_dropdown], | |
| outputs=[mask_output, cutout_output, timing_output, specs_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=8).launch() | |