Tyler Ng commited on
Commit
779ae5b
·
verified ·
1 Parent(s): 5250542

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -124
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import io
3
  import time
4
  import glob
5
  import math
@@ -15,11 +14,12 @@ from PIL import Image
15
  import torch
16
  from torchvision import transforms
17
  from transformers import AutoModelForImageSegmentation
 
18
 
19
- # Slider component (same one BRIA Space uses)
20
  from gradio_imageslider import ImageSlider
21
 
22
- # InSPyReNet wrapper (same approach as your sample Space)
23
  from transparent_background import Remover
24
 
25
  # rembg (U2Net + IS-Net via ONNX)
@@ -46,17 +46,12 @@ def make_checkerboard(w: int, h: int, block: int = 16) -> Image.Image:
46
  cols = int(math.ceil(w / block))
47
  rows = int(math.ceil(h / block))
48
  board = np.zeros((rows * block, cols * block, 3), dtype=np.uint8)
49
-
50
- c1 = np.array([235, 235, 235], dtype=np.uint8)
51
- c2 = np.array([200, 200, 200], dtype=np.uint8)
52
-
53
  for r in range(rows):
54
  for c in range(cols):
55
  color = c1 if (r + c) % 2 == 0 else c2
56
  board[r * block:(r + 1) * block, c * block:(c + 1) * block] = color
57
-
58
- board = board[:h, :w, :]
59
- return Image.fromarray(board, mode="RGB")
60
 
61
 
62
  def rgba_on_checkerboard(rgba: Image.Image) -> Image.Image:
@@ -104,32 +99,19 @@ class Timing:
104
  # ----------------------------
105
 
106
  class ModelManager:
107
- """
108
- Loads and runs:
109
- 1) InSPyReNet via transparent_background.Remover()
110
- 2) BiRefNet via AutoModelForImageSegmentation("ZhengPeng7/BiRefNet", trust_remote_code=True)
111
- 3) U2Net via rembg (onnxruntime; can use CUDA provider if available)
112
- 4) BRIA RMBG 2.0 via AutoModelForImageSegmentation("briaai/RMBG-2.0", trust_remote_code=True)
113
- 5) IS-Net (isnet-general-use) via rembg
114
- """
115
  def __init__(self):
116
- # NOTE: Don't cache device here - ZeroGPU allocates GPU later
117
  self._inspy: Optional[Remover] = None
118
-
119
- self._torch_models: Dict[str, AutoModelForImageSegmentation] = {}
120
  self._torch_model_on_gpu: Optional[str] = None
121
-
122
- # rembg sessions
123
  self._rembg_sessions: Dict[str, object] = {}
 
124
 
125
- # Common transforms for BiRefNet / BRIA RMBG inference
126
  self._tf_1024 = transforms.Compose([
127
  transforms.Resize((1024, 1024)),
128
  transforms.ToTensor(),
129
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
130
  ])
131
 
132
- # Try to set matmul precision nicely
133
  try:
134
  torch.set_float32_matmul_precision("high")
135
  except Exception:
@@ -154,55 +136,69 @@ class ModelManager:
154
  self._torch_model_on_gpu = None
155
  torch.cuda.empty_cache()
156
 
157
- def _load_torch_model(self, key: str) -> AutoModelForImageSegmentation:
158
- """
159
- key in {"birefnet", "bria_rmbg_2"}
160
- """
161
  if key in self._torch_models:
162
  return self._torch_models[key]
 
 
 
163
 
164
- if key == "birefnet":
165
- model_id = "ZhengPeng7/BiRefNet"
166
- elif key == "bria_rmbg_2":
167
- model_id = "briaai/RMBG-2.0"
168
- else:
169
- raise ValueError(f"Unknown torch model key: {key}")
 
170
 
