LI Junxing commited on
Commit
e920cb4
·
1 Parent(s): 7063f69

Add hybrid alpha blending

Browse files
Files changed (3) hide show
  1. app.py +64 -24
  2. app_local.py +64 -24
  3. docs/rankseg_refine_foreground_analysis.md +455 -0
app.py CHANGED
@@ -167,13 +167,42 @@ def refine_foreground(image, mask, r=90, device='cuda'):
167
  return estimated_foreground
168
 
169
 
170
- def get_rankseg_mask(pred: torch.Tensor, metric: str) -> Image.Image:
171
  # BiRefNet produces a single foreground probability map, so RankSEG should
172
  # return a binary mask for that one channel instead of a multiclass map.
173
  rankseg = RankSEG(metric=metric, output_mode='multilabel', solver='RMA')
174
- probs = pred.unsqueeze(0).unsqueeze(0)
175
- rankseg_pred = rankseg.predict(probs).squeeze(0).squeeze(0).to(torch.float32)
176
- return transforms.ToPILImage()(rankseg_pred)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
 
179
  def build_masked_image(image: Image.Image, mask: Image.Image) -> Image.Image:
@@ -269,7 +298,8 @@ def predict(images, resolution, weights_file, enable_rankseg, rankseg_metric):
269
 
270
  if isinstance(images, list):
271
  raw_save_paths = []
272
- rankseg_save_paths = []
 
273
  save_dir = 'preds-BiRefNet'
274
  if not os.path.exists(save_dir):
275
  os.makedirs(save_dir)
@@ -295,37 +325,44 @@ def predict(images, resolution, weights_file, enable_rankseg, rankseg_metric):
295
 
296
  # Prediction
297
  with torch.no_grad():
298
- preds = birefnet(image_proc.to(device).half())[-1].sigmoid().cpu()
299
  pred = preds[0].squeeze()
300
 
301
  pred_pil = transforms.ToPILImage()(pred)
302
- raw_image_masked = build_masked_image(image, pred_pil)
303
- rankseg_image_masked = None
 
304
  if enable_rankseg:
305
- rankseg_mask = get_rankseg_mask(pred, rankseg_metric)
306
- rankseg_image_masked = build_masked_image(image, rankseg_mask)
 
 
307
 
308
  if device == 'cuda':
309
  torch.cuda.empty_cache()
310
 
311
  if tab_is_batch:
312
  image_name = os.path.splitext(os.path.basename(image_src))[0]
313
- raw_save_file_path = os.path.join(save_dir, f"{image_name}_raw.png")
314
  raw_image_masked.save(raw_save_file_path)
315
  raw_save_paths.append(raw_save_file_path)
316
- if enable_rankseg and rankseg_image_masked is not None:
317
- rankseg_save_file_path = os.path.join(save_dir, f"{image_name}_rankseg.png")
318
- rankseg_image_masked.save(rankseg_save_file_path)
319
- rankseg_save_paths.append(rankseg_save_file_path)
 
 
 
 
320
 
321
  if tab_is_batch:
322
  zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
323
  with zipfile.ZipFile(zip_file_path, 'w') as zipf:
324
- for file in raw_save_paths + rankseg_save_paths:
325
  zipf.write(file, os.path.basename(file))
326
- return raw_save_paths, rankseg_save_paths, zip_file_path
327
  else:
328
- return image, raw_image_masked, rankseg_image_masked
329
 
330
 
331
  examples = [[_] for _ in glob('examples/*')][:]
@@ -361,8 +398,9 @@ tab_image = gr.Interface(
361
  ],
362
  outputs=[
363
  gr.Image(label="Original image", type="pil", format='png'),
364
- gr.Image(label="BiRefNet result", type="pil", format='png'),
365
- gr.Image(label="BiRefNet + RankSEG", type="pil", format='png'),
 
366
  ],
367
  examples=examples,
368
  api_name="image",
@@ -380,8 +418,9 @@ tab_text = gr.Interface(
380
  ],
381
  outputs=[
382
  gr.Image(label="Original image", type="pil", format='png'),
383
- gr.Image(label="BiRefNet result", type="pil", format='png'),
384
- gr.Image(label="BiRefNet + RankSEG", type="pil", format='png'),
 
385
  ],
386
  examples=examples_url,
387
  api_name="URL",
@@ -398,8 +437,9 @@ tab_batch = gr.Interface(
398
  gr.Radio(RANKSEG_METRICS, value='dice', label="RankSEG metric", info="Choose the target metric for RankSEG post-processing.")
399
  ],
400
  outputs=[
401
- gr.Gallery(label="BiRefNet results"),
402
- gr.Gallery(label="BiRefNet + RankSEG results"),
 
403
  gr.File(label="Download masked images."),
404
  ],
405
  api_name="batch",
 
167
  return estimated_foreground
168
 
169
 
170
+ def get_rankseg_pred(pred: torch.Tensor, metric: str) -> torch.Tensor:
171
  # BiRefNet produces a single foreground probability map, so RankSEG should
172
  # return a binary mask for that one channel instead of a multiclass map.
173
  rankseg = RankSEG(metric=metric, output_mode='multilabel', solver='RMA')
