Tyler Ng commited on
Commit
b13ba47
·
verified ·
1 Parent(s): f848020

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -51
app.py CHANGED
@@ -43,7 +43,6 @@ def ensure_rgba(pil: Image.Image) -> Image.Image:
43
 
44
 
45
  def make_checkerboard(w: int, h: int, block: int = 16) -> Image.Image:
46
- # Neutral checkerboard
47
  cols = int(math.ceil(w / block))
48
  rows = int(math.ceil(h / block))
49
  board = np.zeros((rows * block, cols * block, 3), dtype=np.uint8)
@@ -79,6 +78,11 @@ def now_ms() -> float:
79
  return time.perf_counter() * 1000.0
80
 
81
 
 
 
 
 
 
82
  @dataclass
83
  class Timing:
84
  preprocess_ms: float
@@ -109,7 +113,7 @@ class ModelManager:
109
  5) IS-Net (isnet-general-use) via rembg
110
  """
111
  def __init__(self):
112
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
113
  self._inspy: Optional[Remover] = None
114
 
115
  self._torch_models: Dict[str, AutoModelForImageSegmentation] = {}
@@ -132,17 +136,16 @@ class ModelManager:
132
  pass
133
 
134
  def _maybe_sync(self):
135
- if self.device == "cuda":
136
  torch.cuda.synchronize()
137
 
138
  def _load_inspy(self) -> Remover:
139
  if self._inspy is None:
140
- # jit=False like your sample
141
  self._inspy = Remover(jit=False)
142
  return self._inspy
143
 
144
  def _offload_torch_models_from_gpu(self, keep_name: str):
145
- if self.device != "cuda":
146
  return
147
  if self._torch_model_on_gpu and self._torch_model_on_gpu != keep_name:
148
  prev = self._torch_models.get(self._torch_model_on_gpu)
@@ -167,7 +170,7 @@ class ModelManager:
167
 
168
  m = AutoModelForImageSegmentation.from_pretrained(model_id, trust_remote_code=True)
169
  m.eval()
170
- # Keep on CPU initially; move to GPU on-demand to avoid T4 OOM.
171
  m.to("cpu")
172
  self._torch_models[key] = m
173
  return m
@@ -179,27 +182,28 @@ class ModelManager:
179
  if name in self._rembg_sessions:
180
  return self._rembg_sessions[name]
181
 
182
- # Prefer CUDA provider if onnxruntime-gpu is installed; otherwise CPU works.
183
- # rembg will pass this into onnxruntime internally.
184
- providers = None
185
  try:
186
- providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
187
  except Exception:
188
- providers = None
189
-
190
- sess = new_session(name, providers=providers) if providers else new_session(name)
191
  self._rembg_sessions[name] = sess
192
  return sess
193
 
194
  def _run_torch_alpha_model(self, model_key: str, image_rgb: Image.Image) -> Image.Image:
195
  """
196
- Runs a torch segmentation model that returns a single-channel mask (alpha matte-ish).
197
  Returns RGBA (with alpha).