171
- m = AutoModelForImageSegmentation.from_pretrained(model_id, trust_remote_code=True)
172
- m.eval()
173
- # Keep on CPU initially; move to GPU on-demand
174
- m.to("cpu")
175
- self._torch_models[key] = m
176
- return m
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  def _get_rembg_session(self, name: str):
179
- """
180
- name: "u2net" or "isnet-general-use"
181
- """
182
  if name in self._rembg_sessions:
183
  return self._rembg_sessions[name]
184
 
185
- # Prefer CUDA provider if available
186
  providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
187
-
188
  try:
189
  sess = new_session(name, providers=providers)
190
  except Exception:
191
- # Fallback to default providers
192
  sess = new_session(name)
193
 
194
  self._rembg_sessions[name] = sess
195
  return sess
196
 
197
  def _run_torch_alpha_model(self, model_key: str, image_rgb: Image.Image) -> Image.Image:
198
- """
199
- Runs a torch segmentation model that returns a single-channel mask.
200
- Returns RGBA (with alpha).
201
- """
202
- device = get_device() # Check device at runtime!
203
  m = self._load_torch_model(model_key)
204
 
205
- # Put model on GPU for inference if possible
206
  if device == "cuda":
207
  self._offload_torch_models_from_gpu(keep_name=model_key)
208
  if self._torch_model_on_gpu != model_key:
@@ -212,8 +208,7 @@ class ModelManager:
212
  image_rgb = pil_to_rgb(image_rgb)
213
  orig_size = image_rgb.size
214
 
215
- x = self._tf_1024(image_rgb).unsqueeze(0)
216
- x = x.to(device)
217
 
218
  with torch.inference_mode():
219
  if device == "cuda":
@@ -222,7 +217,6 @@ class ModelManager:
222
  else:
223
  preds = m(x)[-1].sigmoid()
224
 
225
- # Convert prediction to PIL alpha channel
226
  pred = preds[0].squeeze().detach().float().cpu()
227
  alpha = transforms.ToPILImage()(pred).resize(orig_size, Image.BILINEAR)
228
 
@@ -231,21 +225,19 @@ class ModelManager:
231
  return out
232
 
233
  def run(self, model_name: str, input_image: Image.Image) -> Tuple[Image.Image, Timing]:
234
- """
235
- Returns (output_rgba, timing).
236
- """
237
  if input_image is None:
238
  raise ValueError("No input image")
239
 
240
  t0 = now_ms()
241
 
242
- # --- preprocess ---
243
  pre0 = now_ms()
244
  img_rgb = pil_to_rgb(input_image)
245
  pre1 = now_ms()
246
 
247
- # --- inference ---
248
  inf0 = now_ms()
 
249
  if model_name == "InSPyReNet":
250
  remover = self._load_inspy()
251
  mask = remover.process(input_image, type="map")
@@ -253,7 +245,6 @@ class ModelManager:
253
  mask = mask.convert("L")
254
  else:
255
  mask = Image.fromarray((mask * 255).astype(np.uint8), mode="L")
256
-
257
  out = img_rgb.convert("RGBA")
258
  out.putalpha(mask)
259
 
@@ -262,7 +253,6 @@ class ModelManager:
262
 
263
  elif model_name == "U2Net":
264
  sess = self._get_rembg_session("u2net")
265
- # FIX: rembg returns PIL Image when given PIL Image, not bytes!
266
  out = rembg_remove(img_rgb, session=sess)
267
  out = ensure_rgba(out)
268
 
@@ -271,18 +261,16 @@ class ModelManager:
271
 
272
  elif model_name == "IS-Net":
273
  sess = self._get_rembg_session("isnet-general-use")
274
- # FIX: rembg returns PIL Image when given PIL Image, not bytes!
275
  out = rembg_remove(img_rgb, session=sess)
276
  out = ensure_rgba(out)
277
 
278
  else:
279
  raise ValueError(f"Unknown model: {model_name}")
280
 
281
- # Make sure GPU timing is accurate
282
  self._maybe_sync()
283
  inf1 = now_ms()