174
+ probs = pred.unsqueeze(0).unsqueeze(0).to(torch.float32)
175
+ return rankseg.predict(probs).squeeze(0).squeeze(0).to(torch.float32)
176
+
177
+
178
+ def get_soft_gate(rankseg_mask: torch.Tensor, dilate_kernel: int = 9, blur_kernel: int = 15) -> torch.Tensor:
179
+ support = rankseg_mask.unsqueeze(0).unsqueeze(0).to(torch.float32)
180
+ dilated = torch.nn.functional.max_pool2d(
181
+ support,
182
+ kernel_size=dilate_kernel,
183
+ stride=1,
184
+ padding=dilate_kernel // 2,
185
+ )
186
+ soft_gate = torch.nn.functional.avg_pool2d(
187
+ dilated,
188
+ kernel_size=blur_kernel,
189
+ stride=1,
190
+ padding=blur_kernel // 2,
191
+ )
192
+ return soft_gate.squeeze(0).squeeze(0).clamp(0, 1)
193
+
194
+
195
+ def build_hybrid_alpha(pred: torch.Tensor, rankseg_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
196
+ soft_gate = get_soft_gate(rankseg_mask)
197
+ hard_alpha = (pred * rankseg_mask.to(torch.float32)).clamp(0, 1)
198
+ soft_alpha = torch.where(rankseg_mask > 0, pred, pred * soft_gate).clamp(0, 1)
199
+ return hard_alpha, soft_alpha
200
+
201
+
202
+ def build_alpha_cutout(image: Image.Image, mask: Image.Image) -> Image.Image:
203
+ output = image.copy()
204
+ output.putalpha(mask.resize(image.size))
205
+ return output
206
 
207
 
208
  def build_masked_image(image: Image.Image, mask: Image.Image) -> Image.Image:
 
298
 
299
  if isinstance(images, list):
300
  raw_save_paths = []
301
+ rankseg_hard_save_paths = []
302
+ rankseg_soft_save_paths = []
303
  save_dir = 'preds-BiRefNet'
304
  if not os.path.exists(save_dir):
305
  os.makedirs(save_dir)
 
325
 
326
  # Prediction
327
  with torch.no_grad():
328
+ preds = birefnet(image_proc.to(device).half())[-1].sigmoid().float().cpu()
329
  pred = preds[0].squeeze()
330
 
331
  pred_pil = transforms.ToPILImage()(pred)
332
+ raw_image_masked = build_alpha_cutout(image, pred_pil)
333
+ rankseg_hard_image_masked = None
334
+ rankseg_soft_image_masked = None
335
  if enable_rankseg:
336
+ rankseg_pred = get_rankseg_pred(pred, rankseg_metric)
337
+ hard_alpha, soft_alpha = build_hybrid_alpha(pred, rankseg_pred)
338
+ rankseg_hard_image_masked = build_alpha_cutout(image, transforms.ToPILImage()(hard_alpha))
339
+ rankseg_soft_image_masked = build_alpha_cutout(image, transforms.ToPILImage()(soft_alpha))
340
 
341
  if device == 'cuda':
342
  torch.cuda.empty_cache()
343
 
344
  if tab_is_batch:
345
  image_name = os.path.splitext(os.path.basename(image_src))[0]
346
+ raw_save_file_path = os.path.join(save_dir, f"{image_name}_pred.png")
347
  raw_image_masked.save(raw_save_file_path)
348
  raw_save_paths.append(raw_save_file_path)
349
+ if enable_rankseg and rankseg_hard_image_masked is not None:
350
+ rankseg_hard_save_file_path = os.path.join(save_dir, f"{image_name}_pred_rankseg.png")
351
+ rankseg_hard_image_masked.save(rankseg_hard_save_file_path)
352
+ rankseg_hard_save_paths.append(rankseg_hard_save_file_path)
353
+ if enable_rankseg and rankseg_soft_image_masked is not None:
354
+ rankseg_soft_save_file_path = os.path.join(save_dir, f"{image_name}_pred_softgate.png")
355
+ rankseg_soft_image_masked.save(rankseg_soft_save_file_path)
356
+ rankseg_soft_save_paths.append(rankseg_soft_save_file_path)
357
 
358
  if tab_is_batch:
359
  zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
360
  with zipfile.ZipFile(zip_file_path, 'w') as zipf:
361
+ for file in raw_save_paths + rankseg_hard_save_paths + rankseg_soft_save_paths:
362
  zipf.write(file, os.path.basename(file))
363
+ return raw_save_paths, rankseg_hard_save_paths, rankseg_soft_save_paths, zip_file_path
364
  else:
365
+ return image, raw_image_masked, rankseg_hard_image_masked, rankseg_soft_image_masked
366
 
367
 
368
  examples = [[_] for _ in glob('examples/*')][:]
 
398
  ],
399
  outputs=[
400
  gr.Image(label="Original image", type="pil", format='png'),
401
+ gr.Image(label="BiRefNet pred alpha", type="pil", format='png'),
402
+ gr.Image(label="BiRefNet pred x RankSEG", type="pil", format='png'),
403
+ gr.Image(label="BiRefNet soft-gated hybrid", type="pil", format='png'),
404
  ],
405
  examples=examples,
406
  api_name="image",
 
418
  ],
419
  outputs=[
420
  gr.Image(label="Original image", type="pil", format='png'),
421
+ gr.Image(label="BiRefNet pred alpha", type="pil", format='png'),
422
+ gr.Image(label="BiRefNet pred x RankSEG", type="pil", format='png'),
423
+ gr.Image(label="BiRefNet soft-gated hybrid", type="pil", format='png'),
424
  ],
425
  examples=examples_url,
426
  api_name="URL",
 
437
  gr.Radio(RANKSEG_METRICS, value='dice', label="RankSEG metric", info="Choose the target metric for RankSEG post-processing.")
438
  ],
