Tyler Ng commited on
Commit
9c24f06
·
verified ·
1 Parent(s): 1ece687

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +469 -0
app.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import time
4
+ import glob
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import Dict, Optional, Tuple, List
8
+
9
+ import gradio as gr
10
+ import spaces
11
+
12
+ import numpy as np
13
+ from PIL import Image
14
+
15
+ import torch
16
+ from torchvision import transforms
17
+ from transformers import AutoModelForImageSegmentation
18
+
19
+ # Slider component (same one BRIA Space uses)
20
+ from gradio_imageslider import ImageSlider
21
+
22
+ # InSPyReNet wrapper (same approach as your sample Space)
23
+ from transparent_background import Remover
24
+
25
+ # rembg (U2Net + IS-Net via ONNX)
26
+ from rembg import new_session, remove as rembg_remove
27
+
28
+
29
+ # ----------------------------
30
+ # Utilities
31
+ # ----------------------------
32
+
33
+ def pil_to_rgb(pil: Image.Image) -> Image.Image:
34
+ if pil.mode != "RGB":
35
+ return pil.convert("RGB")
36
+ return pil
37
+
38
+
39
+ def ensure_rgba(pil: Image.Image) -> Image.Image:
40
+ if pil.mode != "RGBA":
41
+ return pil.convert("RGBA")
42
+ return pil
43
+
44
+
45
+ def make_checkerboard(w: int, h: int, block: int = 16) -> Image.Image:
46
+ # Neutral checkerboard
47
+ cols = int(math.ceil(w / block))
48
+ rows = int(math.ceil(h / block))
49
+ board = np.zeros((rows * block, cols * block, 3), dtype=np.uint8)
50
+
51
+ c1 = np.array([235, 235, 235], dtype=np.uint8)
52
+ c2 = np.array([200, 200, 200], dtype=np.uint8)
53
+
54
+ for r in range(rows):
55
+ for c in range(cols):
56
+ color = c1 if (r + c) % 2 == 0 else c2
57
+ board[r * block:(r + 1) * block, c * block:(c + 1) * block] = color
58
+
59
+ board = board[:h, :w, :]
60
+ return Image.fromarray(board, mode="RGB")
61
+
62
+
63
+ def rgba_on_checkerboard(rgba: Image.Image) -> Image.Image:
64
+ rgba = ensure_rgba(rgba)
65
+ w, h = rgba.size
66
+ bg = make_checkerboard(w, h)
67
+ comp = Image.alpha_composite(bg.convert("RGBA"), rgba)
68
+ return comp.convert("RGB")
69
+
70
+
71
+ def save_temp_png(rgba: Image.Image, out_dir: str = "output_images") -> str:
72
+ os.makedirs(out_dir, exist_ok=True)
73
+ path = os.path.join(out_dir, "no_bg.png")
74
+ ensure_rgba(rgba).save(path, format="PNG")
75
+ return path
76
+
77
+
78
+ def now_ms() -> float:
79
+ return time.perf_counter() * 1000.0
80
+
81
+
82
+ @dataclass
83
+ class Timing:
84
+ preprocess_ms: float
85
+ inference_ms: float
86
+ postprocess_ms: float
87
+ total_ms: float
88
+
89
+ def to_text(self) -> str:
90
+ return (
91
+ f"preprocess: {self.preprocess_ms:.2f} ms\n"
92
+ f"inference: {self.inference_ms:.2f} ms\n"
93
+ f"postprocess: {self.postprocess_ms:.2f} ms\n"
94
+ f"TOTAL: {self.total_ms:.2f} ms"
95
+ )
96
+
97
+
98
+ # ----------------------------
99
+ # Model Manager
100
+ # ----------------------------
101
+
102
+ class ModelManager:
103
+ """
104
+ Loads and runs:
105
+ 1) InSPyReNet via transparent_background.Remover()
106
+ 2) BiRefNet via AutoModelForImageSegmentation("ZhengPeng7/BiRefNet", trust_remote_code=True)
107
+ 3) U2Net via rembg (onnxruntime; can use CUDA provider if available)
108
+ 4) BRIA RMBG 2.0 via AutoModelForImageSegmentation("briaai/RMBG-2.0", trust_remote_code=True)
109
+ 5) IS-Net (isnet-general-use) via rembg
110
+ """
111
+ def __init__(self):
112
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
113
+ self._inspy: Optional[Remover] = None
114
+
115
+ self._torch_models: Dict[str, AutoModelForImageSegmentation] = {}
116
+ self._torch_model_on_gpu: Optional[str] = None
117
+
118
+ # rembg sessions
119
+ self._rembg_sessions: Dict[str, object] = {}
120
+
121
+ # Common transforms for BiRefNet / BRIA RMBG inference
122
+ self._tf_1024 = transforms.Compose([
123
+ transforms.Resize((1024, 1024)),
124
+ transforms.ToTensor(),
125
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
126
+ ])
127
+
128
+ # Try to set matmul precision nicely
129
+ try:
130
+ torch.set_float32_matmul_precision("high")
131
+ except Exception:
132
+ pass
133
+
134
+ def _maybe_sync(self):
135
+ if self.device == "cuda":
136
+ torch.cuda.synchronize()
137
+
138
+ def _load_inspy(self) -> Remover:
139
+ if self._inspy is None:
140
+ # jit=False like your sample
141
+ self._inspy = Remover(jit=False)
142
+ return self._inspy
143
+
144
+ def _offload_torch_models_from_gpu(self, keep_name: str):
145
+ if self.device != "cuda":
146
+ return
147
+ if self._torch_model_on_gpu and self._torch_model_on_gpu != keep_name:
148
+ prev = self._torch_models.get(self._torch_model_on_gpu)
149
+ if prev is not None:
150
+ prev.to("cpu")
151
+ self._torch_model_on_gpu = None
152
+ torch.cuda.empty_cache()
153
+
154
+ def _load_torch_model(self, key: str) -> AutoModelForImageSegmentation:
155
+ """
156
+ key in {"birefnet", "bria_rmbg_2"}
157
+ """
158
+ if key in self._torch_models:
159
+ return self._torch_models[key]
160
+
161
+ if key == "birefnet":
162
+ model_id = "ZhengPeng7/BiRefNet"
163
+ elif key == "bria_rmbg_2":
164
+ model_id = "briaai/RMBG-2.0"
165
+ else:
166
+ raise ValueError(f"Unknown torch model key: {key}")
167
+
168
+ m = AutoModelForImageSegmentation.from_pretrained(model_id, trust_remote_code=True)
169
+ m.eval()
170
+ # Keep on CPU initially; move to GPU on-demand to avoid T4 OOM.
171
+ m.to("cpu")
172
+ self._torch_models[key] = m
173
+ return m
174
+
175
+ def _get_rembg_session(self, name: str):
176
+ """
177
+ name: "u2net" or "isnet-general-use"
178
+ """
179
+ if name in self._rembg_sessions:
180
+ return self._rembg_sessions[name]
181
+
182
+ # Prefer CUDA provider if onnxruntime-gpu is installed; otherwise CPU works.
183
+ # rembg will pass this into onnxruntime internally.
184
+ providers = None
185
+ try:
186
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
187
+ except Exception:
188
+ providers = None
189
+
190
+ sess = new_session(name, providers=providers) if providers else new_session(name)
191
+ self._rembg_sessions[name] = sess
192
+ return sess
193
+
194
+ def _run_torch_alpha_model(self, model_key: str, image_rgb: Image.Image) -> Image.Image:
195
+ """
196
+ Runs a torch segmentation model that returns a single-channel mask (alpha matte-ish).
197
+ Returns RGBA (with alpha).
198
+ """
199
+ m = self._load_torch_model(model_key)
200
+
201
+ # Put model on GPU for inference if possible
202
+ if self.device == "cuda":
203
+ self._offload_torch_models_from_gpu(keep_name=model_key)
204
+ if self._torch_model_on_gpu != model_key:
205
+ m.to("cuda")
206
+ self._torch_model_on_gpu = model_key
207
+
208
+ image_rgb = pil_to_rgb(image_rgb)
209
+ orig_size = image_rgb.size
210
+
211
+ x = self._tf_1024(image_rgb).unsqueeze(0)
212
+ x = x.to(self.device)
213
+
214
+ with torch.inference_mode():
215
+ if self.device == "cuda":
216
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
217
+ preds = m(x)[-1].sigmoid()
218
+ else:
219
+ preds = m(x)[-1].sigmoid()
220
+
221
+ # Convert prediction to PIL alpha channel
222
+ pred = preds[0].squeeze().detach().float().cpu()
223
+ alpha = transforms.ToPILImage()(pred).resize(orig_size)
224
+
225
+ out = image_rgb.convert("RGBA")
226
+ out.putalpha(alpha)
227
+ return out
228
+
229
+ def run(self, model_name: str, input_image: Image.Image) -> Tuple[Image.Image, Timing]:
230
+ """
231
+ Returns (output_rgba, timing).
232
+ """
233
+ if input_image is None:
234
+ raise ValueError("No input image")
235
+
236
+ t0 = now_ms()
237
+
238
+ # --- preprocess ---
239
+ pre0 = now_ms()
240
+ img_rgb = pil_to_rgb(input_image)
241
+ pre1 = now_ms()
242
+
243
+ # --- inference ---
244
+ inf0 = now_ms()
245
+ if model_name == "InSPyReNet":
246
+ remover = self._load_inspy()
247
+ # The library returns various modes; we want alpha mask and apply ourselves for consistent output
248
+ mask = remover.process(input_image, type="map")
249
+ if isinstance(mask, Image.Image):
250
+ mask = mask.convert("L")
251
+ else:
252
+ mask = Image.fromarray((mask * 255).astype(np.uint8), mode="L")
253
+
254
+ out = img_rgb.convert("RGBA")
255
+ out.putalpha(mask)
256
+
257
+ elif model_name == "BiRefNet":
258
+ out = self._run_torch_alpha_model("birefnet", img_rgb)
259
+
260
+ elif model_name == "U2Net":
261
+ sess = self._get_rembg_session("u2net")
262
+ # rembg returns bytes (PNG RGBA)
263
+ out_bytes = rembg_remove(img_rgb, session=sess)
264
+ out = Image.open(io.BytesIO(out_bytes)).convert("RGBA")
265
+
266
+ elif model_name == "BRIA RMBG 2.0":
267
+ out = self._run_torch_alpha_model("bria_rmbg_2", img_rgb)
268
+
269
+ elif model_name == "IS-Net":
270
+ sess = self._get_rembg_session("isnet-general-use")
271
+ out_bytes = rembg_remove(img_rgb, session=sess)
272
+ out = Image.open(io.BytesIO(out_bytes)).convert("RGBA")
273
+
274
+ else:
275
+ raise ValueError(f"Unknown model: {model_name}")
276
+
277
+ # Make sure GPU timing is accurate
278
+ self._maybe_sync()
279
+ inf1 = now_ms()
280
+
281
+ # --- postprocess ---
282
+ post0 = now_ms()
283
+ out = ensure_rgba(out)
284
+ post1 = now_ms()
285
+
286
+ t1 = now_ms()
287
+
288
+ timing = Timing(
289
+ preprocess_ms=pre1 - pre0,
290
+ inference_ms=inf1 - inf0,
291
+ postprocess_ms=post1 - post0,
292
+ total_ms=t1 - t0,
293
+ )
294
+ return out, timing
295
+
296
+
297
+ MANAGER = ModelManager()
298
+
299
+ MODEL_CHOICES = [
300
+ "InSPyReNet",
301
+ "BiRefNet",
302
+ "U2Net",
303
+ "BRIA RMBG 2.0",
304
+ "IS-Net",
305
+ ]
306
+
307
+
308
+ # ----------------------------
309
+ # Gradio handlers
310
+ # ----------------------------
311
+
312
+ @spaces.GPU
313
+ def run_single(model_name: str, image: Image.Image):
314
+ if image is None:
315
+ return None, None, "Upload an image first.", None
316
+
317
+ # Warmup-ish for fairer timing (tiny; avoids huge overhead in UI)
318
+ # Note: real benchmark tab does proper warmups.
319
+ out_rgba, timing = MANAGER.run(model_name, image)
320
+
321
+ # Slider wants (processed, original) or (after, before) depending on component;
322
+ # we’ll show: left=original, right=on-checkerboard preview of transparent output.
323
+ preview = rgba_on_checkerboard(out_rgba)
324
+
325
+ out_path = save_temp_png(out_rgba)
326
+ return (image, preview), out_rgba, timing.to_text(), out_path
327
+
328
+
329
+ def list_bench_images() -> List[str]:
330
+ # Put your 10–15 images under bench/
331
+ exts = ("*.jpg", "*.jpeg", "*.png", "*.webp")
332
+ files = []
333
+ for e in exts:
334
+ files += glob.glob(os.path.join("bench", e))
335
+ files = sorted(files)
336
+
337
+ # Fallback to repo-root examples like your sample Space
338
+ if not files:
339
+ fallback = []
340
+ for f in ["1.jpg", "2.jpg", "3.png", "4.webp"]:
341
+ if os.path.exists(f):
342
+ fallback.append(f)
343
+ files = fallback
344
+ return files
345
+
346
+
347
+ @spaces.GPU
348
+ def run_benchmark(model_name: str, repeats: int = 1):
349
+ files = list_bench_images()
350
+ if not files:
351
+ return gr.Dataframe(value=[]), "No benchmark images found. Add 10–15 images under bench/."
352
+
353
+ # Warmup: 2 runs on first image (not timed)
354
+ warm_img = Image.open(files[0]).convert("RGB")
355
+ for _ in range(2):
356
+ _ = MANAGER.run(model_name, warm_img)
357
+
358
+ rows = []
359
+ total_ms = 0.0
360
+ n_images = 0
361
+
362
+ for f in files:
363
+ img = Image.open(f).convert("RGB")
364
+ for r in range(repeats):
365
+ out, timing = MANAGER.run(model_name, img)
366
+ rows.append({
367
+ "file": os.path.basename(f),
368
+ "repeat": r + 1,
369
+ "total_ms": round(timing.total_ms, 2),
370
+ "inference_ms": round(timing.inference_ms, 2),
371
+ })
372
+ total_ms += timing.total_ms
373
+ n_images += 1
374
+
375
+ avg_ms = total_ms / max(1, n_images)
376
+ ips = 1000.0 / avg_ms if avg_ms > 0 else 0.0
377
+
378
+ summary = (
379
+ f"Model: {model_name}\n"
380
+ f"Images: {len(files)} (repeats={repeats}) => runs={n_images}\n"
381
+ f"Avg total: {avg_ms:.2f} ms\n"
382
+ f"Estimated throughput: {ips:.2f} images/sec\n"
383
+ f"Device: {'T4 GPU' if torch.cuda.is_available() else 'CPU'}"
384
+ )
385
+
386
+ df = gr.Dataframe(
387
+ headers=["file", "repeat", "total_ms", "inference_ms"],
388
+ value=[[r["file"], r["repeat"], r["total_ms"], r["inference_ms"]] for r in rows],
389
+ datatype=["str", "number", "number", "number"],
390
+ interactive=False
391
+ )
392
+ return df, summary
393
+
394
+
395
+ # ----------------------------
396
+ # UI
397
+ # ----------------------------
398
+
399
+ with gr.Blocks(title="Background Removal Benchmark (T4)") as demo:
400
+ gr.Markdown(
401
+ """
402
+ # Background Removal Benchmark (T4)
403
+
404
+ Benchmarked models:
405
+ 1) InSPyReNet
406
+ 2) BiRefNet
407
+ 3) U2Net
408
+ 4) BRIA RMBG 2.0
409
+ 5) IS-Net (isnet-general-use)
410
+
411
+ **Notes**
412
+ - Output download is a true transparent PNG (RGBA).
413
+ - The slider preview composites the transparent result over a checkerboard for visibility.
414
+ - For the benchmark tab, add **10–15 images** under `bench/` in your Space repo.
415
+ """
416
+ )
417
+
418
+ with gr.Tab("Try single image"):
419
+ with gr.Row():
420
+ with gr.Column(scale=1):
421
+ inp = gr.Image(type="pil", label="Upload image", height=420)
422
+ model = gr.Dropdown(choices=MODEL_CHOICES, value="InSPyReNet", label="Model")
423
+ run_btn = gr.Button("Run", variant="primary")
424
+ with gr.Column(scale=2):
425
+ slider = ImageSlider(label="Before / After", type="pil")
426
+ out_img = gr.Image(type="pil", label="Output (RGBA)", height=420)
427
+ timing_box = gr.Textbox(label="Timing", lines=5)
428
+ out_file = gr.File(label="Download PNG (transparent)")
429
+
430
+ run_btn.click(
431
+ fn=run_single,
432
+ inputs=[model, inp],
433
+ outputs=[slider, out_img, timing_box, out_file]
434
+ )
435
+
436
+ with gr.Tab("Benchmark (throughput estimate)"):
437
+ with gr.Row():
438
+ with gr.Column(scale=1):
439
+ bench_model = gr.Dropdown(choices=MODEL_CHOICES, value="InSPyReNet", label="Model")
440
+ repeats = gr.Slider(1, 5, value=1, step=1, label="Repeats per image (higher = more stable averages)")
441
+ bench_btn = gr.Button("Run benchmark", variant="primary")
442
+ with gr.Column(scale=2):
443
+ bench_table = gr.Dataframe(
444
+ headers=["file", "repeat", "total_ms", "inference_ms"],
445
+ datatype=["str", "number", "number", "number"],
446
+ interactive=False
447
+ )
448
+ bench_summary = gr.Textbox(label="Summary", lines=6)
449
+
450
+ bench_btn.click(
451
+ fn=run_benchmark,
452
+ inputs=[bench_model, repeats],
453
+ outputs=[bench_table, bench_summary]
454
+ )
455
+
456
+ # Examples (optional) — if these files exist, they show up like your sample Space
457
+ example_files = []
458
+ for f in ["1.jpg", "2.jpg", "3.png", "4.webp"]:
459
+ if os.path.exists(f):
460
+ example_files.append([f, "InSPyReNet"])
461
+ if example_files:
462
+ gr.Examples(
463
+ examples=example_files,
464
+ inputs=[inp, model],
465
+ label="Examples"
466
+ )
467
+
468
+ if __name__ == "__main__":
469
+ demo.launch(show_error=True)