198
  """
 
199
  m = self._load_torch_model(model_key)
200
 
201
  # Put model on GPU for inference if possible
202
- if self.device == "cuda":
203
  self._offload_torch_models_from_gpu(keep_name=model_key)
204
  if self._torch_model_on_gpu != model_key:
205
  m.to("cuda")
@@ -209,10 +213,10 @@ class ModelManager:
209
  orig_size = image_rgb.size
210
 
211
  x = self._tf_1024(image_rgb).unsqueeze(0)
212
- x = x.to(self.device)
213
 
214
  with torch.inference_mode():
215
- if self.device == "cuda":
216
  with torch.autocast(device_type="cuda", dtype=torch.float16):
217
  preds = m(x)[-1].sigmoid()
218
  else:
@@ -220,7 +224,7 @@ class ModelManager:
220
 
221
  # Convert prediction to PIL alpha channel
222
  pred = preds[0].squeeze().detach().float().cpu()
223
- alpha = transforms.ToPILImage()(pred).resize(orig_size)
224
 
225
  out = image_rgb.convert("RGBA")
226
  out.putalpha(alpha)
@@ -244,7 +248,6 @@ class ModelManager:
244
  inf0 = now_ms()
245
  if model_name == "InSPyReNet":
246
  remover = self._load_inspy()
247
- # The library returns various modes; we want alpha mask and apply ourselves for consistent output
248
  mask = remover.process(input_image, type="map")
249
  if isinstance(mask, Image.Image):
250
  mask = mask.convert("L")
@@ -259,17 +262,18 @@ class ModelManager:
259
 
260
  elif model_name == "U2Net":
261
  sess = self._get_rembg_session("u2net")
262
- # rembg returns bytes (PNG RGBA)
263
- out_bytes = rembg_remove(img_rgb, session=sess)
264
- out = Image.open(io.BytesIO(out_bytes)).convert("RGBA")
265
 
266
  elif model_name == "BRIA RMBG 2.0":
267
  out = self._run_torch_alpha_model("bria_rmbg_2", img_rgb)
268
 
269
  elif model_name == "IS-Net":
270
  sess = self._get_rembg_session("isnet-general-use")
271
- out_bytes = rembg_remove(img_rgb, session=sess)
272
- out = Image.open(io.BytesIO(out_bytes)).convert("RGBA")
 
273
 
274
  else:
275
  raise ValueError(f"Unknown model: {model_name}")
@@ -314,27 +318,20 @@ def run_single(model_name: str, image: Image.Image):
314
  if image is None:
315
  return None, None, "Upload an image first.", None
316
 
317
- # Warmup-ish for fairer timing (tiny; avoids huge overhead in UI)
318
- # Note: real benchmark tab does proper warmups.
319
  out_rgba, timing = MANAGER.run(model_name, image)
320
 
321
- # Slider wants (processed, original) or (after, before) depending on component;
322
- # we’ll show: left=original, right=on-checkerboard preview of transparent output.
323
  preview = rgba_on_checkerboard(out_rgba)
324
-
325
  out_path = save_temp_png(out_rgba)
326
  return (image, preview), out_rgba, timing.to_text(), out_path
327
 
328
 
329
  def list_bench_images() -> List[str]:
330
- # Put your 10–15 images under bench/
331
  exts = ("*.jpg", "*.jpeg", "*.png", "*.webp")
332
  files = []
333
  for e in exts:
334
  files += glob.glob(os.path.join("bench", e))
335
  files = sorted(files)
336
 
337
- # Fallback to repo-root examples like your sample Space
338
  if not files:
339
  fallback = []
340
  for f in ["1.jpg", "2.jpg", "3.png", "4.webp"]:
@@ -348,7 +345,8 @@ def list_bench_images() -> List[str]:
348
  def run_benchmark(model_name: str, repeats: int = 1):
349
  files = list_bench_images()
350
  if not files:
351
- return gr.Dataframe(value=[]), "No benchmark images found. Add 10–15 images under bench/."
 
352
 
353
  # Warmup: 2 runs on first image (not timed)
354
  warm_img = Image.open(files[0]).convert("RGB")
@@ -363,12 +361,12 @@ def run_benchmark(model_name: str, repeats: int = 1):
363
  img = Image.open(f).convert("RGB")
364
  for r in range(repeats):
365
  out, timing = MANAGER.run(model_name, img)
366
- rows.append({
367
- "file": os.path.basename(f),
368
- "repeat": r + 1,
369
- "total_ms": round(timing.total_ms, 2),
370
- "inference_ms": round(timing.inference_ms, 2),
371
- })
372
  total_ms += timing.total_ms
373
  n_images += 1
374
 
@@ -380,26 +378,21 @@ def run_benchmark(model_name: str, repeats: int = 1):
380
  f"Images: {len(files)} (repeats={repeats}) => runs={n_images}\n"
381
  f"Avg total: {avg_ms:.2f} ms\n"
382
  f"Estimated throughput: {ips:.2f} images/sec\n"
383
- f"Device: {'T4 GPU' if torch.cuda.is_available() else 'CPU'}"
384
  )
385
 
386
- df = gr.Dataframe(
387
- headers=["file", "repeat", "total_ms", "inference_ms"],
388
- value=[[r["file"], r["repeat"], r["total_ms"], r["inference_ms"]] for r in rows],
389
- datatype=["str", "number", "number", "number"],
390
- interactive=False
391
- )
392
- return df, summary
393
 
394
 
395
  # ----------------------------
396
  # UI
397
  # ----------------------------
398
 
399
- with gr.Blocks(title="Background Removal Benchmark (T4)") as demo:
400
  gr.Markdown(
401
  """
402
- # Background Removal Benchmark (T4)
403
 
404
  Benchmarked models:
405
  1) InSPyReNet
@@ -437,7 +430,7 @@ Benchmarked models:
437
  with gr.Row():