439
  outputs=[
440
+ gr.Gallery(label="BiRefNet pred alpha results"),
441
+ gr.Gallery(label="BiRefNet pred x RankSEG results"),
442
+ gr.Gallery(label="BiRefNet soft-gated hybrid results"),
443
  gr.File(label="Download masked images."),
444
  ],
445
  api_name="batch",
app_local.py CHANGED
@@ -164,13 +164,42 @@ def refine_foreground(image, mask, r=90, device='cuda'):
164
  return estimated_foreground
165
 
166
 
167
- def get_rankseg_mask(pred: torch.Tensor, metric: str) -> Image.Image:
168
  # BiRefNet produces a single foreground probability map, so RankSEG should
169
  # return a binary mask for that one channel instead of a multiclass map.
170
  rankseg = RankSEG(metric=metric, output_mode='multilabel', solver='RMA')
171
- probs = pred.unsqueeze(0).unsqueeze(0)
172
- rankseg_pred = rankseg.predict(probs).squeeze(0).squeeze(0).to(torch.float32)
173
- return transforms.ToPILImage()(rankseg_pred)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
 
176
  def build_masked_image(image: Image.Image, mask: Image.Image) -> Image.Image:
@@ -265,7 +294,8 @@ def predict(images, resolution, weights_file, enable_rankseg, rankseg_metric):
265
 
266
  if isinstance(images, list):
267
  raw_save_paths = []
268
- rankseg_save_paths = []
 
269
  save_dir = 'preds-BiRefNet'
270
  if not os.path.exists(save_dir):
271
  os.makedirs(save_dir)
@@ -291,37 +321,44 @@ def predict(images, resolution, weights_file, enable_rankseg, rankseg_metric):
291
 
292
  # Prediction
293
  with torch.no_grad():
294
- preds = birefnet(image_proc.to(device).half())[-1].sigmoid().cpu()
295
  pred = preds[0].squeeze()
296
 
297
  pred_pil = transforms.ToPILImage()(pred)
298
- raw_image_masked = build_masked_image(image, pred_pil)
299
- rankseg_image_masked = None
 
300
  if enable_rankseg:
301
- rankseg_mask = get_rankseg_mask(pred, rankseg_metric)
302
- rankseg_image_masked = build_masked_image(image, rankseg_mask)
 
 
303
 
304
  if device == 'cuda':
305
  torch.cuda.empty_cache()
306
 
307
  if tab_is_batch:
308
  image_name = os.path.splitext(os.path.basename(image_src))[0]
309
- raw_save_file_path = os.path.join(save_dir, f"{image_name}_raw.png")
310
  raw_image_masked.save(raw_save_file_path)
311
  raw_save_paths.append(raw_save_file_path)
312
- if enable_rankseg and rankseg_image_masked is not None:
313
- rankseg_save_file_path = os.path.join(save_dir, f"{image_name}_rankseg.png")
314
- rankseg_image_masked.save(rankseg_save_file_path)
315
- rankseg_save_paths.append(rankseg_save_file_path)
 
 
 
 
316
 
317
  if tab_is_batch:
318
  zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
319
  with zipfile.ZipFile(zip_file_path, 'w') as zipf:
320
- for file in raw_save_paths + rankseg_save_paths:
321
  zipf.write(file, os.path.basename(file))
322
- return raw_save_paths, rankseg_save_paths, zip_file_path
323
  else:
324
- return image, raw_image_masked, rankseg_image_masked
325
 
326
 
327
  examples = [[_] for _ in glob('examples/*')][:]
@@ -357,8 +394,9 @@ tab_image = gr.Interface(
357
  ],
358
  outputs=[
359
  gr.Image(label="Original image", type="pil", format='png'),
360
- gr.Image(label="BiRefNet result", type="pil", format='png'),
361
- gr.Image(label="BiRefNet + RankSEG", type="pil", format='png'),
 
362
  ],
363
  examples=examples,
364
  api_name="image",
@@ -376,8 +414,9 @@ tab_text = gr.Interface(
376
  ],
377
  outputs=[
378
  gr.Image(label="Original image", type="pil", format='png'),
379
- gr.Image(label="BiRefNet result", type="pil", format='png'),
380
- gr.Image(label="BiRefNet + RankSEG", type="pil", format='png'),
 
381
  ],
382
  examples=examples_url,
383
  api_name="URL",
@@ -394,8 +433,9 @@ tab_batch = gr.Interface(
394
  gr.Radio(RANKSEG_METRICS, value='dice', label="RankSEG metric", info="Choose the target metric for RankSEG post-processing.")
395
  ],
396
  outputs=[
397
- gr.Gallery(label="BiRefNet results"),
398
- gr.Gallery(label="BiRefNet + RankSEG results"),
 
399
  gr.File(label="Download masked images."),
400
  ],
401
  api_name="batch",
 
164
  return estimated_foreground
165
 
166
 
167
+ def get_rankseg_pred(pred: torch.Tensor, metric: str) -> torch.Tensor:
168
  # BiRefNet produces a single foreground probability map, so RankSEG should
169
  # return a binary mask for that one channel instead of a multiclass map.
170
  rankseg = RankSEG(metric=metric, output_mode='multilabel', solver='RMA')