284
 
285
- # --- postprocess ---
286
  post0 = now_ms()
287
  out = ensure_rgba(out)
288
  post1 = now_ms()
@@ -318,11 +306,15 @@ def run_single(model_name: str, image: Image.Image):
318
  if image is None:
319
  return None, None, "Upload an image first.", None
320
 
321
- out_rgba, timing = MANAGER.run(model_name, image)
322
-
323
- preview = rgba_on_checkerboard(out_rgba)
324
- out_path = save_temp_png(out_rgba)
325
- return (image, preview), out_rgba, timing.to_text(), out_path
 
 
 
 
326
 
327
 
328
  def list_bench_images() -> List[str]:
@@ -333,11 +325,9 @@ def list_bench_images() -> List[str]:
333
  files = sorted(files)
334
 
335
  if not files:
336
- fallback = []
337
  for f in ["1.jpg", "2.jpg", "3.png", "4.webp"]:
338
  if os.path.exists(f):
339
- fallback.append(f)
340
- files = fallback
341
  return files
342
 
343
 
@@ -345,44 +335,47 @@ def list_bench_images() -> List[str]:
345
  def run_benchmark(model_name: str, repeats: int = 1):
346
  files = list_bench_images()
347
  if not files:
348
- # FIX: Return data values, not gr.Dataframe component
349
  return [], "No benchmark images found. Add 10–15 images under bench/."
350
 
351
- # Warmup: 2 runs on first image (not timed)
352
- warm_img = Image.open(files[0]).convert("RGB")
353
- for _ in range(2):
354
- _ = MANAGER.run(model_name, warm_img)
355
-
356
- rows = []
357
- total_ms = 0.0
358
- n_images = 0
359
-
360
- for f in files:
361
- img = Image.open(f).convert("RGB")
362
- for r in range(repeats):
363
- out, timing = MANAGER.run(model_name, img)
364
- rows.append([
365
- os.path.basename(f),
366
- r + 1,
367
- round(timing.total_ms, 2),
368
- round(timing.inference_ms, 2),
369
- ])
370
- total_ms += timing.total_ms
371
- n_images += 1
372
-
373
- avg_ms = total_ms / max(1, n_images)
374
- ips = 1000.0 / avg_ms if avg_ms > 0 else 0.0
375
-
376
- summary = (
377
- f"Model: {model_name}\n"
378
- f"Images: {len(files)} (repeats={repeats}) => runs={n_images}\n"
379
- f"Avg total: {avg_ms:.2f} ms\n"
380
- f"Estimated throughput: {ips:.2f} images/sec\n"
381
- f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}"
382
- )
383
-
384
- # FIX: Return the data directly, not a gr.Dataframe component
385
- return rows, summary
 
 
 
 
386
 
387
 
388
  # ----------------------------
@@ -395,16 +388,18 @@ with gr.Blocks(title="Background Removal Benchmark") as demo:
395
  # Background Removal Benchmark
396
 
397
  Benchmarked models:
398
- 1) InSPyReNet
399
- 2) BiRefNet
400
- 3) U2Net
401
- 4) BRIA RMBG 2.0
402
- 5) IS-Net (isnet-general-use)
403
 
404
  **Notes**