438
  with gr.Column(scale=1):
439
  bench_model = gr.Dropdown(choices=MODEL_CHOICES, value="InSPyReNet", label="Model")
440
- repeats = gr.Slider(1, 5, value=1, step=1, label="Repeats per image (higher = more stable averages)")
441
  bench_btn = gr.Button("Run benchmark", variant="primary")
442
  with gr.Column(scale=2):
443
  bench_table = gr.Dataframe(
@@ -453,7 +446,6 @@ Benchmarked models:
453
  outputs=[bench_table, bench_summary]
454
  )
455
 
456
- # Examples (optional) — if these files exist, they show up like your sample Space
457
  example_files = []
458
  for f in ["1.jpg", "2.jpg", "3.png", "4.webp"]:
459
  if os.path.exists(f):
@@ -466,4 +458,4 @@ Benchmarked models:
466
  )
467
 
468
  if __name__ == "__main__":
469
- demo.launch(show_error=True)
 
43
 
44
 
45
  def make_checkerboard(w: int, h: int, block: int = 16) -> Image.Image:
 
46
  cols = int(math.ceil(w / block))
47
  rows = int(math.ceil(h / block))
48
  board = np.zeros((rows * block, cols * block, 3), dtype=np.uint8)
 
78
  return time.perf_counter() * 1000.0
79
 
80
 
81
+ def get_device() -> str:
82
+ """Get device at runtime (important for ZeroGPU)."""
83
+ return "cuda" if torch.cuda.is_available() else "cpu"
84
+
85
+
86
  @dataclass
87
  class Timing:
88
  preprocess_ms: float
 
113
  5) IS-Net (isnet-general-use) via rembg
114
  """
115
  def __init__(self):
116
+ # NOTE: Don't cache device here - ZeroGPU allocates GPU later
117
  self._inspy: Optional[Remover] = None
118
 
119
  self._torch_models: Dict[str, AutoModelForImageSegmentation] = {}
 
136
  pass
137
 
138
  def _maybe_sync(self):
139
+ if get_device() == "cuda":
140
  torch.cuda.synchronize()
141
 
142
  def _load_inspy(self) -> Remover:
143
  if self._inspy is None:
 
144
  self._inspy = Remover(jit=False)
145
  return self._inspy
146
 
147
  def _offload_torch_models_from_gpu(self, keep_name: str):
148
+ if get_device() != "cuda":
149
  return
150
  if self._torch_model_on_gpu and self._torch_model_on_gpu != keep_name:
151
  prev = self._torch_models.get(self._torch_model_on_gpu)
 
170
 
171
  m = AutoModelForImageSegmentation.from_pretrained(model_id, trust_remote_code=True)
172
  m.eval()
173
+ # Keep on CPU initially; move to GPU on-demand
174
  m.to("cpu")
175
  self._torch_models[key] = m
176
  return m
 
182
  if name in self._rembg_sessions:
183
  return self._rembg_sessions[name]
184
 
185
+ # Prefer CUDA provider if available
186
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
187
+
188
  try:
189
+ sess = new_session(name, providers=providers)
190
  except Exception:
191
+ # Fallback to default providers
192
+ sess = new_session(name)
193
+
194
  self._rembg_sessions[name] = sess
195
  return sess
196
 
197
  def _run_torch_alpha_model(self, model_key: str, image_rgb: Image.Image) -> Image.Image:
198
  """
199
+ Runs a torch segmentation model that returns a single-channel mask.
200
  Returns RGBA (with alpha).
