Spaces:
Running
Running
LI Junxing commited on
Commit ·
e920cb4
1
Parent(s): 7063f69
Add hybrid alpha blending
Browse files- app.py +64 -24
- app_local.py +64 -24
- docs/rankseg_refine_foreground_analysis.md +455 -0
app.py
CHANGED
|
@@ -167,13 +167,42 @@ def refine_foreground(image, mask, r=90, device='cuda'):
|
|
| 167 |
return estimated_foreground
|
| 168 |
|
| 169 |
|
| 170 |
-
def
|
| 171 |
# BiRefNet produces a single foreground probability map, so RankSEG should
|
| 172 |
# return a binary mask for that one channel instead of a multiclass map.
|
| 173 |
rankseg = RankSEG(metric=metric, output_mode='multilabel', solver='RMA')
|
| 174 |
-
probs = pred.unsqueeze(0).unsqueeze(0)
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
|
| 179 |
def build_masked_image(image: Image.Image, mask: Image.Image) -> Image.Image:
|
|
@@ -269,7 +298,8 @@ def predict(images, resolution, weights_file, enable_rankseg, rankseg_metric):
|
|
| 269 |
|
| 270 |
if isinstance(images, list):
|
| 271 |
raw_save_paths = []
|
| 272 |
-
|
|
|
|
| 273 |
save_dir = 'preds-BiRefNet'
|
| 274 |
if not os.path.exists(save_dir):
|
| 275 |
os.makedirs(save_dir)
|
|
@@ -295,37 +325,44 @@ def predict(images, resolution, weights_file, enable_rankseg, rankseg_metric):
|
|
| 295 |
|
| 296 |
# Prediction
|
| 297 |
with torch.no_grad():
|
| 298 |
-
preds = birefnet(image_proc.to(device).half())[-1].sigmoid().cpu()
|
| 299 |
pred = preds[0].squeeze()
|
| 300 |
|
| 301 |
pred_pil = transforms.ToPILImage()(pred)
|
| 302 |
-
raw_image_masked =
|
| 303 |
-
|
|
|
|
| 304 |
if enable_rankseg:
|
| 305 |
-
|
| 306 |
-
|
|
|
|
|
|
|
| 307 |
|
| 308 |
if device == 'cuda':
|
| 309 |
torch.cuda.empty_cache()
|
| 310 |
|
| 311 |
if tab_is_batch:
|
| 312 |
image_name = os.path.splitext(os.path.basename(image_src))[0]
|
| 313 |
-
raw_save_file_path = os.path.join(save_dir, f"{image_name}
|
| 314 |
raw_image_masked.save(raw_save_file_path)
|
| 315 |
raw_save_paths.append(raw_save_file_path)
|
| 316 |
-
if enable_rankseg and
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
if tab_is_batch:
|
| 322 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
| 323 |
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
| 324 |
-
for file in raw_save_paths +
|
| 325 |
zipf.write(file, os.path.basename(file))
|
| 326 |
-
return raw_save_paths,
|
| 327 |
else:
|
| 328 |
-
return image, raw_image_masked,
|
| 329 |
|
| 330 |
|
| 331 |
examples = [[_] for _ in glob('examples/*')][:]
|
|
@@ -361,8 +398,9 @@ tab_image = gr.Interface(
|
|
| 361 |
],
|
| 362 |
outputs=[
|
| 363 |
gr.Image(label="Original image", type="pil", format='png'),
|
| 364 |
-
gr.Image(label="BiRefNet
|
| 365 |
-
gr.Image(label="BiRefNet
|
|
|
|
| 366 |
],
|
| 367 |
examples=examples,
|
| 368 |
api_name="image",
|
|
@@ -380,8 +418,9 @@ tab_text = gr.Interface(
|
|
| 380 |
],
|
| 381 |
outputs=[
|
| 382 |
gr.Image(label="Original image", type="pil", format='png'),
|
| 383 |
-
gr.Image(label="BiRefNet
|
| 384 |
-
gr.Image(label="BiRefNet
|
|
|
|
| 385 |
],
|
| 386 |
examples=examples_url,
|
| 387 |
api_name="URL",
|
|
@@ -398,8 +437,9 @@ tab_batch = gr.Interface(
|
|
| 398 |
gr.Radio(RANKSEG_METRICS, value='dice', label="RankSEG metric", info="Choose the target metric for RankSEG post-processing.")
|
| 399 |
],
|
| 400 |
outputs=[
|
| 401 |
-
gr.Gallery(label="BiRefNet results"),
|
| 402 |
-
gr.Gallery(label="BiRefNet
|
|
|
|
| 403 |
gr.File(label="Download masked images."),
|
| 404 |
],
|
| 405 |
api_name="batch",
|
|
|
|
| 167 |
return estimated_foreground
|
| 168 |
|
| 169 |
|
| 170 |
+
def get_rankseg_pred(pred: torch.Tensor, metric: str) -> torch.Tensor:
|
| 171 |
# BiRefNet produces a single foreground probability map, so RankSEG should
|
| 172 |
# return a binary mask for that one channel instead of a multiclass map.
|
| 173 |
rankseg = RankSEG(metric=metric, output_mode='multilabel', solver='RMA')
|
| 174 |
+
probs = pred.unsqueeze(0).unsqueeze(0).to(torch.float32)
|
| 175 |
+
return rankseg.predict(probs).squeeze(0).squeeze(0).to(torch.float32)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def get_soft_gate(rankseg_mask: torch.Tensor, dilate_kernel: int = 9, blur_kernel: int = 15) -> torch.Tensor:
|
| 179 |
+
support = rankseg_mask.unsqueeze(0).unsqueeze(0).to(torch.float32)
|
| 180 |
+
dilated = torch.nn.functional.max_pool2d(
|
| 181 |
+
support,
|
| 182 |
+
kernel_size=dilate_kernel,
|
| 183 |
+
stride=1,
|
| 184 |
+
padding=dilate_kernel // 2,
|
| 185 |
+
)
|
| 186 |
+
soft_gate = torch.nn.functional.avg_pool2d(
|
| 187 |
+
dilated,
|
| 188 |
+
kernel_size=blur_kernel,
|
| 189 |
+
stride=1,
|
| 190 |
+
padding=blur_kernel // 2,
|
| 191 |
+
)
|
| 192 |
+
return soft_gate.squeeze(0).squeeze(0).clamp(0, 1)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def build_hybrid_alpha(pred: torch.Tensor, rankseg_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 196 |
+
soft_gate = get_soft_gate(rankseg_mask)
|
| 197 |
+
hard_alpha = (pred * rankseg_mask.to(torch.float32)).clamp(0, 1)
|
| 198 |
+
soft_alpha = torch.where(rankseg_mask > 0, pred, pred * soft_gate).clamp(0, 1)
|
| 199 |
+
return hard_alpha, soft_alpha
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def build_alpha_cutout(image: Image.Image, mask: Image.Image) -> Image.Image:
|
| 203 |
+
output = image.copy()
|
| 204 |
+
output.putalpha(mask.resize(image.size))
|
| 205 |
+
return output
|
| 206 |
|
| 207 |
|
| 208 |
def build_masked_image(image: Image.Image, mask: Image.Image) -> Image.Image:
|
|
|
|
| 298 |
|
| 299 |
if isinstance(images, list):
|
| 300 |
raw_save_paths = []
|
| 301 |
+
rankseg_hard_save_paths = []
|
| 302 |
+
rankseg_soft_save_paths = []
|
| 303 |
save_dir = 'preds-BiRefNet'
|
| 304 |
if not os.path.exists(save_dir):
|
| 305 |
os.makedirs(save_dir)
|
|
|
|
| 325 |
|
| 326 |
# Prediction
|
| 327 |
with torch.no_grad():
|
| 328 |
+
preds = birefnet(image_proc.to(device).half())[-1].sigmoid().float().cpu()
|
| 329 |
pred = preds[0].squeeze()
|
| 330 |
|
| 331 |
pred_pil = transforms.ToPILImage()(pred)
|
| 332 |
+
raw_image_masked = build_alpha_cutout(image, pred_pil)
|
| 333 |
+
rankseg_hard_image_masked = None
|
| 334 |
+
rankseg_soft_image_masked = None
|
| 335 |
if enable_rankseg:
|
| 336 |
+
rankseg_pred = get_rankseg_pred(pred, rankseg_metric)
|
| 337 |
+
hard_alpha, soft_alpha = build_hybrid_alpha(pred, rankseg_pred)
|
| 338 |
+
rankseg_hard_image_masked = build_alpha_cutout(image, transforms.ToPILImage()(hard_alpha))
|
| 339 |
+
rankseg_soft_image_masked = build_alpha_cutout(image, transforms.ToPILImage()(soft_alpha))
|
| 340 |
|
| 341 |
if device == 'cuda':
|
| 342 |
torch.cuda.empty_cache()
|
| 343 |
|
| 344 |
if tab_is_batch:
|
| 345 |
image_name = os.path.splitext(os.path.basename(image_src))[0]
|
| 346 |
+
raw_save_file_path = os.path.join(save_dir, f"{image_name}_pred.png")
|
| 347 |
raw_image_masked.save(raw_save_file_path)
|
| 348 |
raw_save_paths.append(raw_save_file_path)
|
| 349 |
+
if enable_rankseg and rankseg_hard_image_masked is not None:
|
| 350 |
+
rankseg_hard_save_file_path = os.path.join(save_dir, f"{image_name}_pred_rankseg.png")
|
| 351 |
+
rankseg_hard_image_masked.save(rankseg_hard_save_file_path)
|
| 352 |
+
rankseg_hard_save_paths.append(rankseg_hard_save_file_path)
|
| 353 |
+
if enable_rankseg and rankseg_soft_image_masked is not None:
|
| 354 |
+
rankseg_soft_save_file_path = os.path.join(save_dir, f"{image_name}_pred_softgate.png")
|
| 355 |
+
rankseg_soft_image_masked.save(rankseg_soft_save_file_path)
|
| 356 |
+
rankseg_soft_save_paths.append(rankseg_soft_save_file_path)
|
| 357 |
|
| 358 |
if tab_is_batch:
|
| 359 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
| 360 |
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
| 361 |
+
for file in raw_save_paths + rankseg_hard_save_paths + rankseg_soft_save_paths:
|
| 362 |
zipf.write(file, os.path.basename(file))
|
| 363 |
+
return raw_save_paths, rankseg_hard_save_paths, rankseg_soft_save_paths, zip_file_path
|
| 364 |
else:
|
| 365 |
+
return image, raw_image_masked, rankseg_hard_image_masked, rankseg_soft_image_masked
|
| 366 |
|
| 367 |
|
| 368 |
examples = [[_] for _ in glob('examples/*')][:]
|
|
|
|
| 398 |
],
|
| 399 |
outputs=[
|
| 400 |
gr.Image(label="Original image", type="pil", format='png'),
|
| 401 |
+
gr.Image(label="BiRefNet pred alpha", type="pil", format='png'),
|
| 402 |
+
gr.Image(label="BiRefNet pred x RankSEG", type="pil", format='png'),
|
| 403 |
+
gr.Image(label="BiRefNet soft-gated hybrid", type="pil", format='png'),
|
| 404 |
],
|
| 405 |
examples=examples,
|
| 406 |
api_name="image",
|
|
|
|
| 418 |
],
|
| 419 |
outputs=[
|
| 420 |
gr.Image(label="Original image", type="pil", format='png'),
|
| 421 |
+
gr.Image(label="BiRefNet pred alpha", type="pil", format='png'),
|
| 422 |
+
gr.Image(label="BiRefNet pred x RankSEG", type="pil", format='png'),
|
| 423 |
+
gr.Image(label="BiRefNet soft-gated hybrid", type="pil", format='png'),
|
| 424 |
],
|
| 425 |
examples=examples_url,
|
| 426 |
api_name="URL",
|
|
|
|
| 437 |
gr.Radio(RANKSEG_METRICS, value='dice', label="RankSEG metric", info="Choose the target metric for RankSEG post-processing.")
|
| 438 |
],
|
| 439 |
outputs=[
|
| 440 |
+
gr.Gallery(label="BiRefNet pred alpha results"),
|
| 441 |
+
gr.Gallery(label="BiRefNet pred x RankSEG results"),
|
| 442 |
+
gr.Gallery(label="BiRefNet soft-gated hybrid results"),
|
| 443 |
gr.File(label="Download masked images."),
|
| 444 |
],
|
| 445 |
api_name="batch",
|
app_local.py
CHANGED
|
@@ -164,13 +164,42 @@ def refine_foreground(image, mask, r=90, device='cuda'):
|
|
| 164 |
return estimated_foreground
|
| 165 |
|
| 166 |
|
| 167 |
-
def
|
| 168 |
# BiRefNet produces a single foreground probability map, so RankSEG should
|
| 169 |
# return a binary mask for that one channel instead of a multiclass map.
|
| 170 |
rankseg = RankSEG(metric=metric, output_mode='multilabel', solver='RMA')
|
| 171 |
-
probs = pred.unsqueeze(0).unsqueeze(0)
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
def build_masked_image(image: Image.Image, mask: Image.Image) -> Image.Image:
|
|
@@ -265,7 +294,8 @@ def predict(images, resolution, weights_file, enable_rankseg, rankseg_metric):
|
|
| 265 |
|
| 266 |
if isinstance(images, list):
|
| 267 |
raw_save_paths = []
|
| 268 |
-
|
|
|
|
| 269 |
save_dir = 'preds-BiRefNet'
|
| 270 |
if not os.path.exists(save_dir):
|
| 271 |
os.makedirs(save_dir)
|
|
@@ -291,37 +321,44 @@ def predict(images, resolution, weights_file, enable_rankseg, rankseg_metric):
|
|
| 291 |
|
| 292 |
# Prediction
|
| 293 |
with torch.no_grad():
|
| 294 |
-
preds = birefnet(image_proc.to(device).half())[-1].sigmoid().cpu()
|
| 295 |
pred = preds[0].squeeze()
|
| 296 |
|
| 297 |
pred_pil = transforms.ToPILImage()(pred)
|
| 298 |
-
raw_image_masked =
|
| 299 |
-
|
|
|
|
| 300 |
if enable_rankseg:
|
| 301 |
-
|
| 302 |
-
|
|
|
|
|
|
|
| 303 |
|
| 304 |
if device == 'cuda':
|
| 305 |
torch.cuda.empty_cache()
|
| 306 |
|
| 307 |
if tab_is_batch:
|
| 308 |
image_name = os.path.splitext(os.path.basename(image_src))[0]
|
| 309 |
-
raw_save_file_path = os.path.join(save_dir, f"{image_name}
|
| 310 |
raw_image_masked.save(raw_save_file_path)
|
| 311 |
raw_save_paths.append(raw_save_file_path)
|
| 312 |
-
if enable_rankseg and
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
if tab_is_batch:
|
| 318 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
| 319 |
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
| 320 |
-
for file in raw_save_paths +
|
| 321 |
zipf.write(file, os.path.basename(file))
|
| 322 |
-
return raw_save_paths,
|
| 323 |
else:
|
| 324 |
-
return image, raw_image_masked,
|
| 325 |
|
| 326 |
|
| 327 |
examples = [[_] for _ in glob('examples/*')][:]
|
|
@@ -357,8 +394,9 @@ tab_image = gr.Interface(
|
|
| 357 |
],
|
| 358 |
outputs=[
|
| 359 |
gr.Image(label="Original image", type="pil", format='png'),
|
| 360 |
-
gr.Image(label="BiRefNet
|
| 361 |
-
gr.Image(label="BiRefNet
|
|
|
|
| 362 |
],
|
| 363 |
examples=examples,
|
| 364 |
api_name="image",
|
|
@@ -376,8 +414,9 @@ tab_text = gr.Interface(
|
|
| 376 |
],
|
| 377 |
outputs=[
|
| 378 |
gr.Image(label="Original image", type="pil", format='png'),
|
| 379 |
-
gr.Image(label="BiRefNet
|
| 380 |
-
gr.Image(label="BiRefNet
|
|
|
|
| 381 |
],
|
| 382 |
examples=examples_url,
|
| 383 |
api_name="URL",
|
|
@@ -394,8 +433,9 @@ tab_batch = gr.Interface(
|
|
| 394 |
gr.Radio(RANKSEG_METRICS, value='dice', label="RankSEG metric", info="Choose the target metric for RankSEG post-processing.")
|
| 395 |
],
|
| 396 |
outputs=[
|
| 397 |
-
gr.Gallery(label="BiRefNet results"),
|
| 398 |
-
gr.Gallery(label="BiRefNet
|
|
|
|
| 399 |
gr.File(label="Download masked images."),
|
| 400 |
],
|
| 401 |
api_name="batch",
|
|
|
|
| 164 |
return estimated_foreground
|
| 165 |
|
| 166 |
|
| 167 |
+
def get_rankseg_pred(pred: torch.Tensor, metric: str) -> torch.Tensor:
|
| 168 |
# BiRefNet produces a single foreground probability map, so RankSEG should
|
| 169 |
# return a binary mask for that one channel instead of a multiclass map.
|
| 170 |
rankseg = RankSEG(metric=metric, output_mode='multilabel', solver='RMA')
|
| 171 |
+
probs = pred.unsqueeze(0).unsqueeze(0).to(torch.float32)
|
| 172 |
+
return rankseg.predict(probs).squeeze(0).squeeze(0).to(torch.float32)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_soft_gate(rankseg_mask: torch.Tensor, dilate_kernel: int = 9, blur_kernel: int = 15) -> torch.Tensor:
|
| 176 |
+
support = rankseg_mask.unsqueeze(0).unsqueeze(0).to(torch.float32)
|
| 177 |
+
dilated = torch.nn.functional.max_pool2d(
|
| 178 |
+
support,
|
| 179 |
+
kernel_size=dilate_kernel,
|
| 180 |
+
stride=1,
|
| 181 |
+
padding=dilate_kernel // 2,
|
| 182 |
+
)
|
| 183 |
+
soft_gate = torch.nn.functional.avg_pool2d(
|
| 184 |
+
dilated,
|
| 185 |
+
kernel_size=blur_kernel,
|
| 186 |
+
stride=1,
|
| 187 |
+
padding=blur_kernel // 2,
|
| 188 |
+
)
|
| 189 |
+
return soft_gate.squeeze(0).squeeze(0).clamp(0, 1)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def build_hybrid_alpha(pred: torch.Tensor, rankseg_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 193 |
+
soft_gate = get_soft_gate(rankseg_mask)
|
| 194 |
+
hard_alpha = (pred * rankseg_mask.to(torch.float32)).clamp(0, 1)
|
| 195 |
+
soft_alpha = torch.where(rankseg_mask > 0, pred, pred * soft_gate).clamp(0, 1)
|
| 196 |
+
return hard_alpha, soft_alpha
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def build_alpha_cutout(image: Image.Image, mask: Image.Image) -> Image.Image:
|
| 200 |
+
output = image.copy()
|
| 201 |
+
output.putalpha(mask.resize(image.size))
|
| 202 |
+
return output
|
| 203 |
|
| 204 |
|
| 205 |
def build_masked_image(image: Image.Image, mask: Image.Image) -> Image.Image:
|
|
|
|
| 294 |
|
| 295 |
if isinstance(images, list):
|
| 296 |
raw_save_paths = []
|
| 297 |
+
rankseg_hard_save_paths = []
|
| 298 |
+
rankseg_soft_save_paths = []
|
| 299 |
save_dir = 'preds-BiRefNet'
|
| 300 |
if not os.path.exists(save_dir):
|
| 301 |
os.makedirs(save_dir)
|
|
|
|
| 321 |
|
| 322 |
# Prediction
|
| 323 |
with torch.no_grad():
|
| 324 |
+
preds = birefnet(image_proc.to(device).half())[-1].sigmoid().float().cpu()
|
| 325 |
pred = preds[0].squeeze()
|
| 326 |
|
| 327 |
pred_pil = transforms.ToPILImage()(pred)
|
| 328 |
+
raw_image_masked = build_alpha_cutout(image, pred_pil)
|
| 329 |
+
rankseg_hard_image_masked = None
|
| 330 |
+
rankseg_soft_image_masked = None
|
| 331 |
if enable_rankseg:
|
| 332 |
+
rankseg_pred = get_rankseg_pred(pred, rankseg_metric)
|
| 333 |
+
hard_alpha, soft_alpha = build_hybrid_alpha(pred, rankseg_pred)
|
| 334 |
+
rankseg_hard_image_masked = build_alpha_cutout(image, transforms.ToPILImage()(hard_alpha))
|
| 335 |
+
rankseg_soft_image_masked = build_alpha_cutout(image, transforms.ToPILImage()(soft_alpha))
|
| 336 |
|
| 337 |
if device == 'cuda':
|
| 338 |
torch.cuda.empty_cache()
|
| 339 |
|
| 340 |
if tab_is_batch:
|
| 341 |
image_name = os.path.splitext(os.path.basename(image_src))[0]
|
| 342 |
+
raw_save_file_path = os.path.join(save_dir, f"{image_name}_pred.png")
|
| 343 |
raw_image_masked.save(raw_save_file_path)
|
| 344 |
raw_save_paths.append(raw_save_file_path)
|
| 345 |
+
if enable_rankseg and rankseg_hard_image_masked is not None:
|
| 346 |
+
rankseg_hard_save_file_path = os.path.join(save_dir, f"{image_name}_pred_rankseg.png")
|
| 347 |
+
rankseg_hard_image_masked.save(rankseg_hard_save_file_path)
|
| 348 |
+
rankseg_hard_save_paths.append(rankseg_hard_save_file_path)
|
| 349 |
+
if enable_rankseg and rankseg_soft_image_masked is not None:
|
| 350 |
+
rankseg_soft_save_file_path = os.path.join(save_dir, f"{image_name}_pred_softgate.png")
|
| 351 |
+
rankseg_soft_image_masked.save(rankseg_soft_save_file_path)
|
| 352 |
+
rankseg_soft_save_paths.append(rankseg_soft_save_file_path)
|
| 353 |
|
| 354 |
if tab_is_batch:
|
| 355 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
| 356 |
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
| 357 |
+
for file in raw_save_paths + rankseg_hard_save_paths + rankseg_soft_save_paths:
|
| 358 |
zipf.write(file, os.path.basename(file))
|
| 359 |
+
return raw_save_paths, rankseg_hard_save_paths, rankseg_soft_save_paths, zip_file_path
|
| 360 |
else:
|
| 361 |
+
return image, raw_image_masked, rankseg_hard_image_masked, rankseg_soft_image_masked
|
| 362 |
|
| 363 |
|
| 364 |
examples = [[_] for _ in glob('examples/*')][:]
|
|
|
|
| 394 |
],
|
| 395 |
outputs=[
|
| 396 |
gr.Image(label="Original image", type="pil", format='png'),
|
| 397 |
+
gr.Image(label="BiRefNet pred alpha", type="pil", format='png'),
|
| 398 |
+
gr.Image(label="BiRefNet pred x RankSEG", type="pil", format='png'),
|
| 399 |
+
gr.Image(label="BiRefNet soft-gated hybrid", type="pil", format='png'),
|
| 400 |
],
|
| 401 |
examples=examples,
|
| 402 |
api_name="image",
|
|
|
|
| 414 |
],
|
| 415 |
outputs=[
|
| 416 |
gr.Image(label="Original image", type="pil", format='png'),
|
| 417 |
+
gr.Image(label="BiRefNet pred alpha", type="pil", format='png'),
|
| 418 |
+
gr.Image(label="BiRefNet pred x RankSEG", type="pil", format='png'),
|
| 419 |
+
gr.Image(label="BiRefNet soft-gated hybrid", type="pil", format='png'),
|
| 420 |
],
|
| 421 |
examples=examples_url,
|
| 422 |
api_name="URL",
|
|
|
|
| 433 |
gr.Radio(RANKSEG_METRICS, value='dice', label="RankSEG metric", info="Choose the target metric for RankSEG post-processing.")
|
| 434 |
],
|
| 435 |
outputs=[
|
| 436 |
+
gr.Gallery(label="BiRefNet pred alpha results"),
|
| 437 |
+
gr.Gallery(label="BiRefNet pred x RankSEG results"),
|
| 438 |
+
gr.Gallery(label="BiRefNet soft-gated hybrid results"),
|
| 439 |
gr.File(label="Download masked images."),
|
| 440 |
],
|
| 441 |
api_name="batch",
|
docs/rankseg_refine_foreground_analysis.md
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RankSEG 与 `refine_foreground` 管线分析报告
|
| 2 |
+
|
| 3 |
+
## 1. 结论摘要
|
| 4 |
+
|
| 5 |
+
1. `refine_foreground` 支持 soft mask,也支持 binary mask。它在实现上把 `mask` 当作 `alpha` 使用,要求数值范围在 `[0, 1]`;binary mask 只是 soft mask 的特例。
|
| 6 |
+
2. 当前 `rankseg` 版本的 `RankSEG.predict()` 不输出 soft mask。官方文档和本地源码都表明,它返回的是离散预测,值为 `0/1` 或布尔值,而不是新的概率图。
|
| 7 |
+
3. 对二分类单通道场景,RankSEG 不是简单地把固定阈值 `0.5` 换成另一个全局阈值 `tau`。更准确地说,它会基于整张图的概率排序,求一个当前图像、当前类别对应的最优 `top-k` 截断位置 `opt_tau`,再把前 `k` 个像素设为前景。
|
| 8 |
+
4. 因此,`refine_foreground(rankseg_mask)` 在代码上是成立的,但它用到的是 hard alpha。这样更利于区域判定干净,不一定更利于发丝、半透明边缘等 soft alpha 细节。
|
| 9 |
+
5. 如果目标是“让 RankSEG 的区域选择能力也帮助最终抠图质量”,更合理的方案通常不是要求 RankSEG 直接输出 soft mask,而是让 RankSEG 作为区域约束,原始 `pred` 继续提供 soft alpha。
|
| 10 |
+
|
| 11 |
+
## 2. 问题与回答
|
| 12 |
+
|
| 13 |
+
### Q1. `refine_foreground` 支持 soft mask 吗?
|
| 14 |
+
|
| 15 |
+
支持。
|
| 16 |
+
|
| 17 |
+
本地实现里,`refine_foreground` 明确写着:
|
| 18 |
+
|
| 19 |
+
- `image` 和 `mask` 都应在 `[0, 1]` 范围内
|
| 20 |
+
- `mask` 会被转成 `float32`
|
| 21 |
+
- 后续计算里 `mask` 以 `alpha` 形式参与模糊、混合和前景估计
|
| 22 |
+
|
| 23 |
+
这说明它期望的是连续 alpha,而不是仅限于 0/1 标签。
|
| 24 |
+
|
| 25 |
+
### Q2. `refine_foreground` 能接受 binary mask 吗?
|
| 26 |
+
|
| 27 |
+
能。
|
| 28 |
+
|
| 29 |
+
binary mask 的取值是 `{0, 1}`,本身就是 `[0, 1]` 的子集,因此在数学上和实现上都合法。只是 binary mask 不携带边界过渡信息,所以对前景颜色估计来说通常不如 soft mask 丰富。
|
| 30 |
+
|
| 31 |
+
### Q3. RankSEG 能输出 soft mask 吗?
|
| 32 |
+
|
| 33 |
+
就当前项目安装的 `rankseg==0.0.4` 而言,不能。
|
| 34 |
+
|
| 35 |
+
`RankSEG.predict()` 的官方文档和本地源码都写明它返回的是 binary segmentation predictions。`output_mode='multilabel'` 的含义是“每个类返回一张二值 mask”,不是“输出 soft mask”。
|
| 36 |
+
|
| 37 |
+
### Q4. RankSEG 是否等价于“找一个更好的全局阈值 `tau`”?
|
| 38 |
+
|
| 39 |
+
不完全是。
|
| 40 |
+
|
| 41 |
+
对你现在的单图、单类、binary segmentation 使用场景,可以把它近似理解成“找到一个比 0.5 更合适的自适应截断位置”。但从源码看,它优化的是排序后的 `top-k` 选择,不是直接在原概率空间里学习一个固定全局阈值。
|
| 42 |
+
|
| 43 |
+
更精确地说:
|
| 44 |
+
|
| 45 |
+
- `0.5 threshold` 是固定规则
|
| 46 |
+
- RankSEG 是对当前图像的所有像素概率排序
|
| 47 |
+
- 然后根据目标指标估计最佳 `opt_tau`
|
| 48 |
+
- 最终保留前 `opt_tau` 个像素为前景
|
| 49 |
+
|
| 50 |
+
在单通道情况下,这个过程可以等价成一个“当前图像专属”的隐式阈值切点,但不是数据集级别的全局常数阈值。
|
| 51 |
+
|
| 52 |
+
## 3. `refine_foreground` 的来源、术语与技术脉络
|
| 53 |
+
|
| 54 |
+
### 3.1 `FB` 是什么意思?
|
| 55 |
+
|
| 56 |
+
这里的 `F` 和 `B` 是 matting / compositing 领域里的经典记号:
|
| 57 |
+
|
| 58 |
+
- `F` = Foreground,前景颜色
|
| 59 |
+
- `B` = Background,背景颜色
|
| 60 |
+
- `alpha` = 透明度或前景占比
|
| 61 |
+
|
| 62 |
+
经典合成公式可以写成:
|
| 63 |
+
|
| 64 |
+
```text
|
| 65 |
+
I = alpha * F + (1 - alpha) * B
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
其中:
|
| 69 |
+
|
| 70 |
+
- `I` 是观测到的输入图像
|
| 71 |
+
- `F` 是希望恢复的真实前景颜色
|
| 72 |
+
- `B` 是背景颜色
|
| 73 |
+
- `alpha` 控制前景和背景的混合比例
|
| 74 |
+
|
| 75 |
+
所以这份代码里的 `FB_blur_fusion_foreground_estimator_*`,名字可以直接读成:
|
| 76 |
+
|
| 77 |
+
```text
|
| 78 |
+
一种基于模糊融合(blur fusion)的 F/B 前景估计器
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
它的目标不是再次估计分割类别,而是在已知 `image` 和 `alpha(mask)` 后,估计更干净的 `F`。
|
| 82 |
+
|
| 83 |
+
### 3.2 当前实现最直接的代码来源
|
| 84 |
+
|
| 85 |
+
从项目代码注释看,当前 `refine_foreground` 链路最直接依赖了两个来源:
|
| 86 |
+
|
| 87 |
+
1. Photoroom 的 `fast-foreground-estimation`
|
| 88 |
+
2. BiRefNet 社区里的 GPU 改写版本
|
| 89 |
+
|
| 90 |
+
项目本地注释已经写明:
|
| 91 |
+
|
| 92 |
+
- CPU 版本参考了 Photoroom 仓库
|
| 93 |
+
- GPU 双阶段版本参考了 BiRefNet issue comment
|
| 94 |
+
|
| 95 |
+
对应代码见 [app.py](/Users/lev1s/Documents/BiRefNet_demo/app.py#L137)。
|
| 96 |
+
|
| 97 |
+
BiRefNet 官方仓库后续也明确写到,`refine_foreground` 的提速来自 “the GPU implementation of fast-fg-est”。这说明当前仓库里的 `refine_foreground`,本质上不是 BiRefNet 原论文提出的新理论模块,而是把现有 foreground estimation 方法接到了推理后处理中。
|
| 98 |
+
|
| 99 |
+
### 3.3 直接可追溯的论文来源
|
| 100 |
+
|
| 101 |
+
当前这条实现链最清楚的论文来源分两层:
|
| 102 |
+
|
| 103 |
+
#### 第一层:近似实现
|
| 104 |
+
|
| 105 |
+
Photoroom 的开源仓库 `fast-foreground-estimation` 明确写着,它是论文 “Approximate Fast Foreground Colour Estimation” 的官方仓库,作者是 Marco Forte,发表于 ICIP 2021。
|
| 106 |
+
|
| 107 |
+
这个仓库 README 又明确说明:
|
| 108 |
+
|
| 109 |
+
- 该方法是一个很快的 foreground estimation technique
|
| 110 |
+
- 它“yields comparable results to the full approach [1], while also being faster”
|
| 111 |
+
- 其中 `[1]` 指向的就是 `Fast Multi-Level Foreground Estimation`
|
| 112 |
+
|
| 113 |
+
也就是说,当前代码中这种非常短小、基于 blur fusion 的实现,最直接对应的是:
|
| 114 |
+
|
| 115 |
+
```text
|
| 116 |
+
Marco Forte, "Approximate Fast Foreground Colour Estimation", ICIP 2021
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
#### 第二层:更早的完整方法
|
| 120 |
+
|
| 121 |
+
Photoroom 官方材料和 PyMatting 文档都把上面的近似法指向了更早的完整方法:
|
| 122 |
+
|
| 123 |
+
```text
|
| 124 |
+
Thomas Germer, Tobias Uelwer, Stefan Conrad, Stefan Harmeling,
|
| 125 |
+
"Fast Multi-Level Foreground Estimation", ICPR 2020
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
这篇工作讨论的是:
|
| 129 |
+
|
| 130 |
+
- 已知 alpha matte
|
| 131 |
+
- 如何估计 foreground colours
|
| 132 |
+
- 以避免直接抠图后出现边缘 bleed-through
|
| 133 |
+
|
| 134 |
+
所以从“研究脉络”上讲:
|
| 135 |
+
|
| 136 |
+
```text
|
| 137 |
+
当前 refine_foreground
|
| 138 |
+
<- 工程上更接近 Photoroom 的 Approximate Fast Foreground Colour Estimation
|
| 139 |
+
<- 理论上又是对 Germer 等人 Fast Multi-Level Foreground Estimation 的近似实现
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### 3.4 它和传统 matting 文献的关系
|
| 143 |
+
|
| 144 |
+
这条线属于 image matting / foreground estimation,而不是纯 segmentation。
|
| 145 |
+
|
| 146 |
+
更具体地说,它属于:
|
| 147 |
+
|
| 148 |
+
- 已知 alpha 或已有 mask
|
| 149 |
+
- 进一步恢复前景颜色 `F`
|
| 150 |
+
- 让最终合成结果更自然
|
| 151 |
+
|
| 152 |
+
这和只输出一个二值 mask 的分割工作不同,也和只估计 alpha matte 的方法不同。它解决的是:
|
| 153 |
+
|
| 154 |
+
```text
|
| 155 |
+
“mask 有了之后,怎么恢复更干净的前景颜色”。
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
所以它在定位上更接近:
|
| 159 |
+
|
| 160 |
+
- foreground colour estimation
|
| 161 |
+
- alpha matting 的后处理
|
| 162 |
+
- compositing quality enhancement
|
| 163 |
+
|
| 164 |
+
而不是:
|
| 165 |
+
|
| 166 |
+
- segmentation mask optimization
|
| 167 |
+
|
| 168 |
+
### 3.5 对当前项目的准确表述
|
| 169 |
+
|
| 170 |
+
如果在报告或提交说明里需要一句准确表述,我建议写成:
|
| 171 |
+
|
| 172 |
+
```text
|
| 173 |
+
本项目中的 refine_foreground 并非 BiRefNet 原始分割网络的一部分,
|
| 174 |
+
而是一个用于前景颜色估计的后处理模块。
|
| 175 |
+
其当前实现最直接来源于 Photoroom 开源的
|
| 176 |
+
"Approximate Fast Foreground Colour Estimation"(Marco Forte, ICIP 2021),
|
| 177 |
+
该方法又是对 Germer 等人
|
| 178 |
+
"Fast Multi-Level Foreground Estimation"(ICPR 2020)
|
| 179 |
+
的快速近似。
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
## 4. 源码证据
|
| 183 |
+
|
| 184 |
+
### 4.1 `refine_foreground` 如何使用 mask
|
| 185 |
+
|
| 186 |
+
项目里的实现见 [app.py](/Users/lev1s/Documents/BiRefNet_demo/app.py#L143)。
|
| 187 |
+
|
| 188 |
+
关键信息:
|
| 189 |
+
|
| 190 |
+
```python
|
| 191 |
+
def refine_foreground(image, mask, r=90, device='cuda'):
|
| 192 |
+
"""both image and mask are in range of [0, 1]"""
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
以及:
|
| 196 |
+
|
| 197 |
+
```python
|
| 198 |
+
mask = transforms.functional.to_tensor(mask).float().cuda()
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
和:
|
| 202 |
+
|
| 203 |
+
```python
|
| 204 |
+
blurred_alpha = mean_blur(alpha, kernel_size=r)
|
| 205 |
+
blurred_FGA = mean_blur(FG * alpha, kernel_size=r)
|
| 206 |
+
blurred_B1A = mean_blur(B * (1 - alpha), kernel_size=r)
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
这些实现说明:
|
| 210 |
+
|
| 211 |
+
1. `mask` 被当成连续 `alpha`
|
| 212 |
+
2. 算法内部直接使用 `alpha` 和 `1 - alpha`
|
| 213 |
+
3. soft alpha 会影响模糊融合结果
|
| 214 |
+
4. binary alpha 也可运行,但只是一种退化情况
|
| 215 |
+
|
| 216 |
+
### 4.2 当前项目中 RankSEG 的接法
|
| 217 |
+
|
| 218 |
+
见 [app.py](/Users/lev1s/Documents/BiRefNet_demo/app.py#L170):
|
| 219 |
+
|
| 220 |
+
```python
|
| 221 |
+
def get_rankseg_mask(pred: torch.Tensor, metric: str) -> Image.Image:
|
| 222 |
+
rankseg = RankSEG(metric=metric, output_mode='multilabel', solver='RMA')
|
| 223 |
+
probs = pred.unsqueeze(0).unsqueeze(0).to(torch.float32)
|
| 224 |
+
rankseg_pred = rankseg.predict(probs).squeeze(0).squeeze(0).to(torch.float32)
|
| 225 |
+
return transforms.ToPILImage()(rankseg_pred)
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
这里输入 `probs` 是 `(1, 1, H, W)` 的 soft probability map,但 `rankseg_pred` 是离散预测后再转回 `float32` 以便转图像,不代表它重新变成了 soft mask。
|
| 229 |
+
|
| 230 |
+
### 4.3 RankSEG 官方语义
|
| 231 |
+
|
| 232 |
+
本地安装包源码见 [rankseg/_rankseg.py](/Users/lev1s/Documents/BiRefNet_demo/.venv/lib/python3.12/site-packages/rankseg/_rankseg.py#L94):
|
| 233 |
+
|
| 234 |
+
```python
|
| 235 |
+
def predict(self, probs):
|
| 236 |
+
"""Convert probability maps to binary segmentation predictions.
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
返回说明见同文件 [rankseg/_rankseg.py](/Users/lev1s/Documents/BiRefNet_demo/.venv/lib/python3.12/site-packages/rankseg/_rankseg.py#L105):
|
| 240 |
+
|
| 241 |
+
```python
|
| 242 |
+
preds : torch.Tensor
|
| 243 |
+
Binary segmentation predictions ...
|
| 244 |
+
Values are 0 or 1 (or boolean True/False depending on solver).
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
这已经直接排除了“当前 `predict()` 输出 soft mask”的解释。
|
| 248 |
+
|
| 249 |
+
### 4.4 RankSEG 的 `tau` 到底是什么
|
| 250 |
+
|
| 251 |
+
实现见 [rankseg/_rankseg_algo.py](/Users/lev1s/Documents/BiRefNet_demo/.venv/lib/python3.12/site-packages/rankseg/_rankseg_algo.py#L250)。
|
| 252 |
+
|
| 253 |
+
关键逻辑:
|
| 254 |
+
|
| 255 |
+
```python
|
| 256 |
+
opt_tau = torch.argmax(metric_values, dim=-1) + 1
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
以及:
|
| 260 |
+
|
| 261 |
+
```python
|
| 262 |
+
overlap_preds[b, c, top_index[b, c, :opt_tau[b, c]]] = True
|
| 263 |
+
```
|
| 264 |
+
|
| 265 |
+
这说明:
|
| 266 |
+
|
| 267 |
+
1. 先对概率从大到小排序
|
| 268 |
+
2. 估计不同 `tau` 下的指标值
|
| 269 |
+
3. 选择最优 `opt_tau`
|
| 270 |
+
4. 将前 `opt_tau` 个像素直接置为前景
|
| 271 |
+
|
| 272 |
+
因此 `tau` 是“保留多少个最高概率像素”的截断位置,本质上是 `top-k` 选择,不是简单地把所有像素与某个固定常数阈值比较。
|
| 273 |
+
|
| 274 |
+
## 5. 解释:为什么 `refine_foreground` 更偏爱 soft mask
|
| 275 |
+
|
| 276 |
+
`refine_foreground` 做的不是再次分割,而是前景颜色估计。它本质上在解下面这类问题:
|
| 277 |
+
|
| 278 |
+
```text
|
| 279 |
+
image ~= foreground * alpha + background * (1 - alpha)
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
这里的 `alpha` 越连续,算法越能利用边缘过渡信息去恢复更自然的前景颜色。
|
| 283 |
+
|
| 284 |
+
所以:
|
| 285 |
+
|
| 286 |
+
- soft mask 更利于柔边、发丝、半透明边缘
|
| 287 |
+
- binary mask 更利于轮廓果断、区域清晰
|
| 288 |
+
|
| 289 |
+
两者优化目标不同:
|
| 290 |
+
|
| 291 |
+
- RankSEG 更偏向 mask 指标优化,如 Dice / IoU
|
| 292 |
+
- `refine_foreground` 更偏向最终抠图观感优化
|
| 293 |
+
|
| 294 |
+
## 6. 回答你的核心判断
|
| 295 |
+
|
| 296 |
+
### 5.1 “如果 RankSEG 输出 soft mask,就能更好接上 `refine_foreground` 吗?”
|
| 297 |
+
|
| 298 |
+
原则上是的。
|
| 299 |
+
|
| 300 |
+
如果 RankSEG 能输出经过其全局优化后的 soft alpha,那它会比 hard mask 更适合 `refine_foreground`。但当前版本并没有这样的 API。
|
| 301 |
+
|
| 302 |
+
### 5.2 “既然 RankSEG 不能输出 soft mask,那当前 `refine_foreground(rankseg_mask)` 是否没有意义?”
|
| 303 |
+
|
| 304 |
+
不是没有意义,而是意义不同。
|
| 305 |
+
|
| 306 |
+
当前做法的价值是:
|
| 307 |
+
|
| 308 |
+
- 利用 RankSEG 提升区域级的前景/背景判定
|
| 309 |
+
- 再利用 `refine_foreground` 改善前景颜色和合成效果
|
| 310 |
+
|
| 311 |
+
但它的限制也很明确:
|
| 312 |
+
|
| 313 |
+
- `refine_foreground` 用到的是 hard alpha
|
| 314 |
+
- 边缘渐变信息已经丢失
|
| 315 |
+
- 所以它无法完全发挥 soft alpha matting 的优势
|
| 316 |
+
|
| 317 |
+
## 7. 最合理的工程建议
|
| 318 |
+
|
| 319 |
+
### 方案 A. 保持现状
|
| 320 |
+
|
| 321 |
+
管线:
|
| 322 |
+
|
| 323 |
+
```text
|
| 324 |
+
pred (soft prob) -> RankSEG -> binary mask -> refine_foreground
|
| 325 |
+
```
|
| 326 |
+
|
| 327 |
+
优点:
|
| 328 |
+
|
| 329 |
+
- 实现简单
|
| 330 |
+
- 区域判定更干净
|
| 331 |
+
|
| 332 |
+
缺点:
|
| 333 |
+
|
| 334 |
+
- 边缘容易变硬
|
| 335 |
+
- 发丝、半透明细节可能不如原始 soft mask
|
| 336 |
+
|
| 337 |
+
适合:
|
| 338 |
+
|
| 339 |
+
- 更重视区域分割是否准确
|
| 340 |
+
- 更像“分割后抠图”
|
| 341 |
+
|
| 342 |
+
### 方案 B. 原始 soft mask 直接抠图
|
| 343 |
+
|
| 344 |
+
管线:
|
| 345 |
+
|
| 346 |
+
```text
|
| 347 |
+
pred (soft prob) -> refine_foreground
|
| 348 |
+
```
|
| 349 |
+
|
| 350 |
+
优点:
|
| 351 |
+
|
| 352 |
+
- 边缘更自然
|
| 353 |
+
- 更像 matting
|
| 354 |
+
|
| 355 |
+
缺点:
|
| 356 |
+
|
| 357 |
+
- 区域级误检会直接进入最终结果
|
| 358 |
+
|
| 359 |
+
适合:
|
| 360 |
+
|
| 361 |
+
- 更重视最终视觉观感
|
| 362 |
+
|
| 363 |
+
### 方案 C. 推荐的混合方案
|
| 364 |
+
|
| 365 |
+
管线:
|
| 366 |
+
|
| 367 |
+
```text
|
| 368 |
+
pred (soft prob) -> RankSEG(binary support)
|
| 369 |
+
-> hybrid_alpha = pred * rankseg_mask
|
| 370 |
+
-> refine_foreground(hybrid_alpha)
|
| 371 |
+
```
|
| 372 |
+
|
| 373 |
+
或更保守地:
|
| 374 |
+
|
| 375 |
+
```python
|
| 376 |
+
hybrid_alpha = torch.where(rankseg_mask > 0, pred, torch.zeros_like(pred))
|
| 377 |
+
```
|
| 378 |
+
|
| 379 |
+
这个思路的含义是:
|
| 380 |
+
|
| 381 |
+
- RankSEG 负责“哪些区域允许成为前景”
|
| 382 |
+
- 原始 `pred` 负责“这些前景区域内部的透明度变化”
|
| 383 |
+
|
| 384 |
+
优点:
|
| 385 |
+
|
| 386 |
+
- 同时利用 RankSEG 的区域约束能力
|
| 387 |
+
- 保留原始 soft mask 的边缘渐变
|
| 388 |
+
|
| 389 |
+
这是当前代码架构下最有希望同时兼顾 Dice/IoU 与视觉质量的方案。
|
| 390 |
+
|
| 391 |
+
## 8. 评估建议
|
| 392 |
+
|
| 393 |
+
如果你的目标是论文式或工程式对比,建议把评估拆成两组,不要混为一谈。
|
| 394 |
+
|
| 395 |
+
### 8.1 分割质量评估
|
| 396 |
+
|
| 397 |
+
比较对象:
|
| 398 |
+
|
| 399 |
+
- 原始 `pred` 阈值化后的 mask
|
| 400 |
+
- RankSEG 输出 mask
|
| 401 |
+
- 混合方案的最终 alpha 再阈值化的 mask
|
| 402 |
+
|
| 403 |
+
指标:
|
| 404 |
+
|
| 405 |
+
- IoU
|
| 406 |
+
- Dice
|
| 407 |
+
- Pixel Accuracy
|
| 408 |
+
|
| 409 |
+
注意:
|
| 410 |
+
|
| 411 |
+
- `refine_foreground` 不应作为 IoU/Dice 的评估对象
|
| 412 |
+
- 因为它输出的是前景图像,不是 mask 优化本身
|
| 413 |
+
|
| 414 |
+
### 8.2 抠图观感评估
|
| 415 |
+
|
| 416 |
+
比较对象:
|
| 417 |
+
|
| 418 |
+
- raw soft mask + refine
|
| 419 |
+
- rankseg binary mask + refine
|
| 420 |
+
- hybrid alpha + refine
|
| 421 |
+
|
| 422 |
+
关注点:
|
| 423 |
+
|
| 424 |
+
- 发丝
|
| 425 |
+
- 透明边缘
|
| 426 |
+
- 白边/黑边
|
| 427 |
+
- 前景颜色污染
|
| 428 |
+
|
| 429 |
+
如果有 matting 数据,可以进一步评估 alpha 或前景误差;如果没有,至少做稳定的可视化对比。
|
| 430 |
+
|
| 431 |
+
## 9. 最终结论
|
| 432 |
+
|
| 433 |
+
1. `refine_foreground` 同时支持 soft mask 与 binary mask,但从算法性质上更适合 soft alpha。
|
| 434 |
+
2. 当前 RankSEG 版本不输出 soft mask,只输出离散预测。
|
| 435 |
+
3. RankSEG 不应被简单理解成“把 0.5 换成一个更好的固定阈值”;它更接近于基于概率排序和目标指标的自适应 `top-k` 截断。
|
| 436 |
+
4. 因此,“让 `refine_foreground` 也增强 RankSEG 的效果”这件事,不应期待 RankSEG 直接给 soft mask,而应考虑让 RankSEG 提供区域约束、让原始 `pred` 提供 soft alpha。
|
| 437 |
+
5. 在当前工程里,最值得实验的下一步是混合方案,而不是继续追问 RankSEG 是否已有 soft output API。
|
| 438 |
+
|
| 439 |
+
## 10. 参考来源
|
| 440 |
+
|
| 441 |
+
### 本地代码
|
| 442 |
+
|
| 443 |
+
- [app.py](/Users/lev1s/Documents/BiRefNet_demo/app.py)
|
| 444 |
+
- [rankseg/_rankseg.py](/Users/lev1s/Documents/BiRefNet_demo/.venv/lib/python3.12/site-packages/rankseg/_rankseg.py)
|
| 445 |
+
- [rankseg/_rankseg_algo.py](/Users/lev1s/Documents/BiRefNet_demo/.venv/lib/python3.12/site-packages/rankseg/_rankseg_algo.py)
|
| 446 |
+
|
| 447 |
+
### 外部资料
|
| 448 |
+
|
| 449 |
+
- RankSEG getting started: [GitHub](https://github.com/rankseg/rankseg/blob/main/doc/source/getting_started.md)
|
| 450 |
+
- RankSEG auto API: [GitHub](https://github.com/rankseg/rankseg/blob/main/doc/source/autoapi/rankseg/rankseg/index.md)
|
| 451 |
+
- BiRefNet repository: [GitHub](https://github.com/ZhengPeng7/BiRefNet)
|
| 452 |
+
- Approximate Fast Foreground Colour Estimation repository: [GitHub](https://github.com/Photoroom/fast-foreground-estimation)
|
| 453 |
+
- Approximate Fast Foreground Colour Estimation talk page: [IEEE SPS Resource Center](https://resourcecenter.ieee.org/conferences/icip-2021/spsicip21vid215)
|
| 454 |
+
- Fast Multi-Level Foreground Estimation reference entry: [DBLP](https://dblp.org/rec/conf/icpr/GermerU0H20)
|
| 455 |
+
- PyMatting foreground estimation docs: [PyMatting](https://pymatting.github.io/)
|