171
+ probs = pred.unsqueeze(0).unsqueeze(0).to(torch.float32)
172
+ return rankseg.predict(probs).squeeze(0).squeeze(0).to(torch.float32)
173
+
174
+
175
+ def get_soft_gate(rankseg_mask: torch.Tensor, dilate_kernel: int = 9, blur_kernel: int = 15) -> torch.Tensor:
176
+ support = rankseg_mask.unsqueeze(0).unsqueeze(0).to(torch.float32)
177
+ dilated = torch.nn.functional.max_pool2d(
178
+ support,
179
+ kernel_size=dilate_kernel,
180
+ stride=1,
181
+ padding=dilate_kernel // 2,
182
+ )
183
+ soft_gate = torch.nn.functional.avg_pool2d(
184
+ dilated,
185
+ kernel_size=blur_kernel,
186
+ stride=1,
187
+ padding=blur_kernel // 2,
188
+ )
189
+ return soft_gate.squeeze(0).squeeze(0).clamp(0, 1)
190
+
191
+
192
+ def build_hybrid_alpha(pred: torch.Tensor, rankseg_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
193
+ soft_gate = get_soft_gate(rankseg_mask)
194
+ hard_alpha = (pred * rankseg_mask.to(torch.float32)).clamp(0, 1)
195
+ soft_alpha = torch.where(rankseg_mask > 0, pred, pred * soft_gate).clamp(0, 1)
196
+ return hard_alpha, soft_alpha
197
+
198
+
199
+ def build_alpha_cutout(image: Image.Image, mask: Image.Image) -> Image.Image:
200
+ output = image.copy()
201
+ output.putalpha(mask.resize(image.size))
202
+ return output
203
 
204
 
205
  def build_masked_image(image: Image.Image, mask: Image.Image) -> Image.Image:
 
294
 
295
  if isinstance(images, list):
296
  raw_save_paths = []
297
+ rankseg_hard_save_paths = []
298
+ rankseg_soft_save_paths = []
299
  save_dir = 'preds-BiRefNet'
300
  if not os.path.exists(save_dir):
301
  os.makedirs(save_dir)
 
321
 
322
  # Prediction
323
  with torch.no_grad():
324
+ preds = birefnet(image_proc.to(device).half())[-1].sigmoid().float().cpu()
325
  pred = preds[0].squeeze()
326
 
327
  pred_pil = transforms.ToPILImage()(pred)
328
+ raw_image_masked = build_alpha_cutout(image, pred_pil)
329
+ rankseg_hard_image_masked = None
330
+ rankseg_soft_image_masked = None
331
  if enable_rankseg:
332
+ rankseg_pred = get_rankseg_pred(pred, rankseg_metric)
333
+ hard_alpha, soft_alpha = build_hybrid_alpha(pred, rankseg_pred)
334
+ rankseg_hard_image_masked = build_alpha_cutout(image, transforms.ToPILImage()(hard_alpha))
335
+ rankseg_soft_image_masked = build_alpha_cutout(image, transforms.ToPILImage()(soft_alpha))
336
 
337
  if device == 'cuda':
338
  torch.cuda.empty_cache()
339
 
340
  if tab_is_batch:
341
  image_name = os.path.splitext(os.path.basename(image_src))[0]
342
+ raw_save_file_path = os.path.join(save_dir, f"{image_name}_pred.png")
343
  raw_image_masked.save(raw_save_file_path)
344
  raw_save_paths.append(raw_save_file_path)
345
+ if enable_rankseg and rankseg_hard_image_masked is not None:
346
+ rankseg_hard_save_file_path = os.path.join(save_dir, f"{image_name}_pred_rankseg.png")
347
+ rankseg_hard_image_masked.save(rankseg_hard_save_file_path)
348
+ rankseg_hard_save_paths.append(rankseg_hard_save_file_path)
349
+ if enable_rankseg and rankseg_soft_image_masked is not None:
350
+ rankseg_soft_save_file_path = os.path.join(save_dir, f"{image_name}_pred_softgate.png")
351
+ rankseg_soft_image_masked.save(rankseg_soft_save_file_path)
352
+ rankseg_soft_save_paths.append(rankseg_soft_save_file_path)
353
 
354
  if tab_is_batch:
355
  zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
356
  with zipfile.ZipFile(zip_file_path, 'w') as zipf:
357
+ for file in raw_save_paths + rankseg_hard_save_paths + rankseg_soft_save_paths:
358
  zipf.write(file, os.path.basename(file))
359
+ return raw_save_paths, rankseg_hard_save_paths, rankseg_soft_save_paths, zip_file_path
360
  else:
361
+ return image, raw_image_masked, rankseg_hard_image_masked, rankseg_soft_image_masked
362
 
363
 
364
  examples = [[_] for _ in glob('examples/*')][:]
 
394
  ],
395
  outputs=[
396
  gr.Image(label="Original image", type="pil", format='png'),
397
+ gr.Image(label="BiRefNet pred alpha", type="pil", format='png'),
398
+ gr.Image(label="BiRefNet pred x RankSEG", type="pil", format='png'),
399
+ gr.Image(label="BiRefNet soft-gated hybrid", type="pil", format='png'),
400
  ],
401
  examples=examples,
402
  api_name="image",
 
414
  ],
415
  outputs=[
416
  gr.Image(label="Original image", type="pil", format='png'),
417
+ gr.Image(label="BiRefNet pred alpha", type="pil", format='png'),
418
+ gr.Image(label="BiRefNet pred x RankSEG", type="pil", format='png'),
419
+ gr.Image(label="BiRefNet soft-gated hybrid", type="pil", format='png'),
420
  ],
421
  examples=examples_url,
422
  api_name="URL",
 
433
  gr.Radio(RANKSEG_METRICS, value='dice', label="RankSEG metric", info="Choose the target metric for RankSEG post-processing.")
434
  ],