201
  """
202
+ device = get_device() # Check device at runtime!
203
  m = self._load_torch_model(model_key)
204
 
205
  # Put model on GPU for inference if possible
206
+ if device == "cuda":
207
  self._offload_torch_models_from_gpu(keep_name=model_key)
208
  if self._torch_model_on_gpu != model_key:
209
  m.to("cuda")
 
213
  orig_size = image_rgb.size
214
 
215
  x = self._tf_1024(image_rgb).unsqueeze(0)
216
+ x = x.to(device)
217
 
218
  with torch.inference_mode():
219
+ if device == "cuda":
220
  with torch.autocast(device_type="cuda", dtype=torch.float16):
221
  preds = m(x)[-1].sigmoid()
222
  else:
 
224
 
225
  # Convert prediction to PIL alpha channel
226
  pred = preds[0].squeeze().detach().float().cpu()
227
+ alpha = transforms.ToPILImage()(pred).resize(orig_size, Image.BILINEAR)
228
 
229
  out = image_rgb.convert("RGBA")
230
  out.putalpha(alpha)
 
248
  inf0 = now_ms()
249
  if model_name == "InSPyReNet":
250
  remover = self._load_inspy()
 
251
  mask = remover.process(input_image, type="map")
252
  if isinstance(mask, Image.Image):
253
  mask = mask.convert("L")
 
262
 
263
  elif model_name == "U2Net":
264
  sess = self._get_rembg_session("u2net")
265
+ # FIX: rembg returns PIL Image when given PIL Image, not bytes!
266
+ out = rembg_remove(img_rgb, session=sess)
267
+ out = ensure_rgba(out)
268
 
269
  elif model_name == "BRIA RMBG 2.0":
270
  out = self._run_torch_alpha_model("bria_rmbg_2", img_rgb)
271
 
272
  elif model_name == "IS-Net":
273
  sess = self._get_rembg_session("isnet-general-use")
274
+ # FIX: rembg returns PIL Image when given PIL Image, not bytes!
275
+ out = rembg_remove(img_rgb, session=sess)
276
+ out = ensure_rgba(out)
277
 
278
  else:
279
  raise ValueError(f"Unknown model: {model_name}")
 
318
  if image is None:
319
  return None, None, "Upload an image first.", None
320
 
 
 
321
  out_rgba, timing = MANAGER.run(model_name, image)
322
 
 
 
323
  preview = rgba_on_checkerboard(out_rgba)
 
324
  out_path = save_temp_png(out_rgba)
325
  return (image, preview), out_rgba, timing.to_text(), out_path
326
 
327
 
328
  def list_bench_images() -> List[str]:
 
329
  exts = ("*.jpg", "*.jpeg", "*.png", "*.webp")
330
  files = []
331
  for e in exts:
332
  files += glob.glob(os.path.join("bench", e))
333
  files = sorted(files)
334
 
 
335
  if not files:
336
  fallback = []
337
  for f in ["1.jpg", "2.jpg", "3.png", "4.webp"]:
 
345
  def run_benchmark(model_name: str, repeats: int = 1):
346
  files = list_bench_images()
347
  if not files:
348
+ # FIX: Return data values, not gr.Dataframe component
349
+ return [], "No benchmark images found. Add 10–15 images under bench/."
350
 
351
  # Warmup: 2 runs on first image (not timed)
352
  warm_img = Image.open(files[0]).convert("RGB")
 
361
  img = Image.open(f).convert("RGB")
362
  for r in range(repeats):
363
  out, timing = MANAGER.run(model_name, img)
364
+ rows.append([
365
+ os.path.basename(f),
366
+ r + 1,
367
+ round(timing.total_ms, 2),
368
+ round(timing.inference_ms, 2),
369
+ ])
370
  total_ms += timing.total_ms
371
  n_images += 1
372
 
 
378
  f"Images: {len(files)} (repeats={repeats}) => runs={n_images}\n"
379
  f"Avg total: {avg_ms:.2f} ms\n"
380
  f"Estimated throughput: {ips:.2f} images/sec\n"
381
+ f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}"
382
  )
383
 
384
+ # FIX: Return the data directly, not a gr.Dataframe component
385
+ return rows, summary
 
 
 
 
 
386
 
387
 
388
  # ----------------------------
389
  # UI
390
  # ----------------------------
391
 
392
+ with gr.Blocks(title="Background Removal Benchmark") as demo:
393
  gr.Markdown(
394
  """
395
+ # Background Removal Benchmark
396
 
397
  Benchmarked models:
398
  1) InSPyReNet
 
430
  with gr.Row():
431
  with gr.Column(scale=1):
432
  bench_model = gr.Dropdown(choices=MODEL_CHOICES, value="InSPyReNet", label="Model")
433
+ repeats = gr.Slider(1, 5, value=1, step=1, label="Repeats per image")
434
  bench_btn = gr.Button("Run benchmark", variant="primary")
435
  with gr.Column(scale=2):
436
  bench_table = gr.Dataframe(
 
446
  outputs=[bench_table, bench_summary]
447
  )
448
 
 
449
  example_files = []
450
  for f in ["1.jpg", "2.jpg", "3.png", "4.webp"]:
451
  if os.path.exists(f):
 
458
  )
459
 
460
  if __name__ == "__main__":
461
+ demo.launch(show_error=True)