Tyler Ng
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -43,7 +43,6 @@ def ensure_rgba(pil: Image.Image) -> Image.Image:
|
|
| 43 |
|
| 44 |
|
| 45 |
def make_checkerboard(w: int, h: int, block: int = 16) -> Image.Image:
|
| 46 |
-
# Neutral checkerboard
|
| 47 |
cols = int(math.ceil(w / block))
|
| 48 |
rows = int(math.ceil(h / block))
|
| 49 |
board = np.zeros((rows * block, cols * block, 3), dtype=np.uint8)
|
|
@@ -79,6 +78,11 @@ def now_ms() -> float:
|
|
| 79 |
return time.perf_counter() * 1000.0
|
| 80 |
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
@dataclass
|
| 83 |
class Timing:
|
| 84 |
preprocess_ms: float
|
|
@@ -109,7 +113,7 @@ class ModelManager:
|
|
| 109 |
5) IS-Net (isnet-general-use) via rembg
|
| 110 |
"""
|
| 111 |
def __init__(self):
|
| 112 |
-
|
| 113 |
self._inspy: Optional[Remover] = None
|
| 114 |
|
| 115 |
self._torch_models: Dict[str, AutoModelForImageSegmentation] = {}
|
|
@@ -132,17 +136,16 @@ class ModelManager:
|
|
| 132 |
pass
|
| 133 |
|
| 134 |
def _maybe_sync(self):
|
| 135 |
-
if
|
| 136 |
torch.cuda.synchronize()
|
| 137 |
|
| 138 |
def _load_inspy(self) -> Remover:
|
| 139 |
if self._inspy is None:
|
| 140 |
-
# jit=False like your sample
|
| 141 |
self._inspy = Remover(jit=False)
|
| 142 |
return self._inspy
|
| 143 |
|
| 144 |
def _offload_torch_models_from_gpu(self, keep_name: str):
|
| 145 |
-
if
|
| 146 |
return
|
| 147 |
if self._torch_model_on_gpu and self._torch_model_on_gpu != keep_name:
|
| 148 |
prev = self._torch_models.get(self._torch_model_on_gpu)
|
|
@@ -167,7 +170,7 @@ class ModelManager:
|
|
| 167 |
|
| 168 |
m = AutoModelForImageSegmentation.from_pretrained(model_id, trust_remote_code=True)
|
| 169 |
m.eval()
|
| 170 |
-
# Keep on CPU initially; move to GPU on-demand
|
| 171 |
m.to("cpu")
|
| 172 |
self._torch_models[key] = m
|
| 173 |
return m
|
|
@@ -179,27 +182,28 @@ class ModelManager:
|
|
| 179 |
if name in self._rembg_sessions:
|
| 180 |
return self._rembg_sessions[name]
|
| 181 |
|
| 182 |
-
# Prefer CUDA provider if
|
| 183 |
-
|
| 184 |
-
|
| 185 |
try:
|
| 186 |
-
|
| 187 |
except Exception:
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
self._rembg_sessions[name] = sess
|
| 192 |
return sess
|
| 193 |
|
| 194 |
def _run_torch_alpha_model(self, model_key: str, image_rgb: Image.Image) -> Image.Image:
|
| 195 |
"""
|
| 196 |
-
Runs a torch segmentation model that returns a single-channel mask
|
| 197 |
Returns RGBA (with alpha).
|
| 198 |
"""
|
|
|
|
| 199 |
m = self._load_torch_model(model_key)
|
| 200 |
|
| 201 |
# Put model on GPU for inference if possible
|
| 202 |
-
if
|
| 203 |
self._offload_torch_models_from_gpu(keep_name=model_key)
|
| 204 |
if self._torch_model_on_gpu != model_key:
|
| 205 |
m.to("cuda")
|
|
@@ -209,10 +213,10 @@ class ModelManager:
|
|
| 209 |
orig_size = image_rgb.size
|
| 210 |
|
| 211 |
x = self._tf_1024(image_rgb).unsqueeze(0)
|
| 212 |
-
x = x.to(
|
| 213 |
|
| 214 |
with torch.inference_mode():
|
| 215 |
-
if
|
| 216 |
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 217 |
preds = m(x)[-1].sigmoid()
|
| 218 |
else:
|
|
@@ -220,7 +224,7 @@ class ModelManager:
|
|
| 220 |
|
| 221 |
# Convert prediction to PIL alpha channel
|
| 222 |
pred = preds[0].squeeze().detach().float().cpu()
|
| 223 |
-
alpha = transforms.ToPILImage()(pred).resize(orig_size)
|
| 224 |
|
| 225 |
out = image_rgb.convert("RGBA")
|
| 226 |
out.putalpha(alpha)
|
|
@@ -244,7 +248,6 @@ class ModelManager:
|
|
| 244 |
inf0 = now_ms()
|
| 245 |
if model_name == "InSPyReNet":
|
| 246 |
remover = self._load_inspy()
|
| 247 |
-
# The library returns various modes; we want alpha mask and apply ourselves for consistent output
|
| 248 |
mask = remover.process(input_image, type="map")
|
| 249 |
if isinstance(mask, Image.Image):
|
| 250 |
mask = mask.convert("L")
|
|
@@ -259,17 +262,18 @@ class ModelManager:
|
|
| 259 |
|
| 260 |
elif model_name == "U2Net":
|
| 261 |
sess = self._get_rembg_session("u2net")
|
| 262 |
-
# rembg returns
|
| 263 |
-
|
| 264 |
-
out =
|
| 265 |
|
| 266 |
elif model_name == "BRIA RMBG 2.0":
|
| 267 |
out = self._run_torch_alpha_model("bria_rmbg_2", img_rgb)
|
| 268 |
|
| 269 |
elif model_name == "IS-Net":
|
| 270 |
sess = self._get_rembg_session("isnet-general-use")
|
| 271 |
-
|
| 272 |
-
out =
|
|
|
|
| 273 |
|
| 274 |
else:
|
| 275 |
raise ValueError(f"Unknown model: {model_name}")
|
|
@@ -314,27 +318,20 @@ def run_single(model_name: str, image: Image.Image):
|
|
| 314 |
if image is None:
|
| 315 |
return None, None, "Upload an image first.", None
|
| 316 |
|
| 317 |
-
# Warmup-ish for fairer timing (tiny; avoids huge overhead in UI)
|
| 318 |
-
# Note: real benchmark tab does proper warmups.
|
| 319 |
out_rgba, timing = MANAGER.run(model_name, image)
|
| 320 |
|
| 321 |
-
# Slider wants (processed, original) or (after, before) depending on component;
|
| 322 |
-
# we’ll show: left=original, right=on-checkerboard preview of transparent output.
|
| 323 |
preview = rgba_on_checkerboard(out_rgba)
|
| 324 |
-
|
| 325 |
out_path = save_temp_png(out_rgba)
|
| 326 |
return (image, preview), out_rgba, timing.to_text(), out_path
|
| 327 |
|
| 328 |
|
| 329 |
def list_bench_images() -> List[str]:
|
| 330 |
-
# Put your 10–15 images under bench/
|
| 331 |
exts = ("*.jpg", "*.jpeg", "*.png", "*.webp")
|
| 332 |
files = []
|
| 333 |
for e in exts:
|
| 334 |
files += glob.glob(os.path.join("bench", e))
|
| 335 |
files = sorted(files)
|
| 336 |
|
| 337 |
-
# Fallback to repo-root examples like your sample Space
|
| 338 |
if not files:
|
| 339 |
fallback = []
|
| 340 |
for f in ["1.jpg", "2.jpg", "3.png", "4.webp"]:
|
|
@@ -348,7 +345,8 @@ def list_bench_images() -> List[str]:
|
|
| 348 |
def run_benchmark(model_name: str, repeats: int = 1):
|
| 349 |
files = list_bench_images()
|
| 350 |
if not files:
|
| 351 |
-
|
|
|
|
| 352 |
|
| 353 |
# Warmup: 2 runs on first image (not timed)
|
| 354 |
warm_img = Image.open(files[0]).convert("RGB")
|
|
@@ -363,12 +361,12 @@ def run_benchmark(model_name: str, repeats: int = 1):
|
|
| 363 |
img = Image.open(f).convert("RGB")
|
| 364 |
for r in range(repeats):
|
| 365 |
out, timing = MANAGER.run(model_name, img)
|
| 366 |
-
rows.append(
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
total_ms += timing.total_ms
|
| 373 |
n_images += 1
|
| 374 |
|
|
@@ -380,26 +378,21 @@ def run_benchmark(model_name: str, repeats: int = 1):
|
|
| 380 |
f"Images: {len(files)} (repeats={repeats}) => runs={n_images}\n"
|
| 381 |
f"Avg total: {avg_ms:.2f} ms\n"
|
| 382 |
f"Estimated throughput: {ips:.2f} images/sec\n"
|
| 383 |
-
f"Device: {'
|
| 384 |
)
|
| 385 |
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
value=[[r["file"], r["repeat"], r["total_ms"], r["inference_ms"]] for r in rows],
|
| 389 |
-
datatype=["str", "number", "number", "number"],
|
| 390 |
-
interactive=False
|
| 391 |
-
)
|
| 392 |
-
return df, summary
|
| 393 |
|
| 394 |
|
| 395 |
# ----------------------------
|
| 396 |
# UI
|
| 397 |
# ----------------------------
|
| 398 |
|
| 399 |
-
with gr.Blocks(title="Background Removal Benchmark
|
| 400 |
gr.Markdown(
|
| 401 |
"""
|
| 402 |
-
# Background Removal Benchmark
|
| 403 |
|
| 404 |
Benchmarked models:
|
| 405 |
1) InSPyReNet
|
|
@@ -437,7 +430,7 @@ Benchmarked models:
|
|
| 437 |
with gr.Row():
|
| 438 |
with gr.Column(scale=1):
|
| 439 |
bench_model = gr.Dropdown(choices=MODEL_CHOICES, value="InSPyReNet", label="Model")
|
| 440 |
-
repeats = gr.Slider(1, 5, value=1, step=1, label="Repeats per image
|
| 441 |
bench_btn = gr.Button("Run benchmark", variant="primary")
|
| 442 |
with gr.Column(scale=2):
|
| 443 |
bench_table = gr.Dataframe(
|
|
@@ -453,7 +446,6 @@ Benchmarked models:
|
|
| 453 |
outputs=[bench_table, bench_summary]
|
| 454 |
)
|
| 455 |
|
| 456 |
-
# Examples (optional) — if these files exist, they show up like your sample Space
|
| 457 |
example_files = []
|
| 458 |
for f in ["1.jpg", "2.jpg", "3.png", "4.webp"]:
|
| 459 |
if os.path.exists(f):
|
|
@@ -466,4 +458,4 @@ Benchmarked models:
|
|
| 466 |
)
|
| 467 |
|
| 468 |
if __name__ == "__main__":
|
| 469 |
-
demo.launch(show_error=True)
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def make_checkerboard(w: int, h: int, block: int = 16) -> Image.Image:
|
|
|
|
| 46 |
cols = int(math.ceil(w / block))
|
| 47 |
rows = int(math.ceil(h / block))
|
| 48 |
board = np.zeros((rows * block, cols * block, 3), dtype=np.uint8)
|
|
|
|
| 78 |
return time.perf_counter() * 1000.0
|
| 79 |
|
| 80 |
|
| 81 |
+
def get_device() -> str:
|
| 82 |
+
"""Get device at runtime (important for ZeroGPU)."""
|
| 83 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 84 |
+
|
| 85 |
+
|
| 86 |
@dataclass
|
| 87 |
class Timing:
|
| 88 |
preprocess_ms: float
|
|
|
|
| 113 |
5) IS-Net (isnet-general-use) via rembg
|
| 114 |
"""
|
| 115 |
def __init__(self):
|
| 116 |
+
# NOTE: Don't cache device here - ZeroGPU allocates GPU later
|
| 117 |
self._inspy: Optional[Remover] = None
|
| 118 |
|
| 119 |
self._torch_models: Dict[str, AutoModelForImageSegmentation] = {}
|
|
|
|
| 136 |
pass
|
| 137 |
|
| 138 |
def _maybe_sync(self):
|
| 139 |
+
if get_device() == "cuda":
|
| 140 |
torch.cuda.synchronize()
|
| 141 |
|
| 142 |
def _load_inspy(self) -> Remover:
|
| 143 |
if self._inspy is None:
|
|
|
|
| 144 |
self._inspy = Remover(jit=False)
|
| 145 |
return self._inspy
|
| 146 |
|
| 147 |
def _offload_torch_models_from_gpu(self, keep_name: str):
|
| 148 |
+
if get_device() != "cuda":
|
| 149 |
return
|
| 150 |
if self._torch_model_on_gpu and self._torch_model_on_gpu != keep_name:
|
| 151 |
prev = self._torch_models.get(self._torch_model_on_gpu)
|
|
|
|
| 170 |
|
| 171 |
m = AutoModelForImageSegmentation.from_pretrained(model_id, trust_remote_code=True)
|
| 172 |
m.eval()
|
| 173 |
+
# Keep on CPU initially; move to GPU on-demand
|
| 174 |
m.to("cpu")
|
| 175 |
self._torch_models[key] = m
|
| 176 |
return m
|
|
|
|
| 182 |
if name in self._rembg_sessions:
|
| 183 |
return self._rembg_sessions[name]
|
| 184 |
|
| 185 |
+
# Prefer CUDA provider if available
|
| 186 |
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 187 |
+
|
| 188 |
try:
|
| 189 |
+
sess = new_session(name, providers=providers)
|
| 190 |
except Exception:
|
| 191 |
+
# Fallback to default providers
|
| 192 |
+
sess = new_session(name)
|
| 193 |
+
|
| 194 |
self._rembg_sessions[name] = sess
|
| 195 |
return sess
|
| 196 |
|
| 197 |
def _run_torch_alpha_model(self, model_key: str, image_rgb: Image.Image) -> Image.Image:
|
| 198 |
"""
|
| 199 |
+
Runs a torch segmentation model that returns a single-channel mask.
|
| 200 |
Returns RGBA (with alpha).
|
| 201 |
"""
|
| 202 |
+
device = get_device() # Check device at runtime!
|
| 203 |
m = self._load_torch_model(model_key)
|
| 204 |
|
| 205 |
# Put model on GPU for inference if possible
|
| 206 |
+
if device == "cuda":
|
| 207 |
self._offload_torch_models_from_gpu(keep_name=model_key)
|
| 208 |
if self._torch_model_on_gpu != model_key:
|
| 209 |
m.to("cuda")
|
|
|
|
| 213 |
orig_size = image_rgb.size
|
| 214 |
|
| 215 |
x = self._tf_1024(image_rgb).unsqueeze(0)
|
| 216 |
+
x = x.to(device)
|
| 217 |
|
| 218 |
with torch.inference_mode():
|
| 219 |
+
if device == "cuda":
|
| 220 |
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 221 |
preds = m(x)[-1].sigmoid()
|
| 222 |
else:
|
|
|
|
| 224 |
|
| 225 |
# Convert prediction to PIL alpha channel
|
| 226 |
pred = preds[0].squeeze().detach().float().cpu()
|
| 227 |
+
alpha = transforms.ToPILImage()(pred).resize(orig_size, Image.BILINEAR)
|
| 228 |
|
| 229 |
out = image_rgb.convert("RGBA")
|
| 230 |
out.putalpha(alpha)
|
|
|
|
| 248 |
inf0 = now_ms()
|
| 249 |
if model_name == "InSPyReNet":
|
| 250 |
remover = self._load_inspy()
|
|
|
|
| 251 |
mask = remover.process(input_image, type="map")
|
| 252 |
if isinstance(mask, Image.Image):
|
| 253 |
mask = mask.convert("L")
|
|
|
|
| 262 |
|
| 263 |
elif model_name == "U2Net":
|
| 264 |
sess = self._get_rembg_session("u2net")
|
| 265 |
+
# FIX: rembg returns PIL Image when given PIL Image, not bytes!
|
| 266 |
+
out = rembg_remove(img_rgb, session=sess)
|
| 267 |
+
out = ensure_rgba(out)
|
| 268 |
|
| 269 |
elif model_name == "BRIA RMBG 2.0":
|
| 270 |
out = self._run_torch_alpha_model("bria_rmbg_2", img_rgb)
|
| 271 |
|
| 272 |
elif model_name == "IS-Net":
|
| 273 |
sess = self._get_rembg_session("isnet-general-use")
|
| 274 |
+
# FIX: rembg returns PIL Image when given PIL Image, not bytes!
|
| 275 |
+
out = rembg_remove(img_rgb, session=sess)
|
| 276 |
+
out = ensure_rgba(out)
|
| 277 |
|
| 278 |
else:
|
| 279 |
raise ValueError(f"Unknown model: {model_name}")
|
|
|
|
| 318 |
if image is None:
|
| 319 |
return None, None, "Upload an image first.", None
|
| 320 |
|
|
|
|
|
|
|
| 321 |
out_rgba, timing = MANAGER.run(model_name, image)
|
| 322 |
|
|
|
|
|
|
|
| 323 |
preview = rgba_on_checkerboard(out_rgba)
|
|
|
|
| 324 |
out_path = save_temp_png(out_rgba)
|
| 325 |
return (image, preview), out_rgba, timing.to_text(), out_path
|
| 326 |
|
| 327 |
|
| 328 |
def list_bench_images() -> List[str]:
|
|
|
|
| 329 |
exts = ("*.jpg", "*.jpeg", "*.png", "*.webp")
|
| 330 |
files = []
|
| 331 |
for e in exts:
|
| 332 |
files += glob.glob(os.path.join("bench", e))
|
| 333 |
files = sorted(files)
|
| 334 |
|
|
|
|
| 335 |
if not files:
|
| 336 |
fallback = []
|
| 337 |
for f in ["1.jpg", "2.jpg", "3.png", "4.webp"]:
|
|
|
|
| 345 |
def run_benchmark(model_name: str, repeats: int = 1):
|
| 346 |
files = list_bench_images()
|
| 347 |
if not files:
|
| 348 |
+
# FIX: Return data values, not gr.Dataframe component
|
| 349 |
+
return [], "No benchmark images found. Add 10–15 images under bench/."
|
| 350 |
|
| 351 |
# Warmup: 2 runs on first image (not timed)
|
| 352 |
warm_img = Image.open(files[0]).convert("RGB")
|
|
|
|
| 361 |
img = Image.open(f).convert("RGB")
|
| 362 |
for r in range(repeats):
|
| 363 |
out, timing = MANAGER.run(model_name, img)
|
| 364 |
+
rows.append([
|
| 365 |
+
os.path.basename(f),
|
| 366 |
+
r + 1,
|
| 367 |
+
round(timing.total_ms, 2),
|
| 368 |
+
round(timing.inference_ms, 2),
|
| 369 |
+
])
|
| 370 |
total_ms += timing.total_ms
|
| 371 |
n_images += 1
|
| 372 |
|
|
|
|
| 378 |
f"Images: {len(files)} (repeats={repeats}) => runs={n_images}\n"
|
| 379 |
f"Avg total: {avg_ms:.2f} ms\n"
|
| 380 |
f"Estimated throughput: {ips:.2f} images/sec\n"
|
| 381 |
+
f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}"
|
| 382 |
)
|
| 383 |
|
| 384 |
+
# FIX: Return the data directly, not a gr.Dataframe component
|
| 385 |
+
return rows, summary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
|
| 388 |
# ----------------------------
|
| 389 |
# UI
|
| 390 |
# ----------------------------
|
| 391 |
|
| 392 |
+
with gr.Blocks(title="Background Removal Benchmark") as demo:
|
| 393 |
gr.Markdown(
|
| 394 |
"""
|
| 395 |
+
# Background Removal Benchmark
|
| 396 |
|
| 397 |
Benchmarked models:
|
| 398 |
1) InSPyReNet
|
|
|
|
| 430 |
with gr.Row():
|
| 431 |
with gr.Column(scale=1):
|
| 432 |
bench_model = gr.Dropdown(choices=MODEL_CHOICES, value="InSPyReNet", label="Model")
|
| 433 |
+
repeats = gr.Slider(1, 5, value=1, step=1, label="Repeats per image")
|
| 434 |
bench_btn = gr.Button("Run benchmark", variant="primary")
|
| 435 |
with gr.Column(scale=2):
|
| 436 |
bench_table = gr.Dataframe(
|
|
|
|
| 446 |
outputs=[bench_table, bench_summary]
|
| 447 |
)
|
| 448 |
|
|
|
|
| 449 |
example_files = []
|
| 450 |
for f in ["1.jpg", "2.jpg", "3.png", "4.webp"]:
|
| 451 |
if os.path.exists(f):
|
|
|
|
| 458 |
)
|
| 459 |
|
| 460 |
if __name__ == "__main__":
|
| 461 |
+
demo.launch(show_error=True)
|