|
|
import os |
|
|
import time |
|
|
import glob |
|
|
import math |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, Optional, Tuple, List |
|
|
|
|
|
import gradio as gr |
|
|
import spaces |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from transformers import AutoModelForImageSegmentation |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
from gradio_imageslider import ImageSlider |
|
|
|
|
|
|
|
|
from transparent_background import Remover |
|
|
|
|
|
|
|
|
from rembg import new_session, remove as rembg_remove |
|
|
|
|
|
|
|
|
from withoutbg import WithoutBG |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pil_to_rgb(pil: Image.Image) -> Image.Image: |
|
|
if pil.mode != "RGB": |
|
|
return pil.convert("RGB") |
|
|
return pil |
|
|
|
|
|
|
|
|
def ensure_rgba(pil: Image.Image) -> Image.Image: |
|
|
if pil.mode != "RGBA": |
|
|
return pil.convert("RGBA") |
|
|
return pil |
|
|
|
|
|
|
|
|
def make_checkerboard(w: int, h: int, block: int = 16) -> Image.Image: |
|
|
cols = int(math.ceil(w / block)) |
|
|
rows = int(math.ceil(h / block)) |
|
|
board = np.zeros((rows * block, cols * block, 3), dtype=np.uint8) |
|
|
c1, c2 = np.array([235, 235, 235], dtype=np.uint8), np.array([200, 200, 200], dtype=np.uint8) |
|
|
for r in range(rows): |
|
|
for c in range(cols): |
|
|
color = c1 if (r + c) % 2 == 0 else c2 |
|
|
board[r * block:(r + 1) * block, c * block:(c + 1) * block] = color |
|
|
return Image.fromarray(board[:h, :w, :], mode="RGB") |
|
|
|
|
|
|
|
|
def rgba_on_checkerboard(rgba: Image.Image) -> Image.Image: |
|
|
rgba = ensure_rgba(rgba) |
|
|
w, h = rgba.size |
|
|
bg = make_checkerboard(w, h) |
|
|
comp = Image.alpha_composite(bg.convert("RGBA"), rgba) |
|
|
return comp.convert("RGB") |
|
|
|
|
|
|
|
|
def save_temp_png(rgba: Image.Image, out_dir: str = "output_images") -> str: |
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
path = os.path.join(out_dir, "no_bg.png") |
|
|
ensure_rgba(rgba).save(path, format="PNG") |
|
|
return path |
|
|
|
|
|
|
|
|
def now_ms() -> float: |
|
|
return time.perf_counter() * 1000.0 |
|
|
|
|
|
|
|
|
def get_device() -> str: |
|
|
"""Get device at runtime (important for ZeroGPU).""" |
|
|
return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Timing: |
|
|
preprocess_ms: float |
|
|
inference_ms: float |
|
|
postprocess_ms: float |
|
|
total_ms: float |
|
|
|
|
|
def to_text(self) -> str: |
|
|
return ( |
|
|
f"preprocess: {self.preprocess_ms:.2f} ms\n" |
|
|
f"inference: {self.inference_ms:.2f} ms\n" |
|
|
f"postprocess: {self.postprocess_ms:.2f} ms\n" |
|
|
f"TOTAL: {self.total_ms:.2f} ms" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelManager: |
|
|
def __init__(self): |
|
|
self._inspy: Optional[Remover] = None |
|
|
self._withoutbg: Optional[object] = None |
|
|
self._withoutbg_had_gpu: bool = False |
|
|
self._torch_models: Dict[str, torch.nn.Module] = {} |
|
|
self._torch_model_on_gpu: Optional[str] = None |
|
|
self._rembg_sessions: Dict[str, object] = {} |
|
|
self._model_load_errors: Dict[str, str] = {} |
|
|
|
|
|
self._tf_1024 = transforms.Compose([ |
|
|
transforms.Resize((1024, 1024)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
try: |
|
|
torch.set_float32_matmul_precision("high") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def _maybe_sync(self): |
|
|
if get_device() == "cuda": |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
def _load_inspy(self) -> Remover: |
|
|
if self._inspy is None: |
|
|
self._inspy = Remover(jit=False) |
|
|
return self._inspy |
|
|
|
|
|
def _load_withoutbg(self, force_reload: bool = False): |
|
|
""" |
|
|
Load withoutBG model. |
|
|
Automatically reloads if GPU became available after initial load. |
|
|
""" |
|
|
gpu_available_now = torch.cuda.is_available() |
|
|
|
|
|
|
|
|
need_reload = ( |
|
|
force_reload or |
|
|
self._withoutbg is None or |
|
|
(gpu_available_now and not self._withoutbg_had_gpu) |
|
|
) |
|
|
|
|
|
if need_reload: |
|
|
self._withoutbg = WithoutBG.opensource() |
|
|
self._withoutbg_had_gpu = gpu_available_now |
|
|
|
|
|
return self._withoutbg |
|
|
|
|
|
def _offload_torch_models_from_gpu(self, keep_name: str): |
|
|
if get_device() != "cuda": |
|
|
return |
|
|
if self._torch_model_on_gpu and self._torch_model_on_gpu != keep_name: |
|
|
prev = self._torch_models.get(self._torch_model_on_gpu) |
|
|
if prev is not None: |
|
|
prev.to("cpu") |
|
|
self._torch_model_on_gpu = None |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def _load_torch_model(self, key: str) -> torch.nn.Module: |
|
|
"""Load BiRefNet or BRIA RMBG 2.0 model.""" |
|
|
if key in self._torch_models: |
|
|
return self._torch_models[key] |
|
|
|
|
|
if key in self._model_load_errors: |
|
|
raise RuntimeError(self._model_load_errors[key]) |
|
|
|
|
|
model_configs = { |
|
|
"birefnet": "ZhengPeng7/BiRefNet", |
|
|
"bria_rmbg_2": "briaai/RMBG-2.0", |
|
|
} |
|
|
|
|
|
if key not in model_configs: |
|
|
raise ValueError(f"Unknown model key: {key}") |
|
|
|
|
|
model_id = model_configs[key] |
|
|
|
|
|
try: |
|
|
m = AutoModelForImageSegmentation.from_pretrained( |
|
|
model_id, |
|
|
trust_remote_code=True |
|
|
) |
|
|
m.eval() |
|
|
m.to("cpu") |
|
|
self._torch_models[key] = m |
|
|
return m |
|
|
except OSError as e: |
|
|
error_msg = str(e) |
|
|
if "gated" in error_msg.lower() or "401" in error_msg or "access" in error_msg.lower(): |
|
|
self._model_load_errors[key] = ( |
|
|
f"Model '{model_id}' requires license acceptance.\n" |
|
|
f"1. Go to https://huggingface.co/{model_id}\n" |
|
|
f"2. Accept the license agreement\n" |
|
|
f"3. Add HF_TOKEN secret to your Space settings" |
|
|
) |
|
|
else: |
|
|
self._model_load_errors[key] = f"Failed to load {model_id}: {error_msg}" |
|
|
raise RuntimeError(self._model_load_errors[key]) |
|
|
except ImportError as e: |
|
|
self._model_load_errors[key] = ( |
|
|
f"Import error loading {model_id}: {e}\n" |
|
|
f"Make sure 'timm' is in requirements.txt" |
|
|
) |
|
|
raise RuntimeError(self._model_load_errors[key]) |
|
|
|
|
|
def _get_rembg_session(self, name: str): |
|
|
if name in self._rembg_sessions: |
|
|
return self._rembg_sessions[name] |
|
|
|
|
|
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
|
|
try: |
|
|
sess = new_session(name, providers=providers) |
|
|
except Exception: |
|
|
sess = new_session(name) |
|
|
|
|
|
self._rembg_sessions[name] = sess |
|
|
return sess |
|
|
|
|
|
def _run_torch_alpha_model(self, model_key: str, image_rgb: Image.Image) -> Image.Image: |
|
|
device = get_device() |
|
|
m = self._load_torch_model(model_key) |
|
|
|
|
|
if device == "cuda": |
|
|
self._offload_torch_models_from_gpu(keep_name=model_key) |
|
|
if self._torch_model_on_gpu != model_key: |
|
|
m.to("cuda") |
|
|
self._torch_model_on_gpu = model_key |
|
|
|
|
|
image_rgb = pil_to_rgb(image_rgb) |
|
|
orig_size = image_rgb.size |
|
|
|
|
|
x = self._tf_1024(image_rgb).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
if device == "cuda": |
|
|
with torch.autocast(device_type="cuda", dtype=torch.float16): |
|
|
preds = m(x)[-1].sigmoid() |
|
|
else: |
|
|
preds = m(x)[-1].sigmoid() |
|
|
|
|
|
pred = preds[0].squeeze().detach().float().cpu() |
|
|
alpha = transforms.ToPILImage()(pred).resize(orig_size, Image.BILINEAR) |
|
|
|
|
|
out = image_rgb.convert("RGBA") |
|
|
out.putalpha(alpha) |
|
|
return out |
|
|
|
|
|
def run(self, model_name: str, input_image: Image.Image) -> Tuple[Image.Image, Timing]: |
|
|
if input_image is None: |
|
|
raise ValueError("No input image") |
|
|
|
|
|
t0 = now_ms() |
|
|
|
|
|
|
|
|
pre0 = now_ms() |
|
|
img_rgb = pil_to_rgb(input_image) |
|
|
pre1 = now_ms() |
|
|
|
|
|
|
|
|
inf0 = now_ms() |
|
|
|
|
|
if model_name == "InSPyReNet": |
|
|
remover = self._load_inspy() |
|
|
mask = remover.process(input_image, type="map") |
|
|
if isinstance(mask, Image.Image): |
|
|
mask = mask.convert("L") |
|
|
else: |
|
|
mask = Image.fromarray((mask * 255).astype(np.uint8), mode="L") |
|
|
out = img_rgb.convert("RGBA") |
|
|
out.putalpha(mask) |
|
|
|
|
|
elif model_name == "BiRefNet": |
|
|
out = self._run_torch_alpha_model("birefnet", img_rgb) |
|
|
|
|
|
elif model_name == "U2Net": |
|
|
sess = self._get_rembg_session("u2net") |
|
|
out = rembg_remove(img_rgb, session=sess) |
|
|
out = ensure_rgba(out) |
|
|
|
|
|
elif model_name == "BRIA RMBG 2.0": |
|
|
out = self._run_torch_alpha_model("bria_rmbg_2", img_rgb) |
|
|
|
|
|
elif model_name == "IS-Net": |
|
|
sess = self._get_rembg_session("isnet-general-use") |
|
|
out = rembg_remove(img_rgb, session=sess) |
|
|
out = ensure_rgba(out) |
|
|
|
|
|
elif model_name == "withoutBG": |
|
|
|
|
|
model = self._load_withoutbg() |
|
|
out = model.remove_background(img_rgb) |
|
|
out = ensure_rgba(out) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown model: {model_name}") |
|
|
|
|
|
self._maybe_sync() |
|
|
inf1 = now_ms() |
|
|
|
|
|
|
|
|
post0 = now_ms() |
|
|
out = ensure_rgba(out) |
|
|
post1 = now_ms() |
|
|
|
|
|
t1 = now_ms() |
|
|
|
|
|
timing = Timing( |
|
|
preprocess_ms=pre1 - pre0, |
|
|
inference_ms=inf1 - inf0, |
|
|
postprocess_ms=post1 - post0, |
|
|
total_ms=t1 - t0, |
|
|
) |
|
|
return out, timing |
|
|
|
|
|
|
|
|
MANAGER = ModelManager() |
|
|
|
|
|
MODEL_CHOICES = [ |
|
|
"InSPyReNet", |
|
|
"BiRefNet", |
|
|
"U2Net", |
|
|
"BRIA RMBG 2.0", |
|
|
"IS-Net", |
|
|
"withoutBG", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def run_single(model_name: str, image: Image.Image): |
|
|
if image is None: |
|
|
return None, None, "Upload an image first.", None |
|
|
|
|
|
try: |
|
|
out_rgba, timing = MANAGER.run(model_name, image) |
|
|
preview = rgba_on_checkerboard(out_rgba) |
|
|
out_path = save_temp_png(out_rgba) |
|
|
return (image, preview), out_rgba, timing.to_text(), out_path |
|
|
except RuntimeError as e: |
|
|
return None, None, f"Error: {str(e)}", None |
|
|
except Exception as e: |
|
|
return None, None, f"Unexpected error: {str(e)}", None |
|
|
|
|
|
|
|
|
def list_bench_images() -> List[str]: |
|
|
exts = ("*.jpg", "*.jpeg", "*.png", "*.webp") |
|
|
files = [] |
|
|
for e in exts: |
|
|
files += glob.glob(os.path.join("bench", e)) |
|
|
files = sorted(files) |
|
|
|
|
|
if not files: |
|
|
for f in ["1.jpg", "2.jpg", "3.png", "4.webp"]: |
|
|
if os.path.exists(f): |
|
|
files.append(f) |
|
|
return files |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def run_benchmark(model_name: str, repeats: int = 1): |
|
|
files = list_bench_images() |
|
|
if not files: |
|
|
return [], "No benchmark images found. Add 10β15 images under bench/." |
|
|
|
|
|
try: |
|
|
|
|
|
warm_img = Image.open(files[0]).convert("RGB") |
|
|
for _ in range(2): |
|
|
_ = MANAGER.run(model_name, warm_img) |
|
|
|
|
|
rows = [] |
|
|
total_ms = 0.0 |
|
|
n_images = 0 |
|
|
|
|
|
for f in files: |
|
|
img = Image.open(f).convert("RGB") |
|
|
for r in range(repeats): |
|
|
out, timing = MANAGER.run(model_name, img) |
|
|
rows.append([ |
|
|
os.path.basename(f), |
|
|
r + 1, |
|
|
round(timing.total_ms, 2), |
|
|
round(timing.inference_ms, 2), |
|
|
]) |
|
|
total_ms += timing.total_ms |
|
|
n_images += 1 |
|
|
|
|
|
avg_ms = total_ms / max(1, n_images) |
|
|
ips = 1000.0 / avg_ms if avg_ms > 0 else 0.0 |
|
|
|
|
|
summary = ( |
|
|
f"Model: {model_name}\n" |
|
|
f"Images: {len(files)} (repeats={repeats}) => runs={n_images}\n" |
|
|
f"Avg total: {avg_ms:.2f} ms\n" |
|
|
f"Estimated throughput: {ips:.2f} images/sec\n" |
|
|
f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}" |
|
|
) |
|
|
return rows, summary |
|
|
|
|
|
except RuntimeError as e: |
|
|
return [], f"Error: {str(e)}" |
|
|
except Exception as e: |
|
|
return [], f"Unexpected error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Background Removal Benchmark") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# Background Removal Benchmark |
|
|
|
|
|
Benchmarked models: |
|
|
1. **InSPyReNet** β transparent-background library |
|
|
2. **BiRefNet** β ZhengPeng7/BiRefNet (requires `timm`) |
|
|
3. **U2Net** β via rembg/ONNX |
|
|
4. **BRIA RMBG 2.0** β briaai/RMBG-2.0 (requires license acceptance) |
|
|
5. **IS-Net** β isnet-general-use via rembg |
|
|
6. **withoutBG** β 4-stage ONNX pipeline (Depth β ISNet β Matting β Refiner) |
|
|
|
|
|
**Notes** |
|
|
- Output is true transparent PNG (RGBA) |
|
|
- Slider preview shows result on checkerboard |
|
|
- For benchmarks, add images under `bench/` folder |
|
|
|
|
|
β οΈ **BRIA RMBG 2.0**: Requires accepting license at [huggingface.co/briaai/RMBG-2.0](https://huggingface.co/briaai/RMBG-2.0) and adding `HF_TOKEN` secret to Space settings. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Tab("Try single image"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
inp = gr.Image(type="pil", label="Upload image", height=420) |
|
|
model = gr.Dropdown(choices=MODEL_CHOICES, value="InSPyReNet", label="Model") |
|
|
run_btn = gr.Button("Run", variant="primary") |
|
|
with gr.Column(scale=2): |
|
|
slider = ImageSlider(label="Before / After", type="pil") |
|
|
out_img = gr.Image(type="pil", label="Output (RGBA)", height=420) |
|
|
timing_box = gr.Textbox(label="Timing / Errors", lines=5) |
|
|
out_file = gr.File(label="Download PNG (transparent)") |
|
|
|
|
|
run_btn.click( |
|
|
fn=run_single, |
|
|
inputs=[model, inp], |
|
|
outputs=[slider, out_img, timing_box, out_file] |
|
|
) |
|
|
|
|
|
with gr.Tab("Benchmark (throughput estimate)"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
bench_model = gr.Dropdown(choices=MODEL_CHOICES, value="InSPyReNet", label="Model") |
|
|
repeats = gr.Slider(1, 5, value=1, step=1, label="Repeats per image") |
|
|
bench_btn = gr.Button("Run benchmark", variant="primary") |
|
|
with gr.Column(scale=2): |
|
|
bench_table = gr.Dataframe( |
|
|
headers=["file", "repeat", "total_ms", "inference_ms"], |
|
|
datatype=["str", "number", "number", "number"], |
|
|
interactive=False |
|
|
) |
|
|
bench_summary = gr.Textbox(label="Summary", lines=6) |
|
|
|
|
|
bench_btn.click( |
|
|
fn=run_benchmark, |
|
|
inputs=[bench_model, repeats], |
|
|
outputs=[bench_table, bench_summary] |
|
|
) |
|
|
|
|
|
example_files = [] |
|
|
for f in ["1.jpg", "2.jpg", "3.png", "4.webp"]: |
|
|
if os.path.exists(f): |
|
|
example_files.append([f, "InSPyReNet"]) |
|
|
if example_files: |
|
|
gr.Examples(examples=example_files, inputs=[inp, model], label="Examples") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(show_error=True) |