Spaces:
Sleeping
Sleeping
Update binary_segmentation.py
Browse files- binary_segmentation.py +124 -124
binary_segmentation.py
CHANGED
|
@@ -566,166 +566,166 @@ class BinarySegmenter:
|
|
| 566 |
except ImportError:
|
| 567 |
raise ImportError("RMBG requires: pip install transformers")
|
| 568 |
|
| 569 |
-
def segment(
|
| 570 |
-
self,
|
| 571 |
-
image: np.ndarray,
|
| 572 |
-
threshold: float = 0.5,
|
| 573 |
-
return_type: Literal["mask", "rgba", "both"] = "mask"
|
| 574 |
-
) -> Tuple[Optional[np.ndarray], Optional[Image.Image]]:
|
| 575 |
-
"""
|
| 576 |
-
Segment foreground object from image.
|
| 577 |
-
|
| 578 |
-
Args:
|
| 579 |
-
image: Input image as numpy array (H, W, 3) in RGB or BGR
|
| 580 |
-
threshold: Threshold for binary mask (0-1)
|
| 581 |
-
return_type: What to return - "mask", "rgba", or "both"
|
| 582 |
-
|
| 583 |
-
Returns:
|
| 584 |
-
Tuple of (binary_mask, rgba_image) based on return_type
|
| 585 |
-
"""
|
| 586 |
-
# Convert BGR to RGB if needed
|
| 587 |
-
if len(image.shape) == 3 and image.shape[2] == 3:
|
| 588 |
-
if image[0, 0, 0] != image[0, 0, 2]: # Simple heuristic
|
| 589 |
-
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 590 |
-
else:
|
| 591 |
-
image_rgb = image
|
| 592 |
-
else:
|
| 593 |
-
raise ValueError("Input must be a color image (H, W, 3)")
|
| 594 |
-
|
| 595 |
-
# Convert to PIL
|
| 596 |
-
image_pil = Image.fromarray(image_rgb)
|
| 597 |
-
original_size = image_pil.size
|
| 598 |
-
|
| 599 |
-
# Transform
|
| 600 |
-
input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
|
| 601 |
-
if DEVICE == "cpu":
|
| 602 |
-
input_tensor = input_tensor.float()
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
# Inference
|
| 606 |
-
with torch.no_grad():
|
| 607 |
-
if self.model_type == "u2netp":
|
| 608 |
-
outputs = self.model(input_tensor)
|
| 609 |
-
pred = outputs[0] # Main output
|
| 610 |
-
else: # birefnet or rmbg
|
| 611 |
-
pred = self.model(input_tensor)[-1].sigmoid()
|
| 612 |
-
|
| 613 |
-
# Post-process
|
| 614 |
-
pred = pred.squeeze().cpu().numpy()
|
| 615 |
-
|
| 616 |
-
# Resize to original
|
| 617 |
-
pred_resized = cv2.resize(pred, original_size, interpolation=cv2.INTER_LINEAR)
|
| 618 |
-
|
| 619 |
-
# Normalize to 0-255
|
| 620 |
-
pred_normalized = ((pred_resized - pred_resized.min()) /
|
| 621 |
-
(pred_resized.max() - pred_resized.min() + 1e-8) * 255)
|
| 622 |
-
|
| 623 |
-
# Create binary mask
|
| 624 |
-
binary_mask = (pred_normalized > (threshold * 255)).astype(np.uint8) * 255
|
| 625 |
-
|
| 626 |
-
# Optional: Morphological operations for cleaner mask
|
| 627 |
-
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 628 |
-
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
|
| 629 |
-
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
|
| 630 |
-
|
| 631 |
-
# Create RGBA if needed
|
| 632 |
-
rgba_image = None
|
| 633 |
-
if return_type in ["rgba", "both"]:
|
| 634 |
-
# Create 4-channel image
|
| 635 |
-
rgba = np.dstack([image_rgb, binary_mask])
|
| 636 |
-
rgba_image = Image.fromarray(rgba, mode='RGBA')
|
| 637 |
-
|
| 638 |
-
# Return based on type
|
| 639 |
-
if return_type == "mask":
|
| 640 |
-
return binary_mask, None
|
| 641 |
-
elif return_type == "rgba":
|
| 642 |
-
return None, rgba_image
|
| 643 |
-
else: # both
|
| 644 |
-
return binary_mask, rgba_image
|
| 645 |
-
|
| 646 |
# def segment(
|
| 647 |
# self,
|
| 648 |
# image: np.ndarray,
|
| 649 |
# threshold: float = 0.5,
|
| 650 |
# return_type: Literal["mask", "rgba", "both"] = "mask"
|
| 651 |
# ) -> Tuple[Optional[np.ndarray], Optional[Image.Image]]:
|
| 652 |
-
|
| 653 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
# if len(image.shape) == 3 and image.shape[2] == 3:
|
| 655 |
-
#
|
|
|
|
|
|
|
|
|
|
| 656 |
# else:
|
| 657 |
# raise ValueError("Input must be a color image (H, W, 3)")
|
| 658 |
-
|
| 659 |
-
# #
|
| 660 |
-
# orig_h, orig_w = image.shape[:2]
|
| 661 |
-
|
| 662 |
-
# # Convert to PIL for transforms
|
| 663 |
# image_pil = Image.fromarray(image_rgb)
|
| 664 |
-
|
| 665 |
-
|
|
|
|
| 666 |
# input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
|
| 667 |
# if DEVICE == "cpu":
|
| 668 |
# input_tensor = input_tensor.float()
|
| 669 |
|
|
|
|
| 670 |
# # Inference
|
| 671 |
# with torch.no_grad():
|
| 672 |
# if self.model_type == "u2netp":
|
| 673 |
# outputs = self.model(input_tensor)
|
| 674 |
-
# pred = outputs[0]
|
| 675 |
# else: # birefnet or rmbg
|
| 676 |
# pred = self.model(input_tensor)[-1].sigmoid()
|
| 677 |
-
|
| 678 |
-
# # Post-process
|
| 679 |
# pred = pred.squeeze().cpu().numpy()
|
| 680 |
-
|
| 681 |
-
# #
|
| 682 |
-
#
|
| 683 |
-
|
| 684 |
-
# pred,
|
| 685 |
-
# (orig_w, orig_h), # ← correct order for cv2
|
| 686 |
-
# interpolation=cv2.INTER_LINEAR
|
| 687 |
-
# )
|
| 688 |
-
|
| 689 |
-
# # Verify shape matches original
|
| 690 |
-
# assert pred_resized.shape == (orig_h, orig_w), \
|
| 691 |
-
# f"Shape mismatch! Got {pred_resized.shape}, expected ({orig_h}, {orig_w})"
|
| 692 |
-
|
| 693 |
# # Normalize to 0-255
|
| 694 |
-
# pred_normalized = (
|
| 695 |
-
#
|
| 696 |
-
|
| 697 |
-
#
|
| 698 |
-
|
| 699 |
-
# # Binary mask
|
| 700 |
# binary_mask = (pred_normalized > (threshold * 255)).astype(np.uint8) * 255
|
| 701 |
-
|
| 702 |
-
# # Morphological
|
| 703 |
# kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 704 |
# binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
|
| 705 |
# binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
|
| 706 |
-
|
| 707 |
-
# # ✅ Verify final mask dimensions match input
|
| 708 |
-
# assert binary_mask.shape == (orig_h, orig_w), \
|
| 709 |
-
# f"Final mask mismatch! Got {binary_mask.shape}, expected ({orig_h}, {orig_w})"
|
| 710 |
-
|
| 711 |
-
# logger.info(f"Input shape: ({orig_h}, {orig_w}) | Output mask shape: {binary_mask.shape} ✅")
|
| 712 |
-
|
| 713 |
# # Create RGBA if needed
|
| 714 |
# rgba_image = None
|
| 715 |
# if return_type in ["rgba", "both"]:
|
|
|
|
| 716 |
# rgba = np.dstack([image_rgb, binary_mask])
|
| 717 |
# rgba_image = Image.fromarray(rgba, mode='RGBA')
|
| 718 |
-
|
| 719 |
-
#
|
| 720 |
-
# assert rgba_image.size == (orig_w, orig_h), \
|
| 721 |
-
# f"RGBA size mismatch! Got {rgba_image.size}, expected ({orig_w}, {orig_h})"
|
| 722 |
-
|
| 723 |
# if return_type == "mask":
|
| 724 |
# return binary_mask, None
|
| 725 |
# elif return_type == "rgba":
|
| 726 |
# return None, rgba_image
|
| 727 |
-
# else:
|
| 728 |
# return binary_mask, rgba_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
|
| 730 |
def batch_segment(
|
| 731 |
self,
|
|
|
|
| 566 |
except ImportError:
|
| 567 |
raise ImportError("RMBG requires: pip install transformers")
|
| 568 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
# def segment(
|
| 570 |
# self,
|
| 571 |
# image: np.ndarray,
|
| 572 |
# threshold: float = 0.5,
|
| 573 |
# return_type: Literal["mask", "rgba", "both"] = "mask"
|
| 574 |
# ) -> Tuple[Optional[np.ndarray], Optional[Image.Image]]:
|
| 575 |
+
# """
|
| 576 |
+
# Segment foreground object from image.
|
| 577 |
+
|
| 578 |
+
# Args:
|
| 579 |
+
# image: Input image as numpy array (H, W, 3) in RGB or BGR
|
| 580 |
+
# threshold: Threshold for binary mask (0-1)
|
| 581 |
+
# return_type: What to return - "mask", "rgba", or "both"
|
| 582 |
+
|
| 583 |
+
# Returns:
|
| 584 |
+
# Tuple of (binary_mask, rgba_image) based on return_type
|
| 585 |
+
# """
|
| 586 |
+
# # Convert BGR to RGB if needed
|
| 587 |
# if len(image.shape) == 3 and image.shape[2] == 3:
|
| 588 |
+
# if image[0, 0, 0] != image[0, 0, 2]: # Simple heuristic
|
| 589 |
+
# image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 590 |
+
# else:
|
| 591 |
+
# image_rgb = image
|
| 592 |
# else:
|
| 593 |
# raise ValueError("Input must be a color image (H, W, 3)")
|
| 594 |
+
|
| 595 |
+
# # Convert to PIL
|
|
|
|
|
|
|
|
|
|
| 596 |
# image_pil = Image.fromarray(image_rgb)
|
| 597 |
+
# original_size = image_pil.size
|
| 598 |
+
|
| 599 |
+
# # Transform
|
| 600 |
# input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
|
| 601 |
# if DEVICE == "cpu":
|
| 602 |
# input_tensor = input_tensor.float()
|
| 603 |
|
| 604 |
+
|
| 605 |
# # Inference
|
| 606 |
# with torch.no_grad():
|
| 607 |
# if self.model_type == "u2netp":
|
| 608 |
# outputs = self.model(input_tensor)
|
| 609 |
+
# pred = outputs[0] # Main output
|
| 610 |
# else: # birefnet or rmbg
|
| 611 |
# pred = self.model(input_tensor)[-1].sigmoid()
|
| 612 |
+
|
| 613 |
+
# # Post-process
|
| 614 |
# pred = pred.squeeze().cpu().numpy()
|
| 615 |
+
|
| 616 |
+
# # Resize to original
|
| 617 |
+
# pred_resized = cv2.resize(pred, original_size, interpolation=cv2.INTER_LINEAR)
|
| 618 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
# # Normalize to 0-255
|
| 620 |
+
# pred_normalized = ((pred_resized - pred_resized.min()) /
|
| 621 |
+
# (pred_resized.max() - pred_resized.min() + 1e-8) * 255)
|
| 622 |
+
|
| 623 |
+
# # Create binary mask
|
|
|
|
|
|
|
| 624 |
# binary_mask = (pred_normalized > (threshold * 255)).astype(np.uint8) * 255
|
| 625 |
+
|
| 626 |
+
# # Optional: Morphological operations for cleaner mask
|
| 627 |
# kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 628 |
# binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
|
| 629 |
# binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
|
| 630 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
# # Create RGBA if needed
|
| 632 |
# rgba_image = None
|
| 633 |
# if return_type in ["rgba", "both"]:
|
| 634 |
+
# # Create 4-channel image
|
| 635 |
# rgba = np.dstack([image_rgb, binary_mask])
|
| 636 |
# rgba_image = Image.fromarray(rgba, mode='RGBA')
|
| 637 |
+
|
| 638 |
+
# # Return based on type
|
|
|
|
|
|
|
|
|
|
| 639 |
# if return_type == "mask":
|
| 640 |
# return binary_mask, None
|
| 641 |
# elif return_type == "rgba":
|
| 642 |
# return None, rgba_image
|
| 643 |
+
# else: # both
|
| 644 |
# return binary_mask, rgba_image
|
| 645 |
+
|
| 646 |
+
def segment(
|
| 647 |
+
self,
|
| 648 |
+
image: np.ndarray,
|
| 649 |
+
threshold: float = 0.5,
|
| 650 |
+
return_type: Literal["mask", "rgba", "both"] = "mask"
|
| 651 |
+
) -> Tuple[Optional[np.ndarray], Optional[Image.Image]]:
|
| 652 |
+
|
| 653 |
+
# Convert BGR to RGB
|
| 654 |
+
if len(image.shape) == 3 and image.shape[2] == 3:
|
| 655 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 656 |
+
else:
|
| 657 |
+
raise ValueError("Input must be a color image (H, W, 3)")
|
| 658 |
+
|
| 659 |
+
# Store ORIGINAL dimensions (H, W) from numpy
|
| 660 |
+
orig_h, orig_w = image.shape[:2]
|
| 661 |
+
|
| 662 |
+
# Convert to PIL for transforms
|
| 663 |
+
image_pil = Image.fromarray(image_rgb)
|
| 664 |
+
|
| 665 |
+
# Transform (model resizes internally e.g. 320x320 / 512x512)
|
| 666 |
+
input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
|
| 667 |
+
if DEVICE == "cpu":
|
| 668 |
+
input_tensor = input_tensor.float()
|
| 669 |
+
|
| 670 |
+
# Inference
|
| 671 |
+
with torch.no_grad():
|
| 672 |
+
if self.model_type == "u2netp":
|
| 673 |
+
outputs = self.model(input_tensor)
|
| 674 |
+
pred = outputs[0]
|
| 675 |
+
else: # birefnet or rmbg
|
| 676 |
+
pred = self.model(input_tensor)[-1].sigmoid()
|
| 677 |
+
|
| 678 |
+
# Post-process - squeeze to 2D
|
| 679 |
+
pred = pred.squeeze().cpu().numpy()
|
| 680 |
+
|
| 681 |
+
# ✅ FIX: Resize back to ORIGINAL (width, height) for cv2
|
| 682 |
+
# cv2.resize takes (width, height) = (orig_w, orig_h)
|
| 683 |
+
pred_resized = cv2.resize(
|
| 684 |
+
pred,
|
| 685 |
+
(orig_w, orig_h), # ← correct order for cv2
|
| 686 |
+
interpolation=cv2.INTER_LINEAR
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
# Verify shape matches original
|
| 690 |
+
assert pred_resized.shape == (orig_h, orig_w), \
|
| 691 |
+
f"Shape mismatch! Got {pred_resized.shape}, expected ({orig_h}, {orig_w})"
|
| 692 |
+
|
| 693 |
+
# Normalize to 0-255
|
| 694 |
+
pred_normalized = (
|
| 695 |
+
(pred_resized - pred_resized.min()) /
|
| 696 |
+
(pred_resized.max() - pred_resized.min() + 1e-8) * 255
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
# Binary mask
|
| 700 |
+
binary_mask = (pred_normalized > (threshold * 255)).astype(np.uint8) * 255
|
| 701 |
+
|
| 702 |
+
# Morphological cleanup
|
| 703 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 704 |
+
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
|
| 705 |
+
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
|
| 706 |
+
|
| 707 |
+
# ✅ Verify final mask dimensions match input
|
| 708 |
+
assert binary_mask.shape == (orig_h, orig_w), \
|
| 709 |
+
f"Final mask mismatch! Got {binary_mask.shape}, expected ({orig_h}, {orig_w})"
|
| 710 |
+
|
| 711 |
+
logger.info(f"Input shape: ({orig_h}, {orig_w}) | Output mask shape: {binary_mask.shape} ✅")
|
| 712 |
+
|
| 713 |
+
# Create RGBA if needed
|
| 714 |
+
rgba_image = None
|
| 715 |
+
if return_type in ["rgba", "both"]:
|
| 716 |
+
rgba = np.dstack([image_rgb, binary_mask])
|
| 717 |
+
rgba_image = Image.fromarray(rgba, mode='RGBA')
|
| 718 |
+
|
| 719 |
+
# ✅ Verify RGBA dimensions
|
| 720 |
+
assert rgba_image.size == (orig_w, orig_h), \
|
| 721 |
+
f"RGBA size mismatch! Got {rgba_image.size}, expected ({orig_w}, {orig_h})"
|
| 722 |
+
|
| 723 |
+
if return_type == "mask":
|
| 724 |
+
return binary_mask, None
|
| 725 |
+
elif return_type == "rgba":
|
| 726 |
+
return None, rgba_image
|
| 727 |
+
else:
|
| 728 |
+
return binary_mask, rgba_image
|
| 729 |
|
| 730 |
def batch_segment(
|
| 731 |
self,
|