405
- - Output download is a true transparent PNG (RGBA).
406
- - The slider preview composites the transparent result over a checkerboard for visibility.
407
- - For the benchmark tab, add **10–15 images** under `bench/` in your Space repo.
 
 
408
  """
409
  )
410
 
@@ -417,7 +412,7 @@ Benchmarked models:
417
  with gr.Column(scale=2):
418
  slider = ImageSlider(label="Before / After", type="pil")
419
  out_img = gr.Image(type="pil", label="Output (RGBA)", height=420)
420
- timing_box = gr.Textbox(label="Timing", lines=5)
421
  out_file = gr.File(label="Download PNG (transparent)")
422
 
423
  run_btn.click(
@@ -451,11 +446,7 @@ Benchmarked models:
451
  if os.path.exists(f):
452
  example_files.append([f, "InSPyReNet"])
453
  if example_files:
454
- gr.Examples(
455
- examples=example_files,
456
- inputs=[inp, model],
457
- label="Examples"
458
- )
459
 
460
  if __name__ == "__main__":
461
  demo.launch(show_error=True)
 
1
  import os
 
2
  import time
3
  import glob
4
  import math
 
14
  import torch
15
  from torchvision import transforms
16
  from transformers import AutoModelForImageSegmentation
17
+ from huggingface_hub import hf_hub_download
18
 
19
+ # Slider component
20
  from gradio_imageslider import ImageSlider
21
 
22
+ # InSPyReNet wrapper
23
  from transparent_background import Remover
24
 
25
  # rembg (U2Net + IS-Net via ONNX)
 
46
  cols = int(math.ceil(w / block))
47
  rows = int(math.ceil(h / block))
48
  board = np.zeros((rows * block, cols * block, 3), dtype=np.uint8)
49
+ c1, c2 = np.array([235, 235, 235], dtype=np.uint8), np.array([200, 200, 200], dtype=np.uint8)
 
 
 
50
  for r in range(rows):
51
  for c in range(cols):
52
  color = c1 if (r + c) % 2 == 0 else c2
53
  board[r * block:(r + 1) * block, c * block:(c + 1) * block] = color
54
+ return Image.fromarray(board[:h, :w, :], mode="RGB")
 
 
55
 
56
 
57
  def rgba_on_checkerboard(rgba: Image.Image) -> Image.Image:
 
99
  # ----------------------------
100
 
101
  class ModelManager:
 
 
 
 
 
 
 
 
102
  def __init__(self):
 
103
  self._inspy: Optional[Remover] = None
104
+ self._torch_models: Dict[str, torch.nn.Module] = {}
 
105
  self._torch_model_on_gpu: Optional[str] = None
 
 
106
  self._rembg_sessions: Dict[str, object] = {}
107
+ self._model_load_errors: Dict[str, str] = {}
108
 
 
109
  self._tf_1024 = transforms.Compose([
110
  transforms.Resize((1024, 1024)),
111
  transforms.ToTensor(),
112
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
113
  ])
114
 
 
115
  try:
116
  torch.set_float32_matmul_precision("high")
117
  except Exception:
 
136
  self._torch_model_on_gpu = None
137
  torch.cuda.empty_cache()
138
 
139
+ def _load_torch_model(self, key: str) -> torch.nn.Module:
140
+ """Load BiRefNet or BRIA RMBG 2.0 model."""
 
 
141
  if key in self._torch_models:
142
  return self._torch_models[key]
143
+
144
+ if key in self._model_load_errors:
145
+ raise RuntimeError(self._model_load_errors[key])
146
 
147
+ model_configs = {
148
+ "birefnet": "ZhengPeng7/BiRefNet",
149
+ "bria_rmbg_2": "briaai/RMBG-2.0",
150
+ }
151
+
152
+ if key not in model_configs:
153
+ raise ValueError(f"Unknown model key: {key}")
154
 
155
+ model_id = model_configs[key]
156
+
157
+ try:
158
+ m = AutoModelForImageSegmentation.from_pretrained(
159
+ model_id,
160
+ trust_remote_code=True
161
+ )
162
+ m.eval()
163
+ m.to("cpu")
164
+ self._torch_models[key] = m
165
+ return m
166
+ except OSError as e:
167
+ error_msg = str(e)
168
+ if "gated" in error_msg.lower() or "401" in error_msg or "access" in error_msg.lower():
169
+ self._model_load_errors[key] = (
170
+ f"Model '{model_id}' requires license acceptance.\n"
171
+ f"1. Go to https://huggingface.co/{model_id}\n"
172
+ f"2. Accept the license agreement\n"
173
+ f"3. Add HF_TOKEN secret to your Space settings"
174
+ )
175
+ else:
176
+ self._model_load_errors[key] = f"Failed to load {model_id}: {error_msg}"
177
+ raise RuntimeError(self._model_load_errors[key])
178
+ except ImportError as e:
179
+ self._model_load_errors[key] = (
180
+ f"Import error loading {model_id}: {e}\n"
181
+ f"Make sure 'timm' is in requirements.txt"
182
+ )
183
+ raise RuntimeError(self._model_load_errors[key])
184
 
185
  def _get_rembg_session(self, name: str):
 
 
 
186
  if name in self._rembg_sessions:
187
  return self._rembg_sessions[name]
188
 
 
189
  providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
 
190
  try:
191
  sess = new_session(name, providers=providers)
192
  except Exception:
 
193
  sess = new_session(name)
194
 
195
  self._rembg_sessions[name] = sess
196
  return sess
197
 
198
  def _run_torch_alpha_model(self, model_key: str, image_rgb: Image.Image) -> Image.Image:
199
+ device = get_device()
 
 
 
 
200
  m = self._load_torch_model(model_key)
201
 
 
202
  if device == "cuda":
203
  self._offload_torch_models_from_gpu(keep_name=model_key)
204
  if self._torch_model_on_gpu != model_key:
 
208
  image_rgb = pil_to_rgb(image_rgb)
209
  orig_size = image_rgb.size
210
 
211
+ x = self._tf_1024(image_rgb).unsqueeze(0).to(device)
 
212
 
213
  with torch.inference_mode():
214
  if device == "cuda":
 
217
  else:
218
  preds = m(x)[-1].sigmoid()
219
 
 
220
  pred = preds[0].squeeze().detach().float().cpu()
221
  alpha = transforms.ToPILImage()(pred).resize(orig_size, Image.BILINEAR)
222
 
 
225
  return out
226
 
227
  def run(self, model_name: str, input_image: Image.Image) -> Tuple[Image.Image, Timing]:
 
 
 
228
  if input_image is None:
229
  raise ValueError("No input image")
230
 
231
  t0 = now_ms()
232
 
233
+ # Preprocess
234
  pre0 = now_ms()
235
  img_rgb = pil_to_rgb(input_image)
236
  pre1 = now_ms()
237
 
238
+ # Inference
239
  inf0 = now_ms()
240
+
241
  if model_name == "InSPyReNet":
242
  remover = self._load_inspy()
243
  mask = remover.process(input_image, type="map")
 
245
  mask = mask.convert("L")
246
  else:
247
  mask = Image.fromarray((mask * 255).astype(np.uint8), mode="L")
 
248
  out = img_rgb.convert("RGBA")
249
  out.putalpha(mask)
250
 
 
253
 
254
  elif model_name == "U2Net":
255
  sess = self._get_rembg_session("u2net")
 
256
  out = rembg_remove(img_rgb, session=sess)
257
  out = ensure_rgba(out)
258
 
 
261
 
262
  elif model_name == "IS-Net":
263
  sess = self._get_rembg_session("isnet-general-use")
 
264
  out = rembg_remove(img_rgb, session=sess)
265
  out = ensure_rgba(out)
266
 
267
  else:
268
  raise ValueError(f"Unknown model: {model_name}")
269
 
 
270
  self._maybe_sync()
271
  inf1 = now_ms()
272
 
273
+ # Postprocess
274
  post0 = now_ms()
275
  out = ensure_rgba(out)
276
  post1 = now_ms()
 
306
  if image is None:
307
  return None, None, "Upload an image first.", None
308
 
309
+ try:
310
+ out_rgba, timing = MANAGER.run(model_name, image)
311
+ preview = rgba_on_checkerboard(out_rgba)
312
+ out_path = save_temp_png(out_rgba)
313
+ return (image, preview), out_rgba, timing.to_text(), out_path
314
+ except RuntimeError as e:
315
+ return None, None, f"Error: {str(e)}", None
316
+ except Exception as e:
317
+ return None, None, f"Unexpected error: {str(e)}", None
318
 
319
 
320
  def list_bench_images() -> List[str]:
 
325
  files = sorted(files)
326
 
327
  if not files:
 
328
  for f in ["1.jpg", "2.jpg", "3.png", "4.webp"]:
329
  if os.path.exists(f):
330
+ files.append(f)
 
331
  return files
332
 
333
 
 
335
  def run_benchmark(model_name: str, repeats: int = 1):
336
  files = list_bench_images()
337
  if not files:
 
338
  return [], "No benchmark images found. Add 10–15 images under bench/."
339
 
340
+ try:
341
+ # Warmup
342
+ warm_img = Image.open(files[0]).convert("RGB")
343
+ for _ in range(2):
344
+ _ = MANAGER.run(model_name, warm_img)
345
+
346
+ rows = []
347
+ total_ms = 0.0
348
+ n_images = 0
349
+
350
+ for f in files:
351
+ img = Image.open(f).convert("RGB")
352
+ for r in range(repeats):
353
+ out, timing = MANAGER.run(model_name, img)
354
+ rows.append([
355
+ os.path.basename(f),
356
+ r + 1,
357
+ round(timing.total_ms, 2),
358
+ round(timing.inference_ms, 2),
359
+ ])
360
+ total_ms += timing.total_ms
361
+ n_images += 1
362
+
363
+ avg_ms = total_ms / max(1, n_images)
364
+ ips = 1000.0 / avg_ms if avg_ms > 0 else 0.0
365
+
366
+ summary = (
367
+ f"Model: {model_name}\n"
368
+ f"Images: {len(files)} (repeats={repeats}) => runs={n_images}\n"
369
+ f"Avg total: {avg_ms:.2f} ms\n"
370
+ f"Estimated throughput: {ips:.2f} images/sec\n"
371
+ f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}"
372
+ )
373
+ return rows, summary
374
+
375
+ except RuntimeError as e:
376
+ return [], f"Error: {str(e)}"
377
+ except Exception as e:
378
+ return [], f"Unexpected error: {str(e)}"
379
 
380
 
381
  # ----------------------------
 
388
  # Background Removal Benchmark
389
 
390
  Benchmarked models:
391
+ 1. **InSPyReNet** — transparent-background library
392
+ 2. **BiRefNet** — ZhengPeng7/BiRefNet (requires `timm`)
393
+ 3. **U2Net** — via rembg/ONNX
394
+ 4. **BRIA RMBG 2.0** — briaai/RMBG-2.0 (requires license acceptance)
395
+ 5. **IS-Net** isnet-general-use via rembg
396
 
397
  **Notes**
398
+ - Output is true transparent PNG (RGBA)
399
+ - Slider preview shows result on checkerboard
400
+ - For benchmarks, add images under `bench/` folder
401
+
402
+ ⚠️ **BRIA RMBG 2.0**: Requires accepting license at [huggingface.co/briaai/RMBG-2.0](https://huggingface.co/briaai/RMBG-2.0) and adding `HF_TOKEN` secret to Space settings.
403
  """
404
  )
405
 
 
412
  with gr.Column(scale=2):
413
  slider = ImageSlider(label="Before / After", type="pil")
414
  out_img = gr.Image(type="pil", label="Output (RGBA)", height=420)
415
+ timing_box = gr.Textbox(label="Timing / Errors", lines=5)
416
  out_file = gr.File(label="Download PNG (transparent)")
417
 
418
  run_btn.click(
 
446
  if os.path.exists(f):
447
  example_files.append([f, "InSPyReNet"])
448
  if example_files:
449
+ gr.Examples(examples=example_files, inputs=[inp, model], label="Examples")
 
 
 
 
450
 
451
  if __name__ == "__main__":
452
  demo.launch(show_error=True)