435
  outputs=[
436
+ gr.Gallery(label="BiRefNet pred alpha results"),
437
+ gr.Gallery(label="BiRefNet pred x RankSEG results"),
438
+ gr.Gallery(label="BiRefNet soft-gated hybrid results"),
439
  gr.File(label="Download masked images."),
440
  ],
441
  api_name="batch",
docs/rankseg_refine_foreground_analysis.md ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RankSEG 与 `refine_foreground` 管线分析报告
2
+
3
+ ## 1. 结论摘要
4
+
5
+ 1. `refine_foreground` 支持 soft mask,也支持 binary mask。它在实现上把 `mask` 当作 `alpha` 使用,要求数值范围在 `[0, 1]`;binary mask 只是 soft mask 的特例。
6
+ 2. 当前 `rankseg` 版本的 `RankSEG.predict()` 不输出 soft mask。官方文档和本地源码都表明,它返回的是离散预测,值为 `0/1` 或布尔值,而不是新的概率图。
7
+ 3. 对二分类单通道场景,RankSEG 不是简单地把固定阈值 `0.5` 换成另一个全局阈值 `tau`。更准确地说,它会基于整张图的概率排序,求一个当前图像、当前类别对应的最优 `top-k` 截断位置 `opt_tau`,再把前 `k` 个像素设为前景。
8
+ 4. 因此,`refine_foreground(rankseg_mask)` 在代码上是成立的,但它用到的是 hard alpha。这样更利于区域判定干净,不一定更利于发丝、半透明边缘等 soft alpha 细节。
9
+ 5. 如果目标是“让 RankSEG 的区域选择能力也帮助最终抠图质量”,更合理的方案通常不是要求 RankSEG 直接输出 soft mask,而是让 RankSEG 作为区域约束,原始 `pred` 继续提供 soft alpha。
10
+
11
+ ## 2. 问题与回答
12
+
13
+ ### Q1. `refine_foreground` 支持 soft mask 吗?
14
+
15
+ 支持。
16
+
17
+ 本地实现里,`refine_foreground` 明确写着:
18
+
19
+ - `image` 和 `mask` 都应在 `[0, 1]` 范围内
20
+ - `mask` 会被转成 `float32`
21
+ - 后续计算里 `mask` 以 `alpha` 形式参与模糊、混合和前景估计
22
+
23
+ 这说明它期望的是连续 alpha,而不是仅限于 0/1 标签。
24
+
25
+ ### Q2. `refine_foreground` 能接受 binary mask 吗?
26
+
27
+ 能。
28
+
29
+ binary mask 的取值是 `{0, 1}`,本身就是 `[0, 1]` 的子集,因此在数学上和实现上都合法。只是 binary mask 不携带边界过渡信息,所以对前景颜色估计来说通常不如 soft mask 丰富。
30
+
31
+ ### Q3. RankSEG 能输出 soft mask 吗?
32
+
33
+ 就当前项目安装的 `rankseg==0.0.4` 而言,不能。
34
+
35
+ `RankSEG.predict()` 的官方文档和本地源码都写明它返回的是 binary segmentation predictions。`output_mode='multilabel'` 的含义是“每个类返回一张二值 mask”,不是“输出 soft mask”。
36
+
37
+ ### Q4. RankSEG 是否等价于“找一个更好的全局阈值 `tau`”?
38
+
39
+ 不完全是。
40
+
41
+ 对你现在的单图、单类、binary segmentation 使用场景,可以把它近似理解成“找到一个比 0.5 更合适的自适应截断位置”。但从源码看,它优化的是排序后的 `top-k` 选择,不是直接在原概率空间里学习一个固定全局阈值。
42
+
43
+ 更精确地说:
44
+
45
+ - `0.5 threshold` 是固定规则
46
+ - RankSEG 是对当前图像的所有像素概率排序
47
+ - 然后根据目标指标估计最佳 `opt_tau`
48
+ - 最终保留前 `opt_tau` 个像素为前景
49
+
50
+ 在单通道情况下,这个过程可以等价成一个“当前图像专属”的隐式阈值切点,但不是数据集级别的全局常数阈值。
51
+
52
+ ## 3. `refine_foreground` 的来源、术语与技术脉络
53
+
54
+ ### 3.1 `FB` 是什么意思?
55
+
56
+ 这里的 `F` 和 `B` 是 matting / compositing 领域里的经典记号:
57
+
58
+ - `F` = Foreground,前景颜色
59
+ - `B` = Background,背景颜色
60
+ - `alpha` = 透明度或前景占比
61
+
62
+ 经典合成公式可以写成:
63
+
64
+ ```text
65
+ I = alpha * F + (1 - alpha) * B
66
+ ```
67
+
68
+ 其中:
69
+
70
+ - `I` 是观测到的输入图像
71
+ - `F` 是希望恢复的真实前景颜色
72
+ - `B` 是背景颜色
73
+ - `alpha` 控制前景和背景的混合比例
74
+
75
+ 所以这份代码里的 `FB_blur_fusion_foreground_estimator_*`,名字可以直接读成:
76
+
77
+ ```text
78
+ 一种基于模糊融合(blur fusion)的 F/B 前景估计器
79
+ ```
80
+
81
+ 它的目标不是再次估计分割类别,而是在已知 `image` 和 `alpha(mask)` 后,估计更干净的 `F`。
82
+
83
+ ### 3.2 当前实现最直接的代码来源
84
+
85
+ 从项目代码注释看,当前 `refine_foreground` 链路最直接依赖了两个来源:
86
+
87
+ 1. Photoroom 的 `fast-foreground-estimation`
88
+ 2. BiRefNet 社区里的 GPU 改写版本
89
+
90
+ 项目本地注释已经写明:
91
+
92
+ - CPU 版本参考了 Photoroom 仓库
93
+ - GPU 双阶段版本参考了 BiRefNet issue comment
94
+
95
+ 对应代码见 [app.py](/Users/lev1s/Documents/BiRefNet_demo/app.py#L137)。
96
+
97
+ BiRefNet 官方仓库后续也明确写到,`refine_foreground` 的提速来自 “the GPU implementation of fast-fg-est”。这说明当前仓库里的 `refine_foreground`,本质上不是 BiRefNet 原论文提出的新理论模块,而是把现有 foreground estimation 方法接到了推理后处理中。
98
+
99
+ ### 3.3 直接可追溯的论文来源
100
+
101
+ 当前这条实现链最清楚的论文来源分两层:
102
+
103
+ #### 第一层:近似实现
104
+
105
+ Photoroom 的开源仓库 `fast-foreground-estimation` 明确写着,它是论文 “Approximate Fast Foreground Colour Estimation” 的官方仓库,作者是 Marco Forte,发表于 ICIP 2021。
106
+
107
+ 这个仓库 README 又明确说明:
108
+
109
+ - 该方法是一个很快的 foreground estimation technique
110
+ - 它“yields comparable results to the full approach [1], while also being faster”
111
+ - 其中 `[1]` 指向的就是 `Fast Multi-Level Foreground Estimation`
112
+
113
+ 也就是说,当前代码中这种非常短小、基于 blur fusion 的实现,最直接对应的是:
114
+
115
+ ```text
116
+ Marco Forte, "Approximate Fast Foreground Colour Estimation", ICIP 2021
117
+ ```
118
+
119
+ #### 第二层:更早的完整方法
120
+
121
+ Photoroom 官方材料和 PyMatting 文档都把上面的近似法指向了更早的完整方法:
122
+
123
+ ```text
124
+ Thomas Germer, Tobias Uelwer, Stefan Conrad, Stefan Harmeling,
125
+ "Fast Multi-Level Foreground Estimation", ICPR 2020
126
+ ```
127
+
128
+ 这篇工作讨论的是:
129
+
130
+ - 已知 alpha matte
131
+ - 如何估计 foreground colours
132
+ - 以避免直接抠图后出现边缘 bleed-through
133
+
134
+ 所以从“研究脉络”上讲:
135
+
136
+ ```text
137
+ 当前 refine_foreground
138
+ <- 工程上更接近 Photoroom 的 Approximate Fast Foreground Colour Estimation
139
+ <- 理论上又是对 Germer 等人 Fast Multi-Level Foreground Estimation 的近似实现
140
+ ```
141
+
142
+ ### 3.4 它和传统 matting 文献的关系
143
+
144
+ 这条线属于 image matting / foreground estimation,而不是纯 segmentation。
145
+
146
+ 更具体地说,它属于:
147
+
148
+ - 已知 alpha 或已有 mask
149
+ - 进一步恢复前景颜色 `F`
150
+ - 让最终合成结果更自然
151
+
152
+ 这和只输出一个二值 mask 的分割工作不同,也和只估计 alpha matte 的方法不同。它解决的是:
153
+
154
+ ```text
155
+ “mask 有了之后,怎么恢复更干净的前景颜色”。
156
+ ```
157
+
158
+ 所以它在定位上更接近:
159
+
160
+ - foreground colour estimation
161
+ - alpha matting 的后处理
162
+ - compositing quality enhancement
163
+
164
+ 而不是:
165
+
166
+ - segmentation mask optimization
167
+
168
+ ### 3.5 对当前项目的准确表述
169
+
170
+ 如果在报告或提交说明里需要一句准确表述,我建议写成:
171
+
172
+ ```text
173
+ 本项目中的 refine_foreground 并非 BiRefNet 原始分割网络的一部分,
174
+ 而是一个用于前景颜色估计的后处理模块。
175
+ 其当前实现最直接来源于 Photoroom 开源的
176
+ "Approximate Fast Foreground Colour Estimation"(Marco Forte, ICIP 2021),
177
+ 该方法又是对 Germer 等人
178
+ "Fast Multi-Level Foreground Estimation"(ICPR 2020)
179
+ 的快速近似。
180
+ ```
181
+
182
+ ## 4. 源码证据
183
+
184
+ ### 4.1 `refine_foreground` 如何使用 mask
185
+
186
+ 项目里的实现见 [app.py](/Users/lev1s/Documents/BiRefNet_demo/app.py#L143)。
187
+
188
+ 关键信息:
189
+
190
+ ```python
191
+ def refine_foreground(image, mask, r=90, device='cuda'):
192
+ """both image and mask are in range of [0, 1]"""
193
+ ```
194
+
195
+ 以及:
196
+
197
+ ```python
198
+ mask = transforms.functional.to_tensor(mask).float().cuda()
199
+ ```
200
+
201
+ 和:
202
+
203
+ ```python
204
+ blurred_alpha = mean_blur(alpha, kernel_size=r)
205
+ blurred_FGA = mean_blur(FG * alpha, kernel_size=r)
206
+ blurred_B1A = mean_blur(B * (1 - alpha), kernel_size=r)
207
+ ```
208
+
209
+ 这些实现说明:
210
+
211
+ 1. `mask` 被当成连续 `alpha`
212
+ 2. 算法内部直接使用 `alpha` 和 `1 - alpha`
213
+ 3. soft alpha 会影响模糊融合结果
214
+ 4. binary alpha 也可运行,但只是一种退化情况
215
+
216
+ ### 4.2 当前项目中 RankSEG 的接法
217
+
218
+ 见 [app.py](/Users/lev1s/Documents/BiRefNet_demo/app.py#L170):
219
+
220
+ ```python
221
+ def get_rankseg_mask(pred: torch.Tensor, metric: str) -> Image.Image:
222
+ rankseg = RankSEG(metric=metric, output_mode='multilabel', solver='RMA')
223
+ probs = pred.unsqueeze(0).unsqueeze(0).to(torch.float32)
224
+ rankseg_pred = rankseg.predict(probs).squeeze(0).squeeze(0).to(torch.float32)
225
+ return transforms.ToPILImage()(rankseg_pred)
226
+ ```
227
+
228
+ 这里输入 `probs` 是 `(1, 1, H, W)` 的 soft probability map,但 `rankseg_pred` 是离散预测后再转回 `float32` 以便转图像,不代表它重新变成了 soft mask。
229
+
230
+ ### 4.3 RankSEG 官方语义
231
+
232
+ 本地安装包源码见 [rankseg/_rankseg.py](/Users/lev1s/Documents/BiRefNet_demo/.venv/lib/python3.12/site-packages/rankseg/_rankseg.py#L94):
233
+
234
+ ```python
235
+ def predict(self, probs):
236
+ """Convert probability maps to binary segmentation predictions.
237
+ ```
238
+
239
+ 返回说明见同文件 [rankseg/_rankseg.py](/Users/lev1s/Documents/BiRefNet_demo/.venv/lib/python3.12/site-packages/rankseg/_rankseg.py#L105):
240
+
241
+ ```python
242
+ preds : torch.Tensor
243
+ Binary segmentation predictions ...
244
+ Values are 0 or 1 (or boolean True/False depending on solver).
245
+ ```
246
+
247
+ 这已经直接排除了“当前 `predict()` 输出 soft mask”的解释。
248
+
249
+ ### 4.4 RankSEG 的 `tau` 到底是什么
250
+
251
+ 实现见 [rankseg/_rankseg_algo.py](/Users/lev1s/Documents/BiRefNet_demo/.venv/lib/python3.12/site-packages/rankseg/_rankseg_algo.py#L250)。
252
+
253
+ 关键逻辑:
254
+
255
+ ```python
256
+ opt_tau = torch.argmax(metric_values, dim=-1) + 1
257
+ ```
258
+
259
+ 以及:
260
+
261
+ ```python
262
+ overlap_preds[b, c, top_index[b, c, :opt_tau[b, c]]] = True
263
+ ```
264
+
265
+ 这说明:
266
+
267
+ 1. 先对概率从大到小排序
268
+ 2. 估计不同 `tau` 下的指标值
269
+ 3. 选择最优 `opt_tau`
270
+ 4. 将前 `opt_tau` 个像素直接置为前景
271
+
272
+ 因此 `tau` 是“保留多少个最高概率像素”的截断位置,本质上是 `top-k` 选择,不是简单地把所有像素与某个固定常数阈值比较。
273
+
274
+ ## 5. 解释:为什么 `refine_foreground` 更偏爱 soft mask
275
+
276
+ `refine_foreground` 做的不是再次分割,而是前景颜色估计。它本质上在解下面这类问题:
277
+
278
+ ```text
279
+ image ~= foreground * alpha + background * (1 - alpha)
280
+ ```
281
+
282
+ 这里的 `alpha` 越连续,算法越能利用边缘过渡信息去恢复更自然的前景颜色。
283
+
284
+ 所以:
285
+
286
+ - soft mask 更利于柔边、发丝、半透明边缘
287
+ - binary mask 更利于轮廓果断、区域清晰
288
+
289
+ 两者优化目标不同:
290
+
291
+ - RankSEG 更偏向 mask 指标优化,如 Dice / IoU
292
+ - `refine_foreground` 更偏向最终抠图观感优化
293
+
294
+ ## 6. 回答你的核心判断
295
+
296
+ ### 5.1 “如果 RankSEG 输出 soft mask,就能更好接上 `refine_foreground` 吗?”
297
+
298
+ 原则上是的。
299
+
300
+ 如果 RankSEG 能输出经过其全局优化后的 soft alpha,那它会比 hard mask 更适合 `refine_foreground`。但当前版本并没有这样的 API。
301
+
302
+ ### 5.2 “既然 RankSEG 不能输出 soft mask,那当前 `refine_foreground(rankseg_mask)` 是否没有意义?”
303
+
304
+ 不是没有意义,而是意义不同。
305
+
306
+ 当前做法的价值是:
307
+
308
+ - 利用 RankSEG 提升区域级的前景/背景判定
309
+ - 再利用 `refine_foreground` 改善前景颜色和合成效果
310
+
311
+ 但它的限制也很明确:
312
+
313
+ - `refine_foreground` 用到的是 hard alpha
314
+ - 边缘渐变信息已经丢失
315
+ - 所以它无法完全发挥 soft alpha matting 的优势
316
+
317
+ ## 7. 最合理的工程建议
318
+
319
+ ### 方案 A. 保持现状
320
+
321
+ 管线:
322
+
323
+ ```text
324
+ pred (soft prob) -> RankSEG -> binary mask -> refine_foreground
325
+ ```
326
+
327
+ 优点:
328
+
329
+ - 实现简单
330
+ - 区域判定更干净
331
+
332
+ 缺点:
333
+
334
+ - 边缘容易变硬
335
+ - 发丝、半透明细节可能不如原始 soft mask
336
+
337
+ 适合:
338
+
339
+ - 更重视区域分割是否准确
340
+ - 更像“分割后抠图”
341
+
342
+ ### 方案 B. 原始 soft mask 直接抠图
343
+
344
+ 管线:
345
+
346
+ ```text
347
+ pred (soft prob) -> refine_foreground
348
+ ```
349
+
350
+ 优点:
351
+
352
+ - 边缘更自然
353
+ - 更像 matting
354
+
355
+ 缺点:
356
+
357
+ - 区域级误检会直接进入最终结果
358
+
359
+ 适合:
360
+
361
+ - 更重视最终视觉观感
362
+
363
+ ### 方案 C. 推荐的混合方案
364
+
365
+ 管线:
366
+
367
+ ```text
368
+ pred (soft prob) -> RankSEG(binary support)
369
+ -> hybrid_alpha = pred * rankseg_mask
370
+ -> refine_foreground(hybrid_alpha)
371
+ ```
372
+
373
+ 或更保守地:
374
+
375
+ ```python
376
+ hybrid_alpha = torch.where(rankseg_mask > 0, pred, torch.zeros_like(pred))
377
+ ```
378
+
379
+ 这个思路的含义是:
380
+
381
+ - RankSEG 负责“哪些区域允许成为前景”
382
+ - 原始 `pred` 负责“这些前景区域内部的透明度变化”
383
+
384
+ 优点:
385
+
386
+ - 同时利用 RankSEG 的区域约束能力
387
+ - 保留原始 soft mask 的边缘渐变
388
+
389
+ 这是当前代码架构下最有希望同时兼顾 Dice/IoU 与视觉质量的方案。
390
+
391
+ ## 8. 评估建议
392
+
393
+ 如果你的目标是论文式或工程式对比,建议把评估拆成两组,不要混为一谈。
394
+
395
+ ### 8.1 分割质量评估
396
+
397
+ 比较对象:
398
+
399
+ - 原始 `pred` 阈值化后的 mask
400
+ - RankSEG 输出 mask
401
+ - 混合方案的最终 alpha 再阈值化的 mask
402
+
403
+ 指标:
404
+
405
+ - IoU
406
+ - Dice
407
+ - Pixel Accuracy
408
+
409
+ 注意:
410
+
411
+ - `refine_foreground` 不应作为 IoU/Dice 的评估对象
412
+ - 因为它输出的是前景图像,不是 mask 优化本身
413
+
414
+ ### 8.2 抠图观感评估
415
+
416
+ 比较对象:
417
+
418
+ - raw soft mask + refine
419
+ - rankseg binary mask + refine
420
+ - hybrid alpha + refine
421
+
422
+ 关注点:
423
+
424
+ - 发丝
425
+ - 透明边缘
426
+ - 白边/黑边
427
+ - 前景颜色污染
428
+
429
+ 如果有 matting 数据,可以进一步评估 alpha 或前景误差;如果没有,至少做稳定的可视化对比。
430
+
431
+ ## 9. 最终结论
432
+
433
+ 1. `refine_foreground` 同时支持 soft mask 与 binary mask,但从算法性质上更适合 soft alpha。
434
+ 2. 当前 RankSEG 版本不输出 soft mask,只输出离散预测。
435
+ 3. RankSEG 不应被简单理解成“把 0.5 换成一个更好的固定阈值”;它更接近于基于概率排序和目标指标的自适应 `top-k` 截断。
436
+ 4. 因此,“让 `refine_foreground` 也增强 RankSEG 的效果”这件事,不应期待 RankSEG 直接给 soft mask,而应考虑让 RankSEG 提供区域约束、让原始 `pred` 提供 soft alpha。
437
+ 5. 在当前工程里,最值得实验的下一步是混合方案,而不是继续追问 RankSEG 是否已有 soft output API。
438
+
439
+ ## 10. 参考来源
440
+
441
+ ### 本地代码
442
+
443
+ - [app.py](/Users/lev1s/Documents/BiRefNet_demo/app.py)
444
+ - [rankseg/_rankseg.py](/Users/lev1s/Documents/BiRefNet_demo/.venv/lib/python3.12/site-packages/rankseg/_rankseg.py)
445
+ - [rankseg/_rankseg_algo.py](/Users/lev1s/Documents/BiRefNet_demo/.venv/lib/python3.12/site-packages/rankseg/_rankseg_algo.py)
446
+
447
+ ### 外部资料
448
+
449
+ - RankSEG getting started: [GitHub](https://github.com/rankseg/rankseg/blob/main/doc/source/getting_started.md)
450
+ - RankSEG auto API: [GitHub](https://github.com/rankseg/rankseg/blob/main/doc/source/autoapi/rankseg/rankseg/index.md)
451
+ - BiRefNet repository: [GitHub](https://github.com/ZhengPeng7/BiRefNet)
452
+ - Approximate Fast Foreground Colour Estimation repository: [GitHub](https://github.com/Photoroom/fast-foreground-estimation)
453
+ - Approximate Fast Foreground Colour Estimation talk page: [IEEE SPS Resource Center](https://resourcecenter.ieee.org/conferences/icip-2021/spsicip21vid215)
454
+ - Fast Multi-Level Foreground Estimation reference entry: [DBLP](https://dblp.org/rec/conf/icpr/GermerU0H20)
455
+ - PyMatting foreground estimation docs: [PyMatting](https://pymatting.github.io/)