Ayesha-Majeed commited on
Commit
034884b
·
verified ·
1 Parent(s): f4aa803

Update binary_segmentation.py

Browse files
Files changed (1) hide show
  1. binary_segmentation.py +124 -124
binary_segmentation.py CHANGED
@@ -566,166 +566,166 @@ class BinarySegmenter:
566
  except ImportError:
567
  raise ImportError("RMBG requires: pip install transformers")
568
 
569
- def segment(
570
- self,
571
- image: np.ndarray,
572
- threshold: float = 0.5,
573
- return_type: Literal["mask", "rgba", "both"] = "mask"
574
- ) -> Tuple[Optional[np.ndarray], Optional[Image.Image]]:
575
- """
576
- Segment foreground object from image.
577
-
578
- Args:
579
- image: Input image as numpy array (H, W, 3) in RGB or BGR
580
- threshold: Threshold for binary mask (0-1)
581
- return_type: What to return - "mask", "rgba", or "both"
582
-
583
- Returns:
584
- Tuple of (binary_mask, rgba_image) based on return_type
585
- """
586
- # Convert BGR to RGB if needed
587
- if len(image.shape) == 3 and image.shape[2] == 3:
588
- if image[0, 0, 0] != image[0, 0, 2]: # Simple heuristic
589
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
590
- else:
591
- image_rgb = image
592
- else:
593
- raise ValueError("Input must be a color image (H, W, 3)")
594
-
595
- # Convert to PIL
596
- image_pil = Image.fromarray(image_rgb)
597
- original_size = image_pil.size
598
-
599
- # Transform
600
- input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
601
- if DEVICE == "cpu":
602
- input_tensor = input_tensor.float()
603
-
604
-
605
- # Inference
606
- with torch.no_grad():
607
- if self.model_type == "u2netp":
608
- outputs = self.model(input_tensor)
609
- pred = outputs[0] # Main output
610
- else: # birefnet or rmbg
611
- pred = self.model(input_tensor)[-1].sigmoid()
612
-
613
- # Post-process
614
- pred = pred.squeeze().cpu().numpy()
615
-
616
- # Resize to original
617
- pred_resized = cv2.resize(pred, original_size, interpolation=cv2.INTER_LINEAR)
618
-
619
- # Normalize to 0-255
620
- pred_normalized = ((pred_resized - pred_resized.min()) /
621
- (pred_resized.max() - pred_resized.min() + 1e-8) * 255)
622
-
623
- # Create binary mask
624
- binary_mask = (pred_normalized > (threshold * 255)).astype(np.uint8) * 255
625
-
626
- # Optional: Morphological operations for cleaner mask
627
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
628
- binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
629
- binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
630
-
631
- # Create RGBA if needed
632
- rgba_image = None
633
- if return_type in ["rgba", "both"]:
634
- # Create 4-channel image
635
- rgba = np.dstack([image_rgb, binary_mask])
636
- rgba_image = Image.fromarray(rgba, mode='RGBA')
637
-
638
- # Return based on type
639
- if return_type == "mask":
640
- return binary_mask, None
641
- elif return_type == "rgba":
642
- return None, rgba_image
643
- else: # both
644
- return binary_mask, rgba_image
645
-
646
  # def segment(
647
  # self,
648
  # image: np.ndarray,
649
  # threshold: float = 0.5,
650
  # return_type: Literal["mask", "rgba", "both"] = "mask"
651
  # ) -> Tuple[Optional[np.ndarray], Optional[Image.Image]]:
652
-
653
- # # Convert BGR to RGB
 
 
 
 
 
 
 
 
 
 
654
  # if len(image.shape) == 3 and image.shape[2] == 3:
655
- # image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
 
 
656
  # else:
657
  # raise ValueError("Input must be a color image (H, W, 3)")
658
-
659
- # # Store ORIGINAL dimensions (H, W) from numpy
660
- # orig_h, orig_w = image.shape[:2]
661
-
662
- # # Convert to PIL for transforms
663
  # image_pil = Image.fromarray(image_rgb)
664
-
665
- # # Transform (model resizes internally e.g. 320x320 / 512x512)
 
666
  # input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
667
  # if DEVICE == "cpu":
668
  # input_tensor = input_tensor.float()
669
 
 
670
  # # Inference
671
  # with torch.no_grad():
672
  # if self.model_type == "u2netp":
673
  # outputs = self.model(input_tensor)
674
- # pred = outputs[0]
675
  # else: # birefnet or rmbg
676
  # pred = self.model(input_tensor)[-1].sigmoid()
677
-
678
- # # Post-process - squeeze to 2D
679
  # pred = pred.squeeze().cpu().numpy()
680
-
681
- # # ✅ FIX: Resize back to ORIGINAL (width, height) for cv2
682
- # # cv2.resize takes (width, height) = (orig_w, orig_h)
683
- # pred_resized = cv2.resize(
684
- # pred,
685
- # (orig_w, orig_h), # ← correct order for cv2
686
- # interpolation=cv2.INTER_LINEAR
687
- # )
688
-
689
- # # Verify shape matches original
690
- # assert pred_resized.shape == (orig_h, orig_w), \
691
- # f"Shape mismatch! Got {pred_resized.shape}, expected ({orig_h}, {orig_w})"
692
-
693
  # # Normalize to 0-255
694
- # pred_normalized = (
695
- # (pred_resized - pred_resized.min()) /
696
- # (pred_resized.max() - pred_resized.min() + 1e-8) * 255
697
- # )
698
-
699
- # # Binary mask
700
  # binary_mask = (pred_normalized > (threshold * 255)).astype(np.uint8) * 255
701
-
702
- # # Morphological cleanup
703
  # kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
704
  # binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
705
  # binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
706
-
707
- # # ✅ Verify final mask dimensions match input
708
- # assert binary_mask.shape == (orig_h, orig_w), \
709
- # f"Final mask mismatch! Got {binary_mask.shape}, expected ({orig_h}, {orig_w})"
710
-
711
- # logger.info(f"Input shape: ({orig_h}, {orig_w}) | Output mask shape: {binary_mask.shape} ✅")
712
-
713
  # # Create RGBA if needed
714
  # rgba_image = None
715
  # if return_type in ["rgba", "both"]:
 
716
  # rgba = np.dstack([image_rgb, binary_mask])
717
  # rgba_image = Image.fromarray(rgba, mode='RGBA')
718
-
719
- # # Verify RGBA dimensions
720
- # assert rgba_image.size == (orig_w, orig_h), \
721
- # f"RGBA size mismatch! Got {rgba_image.size}, expected ({orig_w}, {orig_h})"
722
-
723
  # if return_type == "mask":
724
  # return binary_mask, None
725
  # elif return_type == "rgba":
726
  # return None, rgba_image
727
- # else:
728
  # return binary_mask, rgba_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
729
 
730
  def batch_segment(
731
  self,
 
566
  except ImportError:
567
  raise ImportError("RMBG requires: pip install transformers")
568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  # def segment(
570
  # self,
571
  # image: np.ndarray,
572
  # threshold: float = 0.5,
573
  # return_type: Literal["mask", "rgba", "both"] = "mask"
574
  # ) -> Tuple[Optional[np.ndarray], Optional[Image.Image]]:
575
+ # """
576
+ # Segment foreground object from image.
577
+
578
+ # Args:
579
+ # image: Input image as numpy array (H, W, 3) in RGB or BGR
580
+ # threshold: Threshold for binary mask (0-1)
581
+ # return_type: What to return - "mask", "rgba", or "both"
582
+
583
+ # Returns:
584
+ # Tuple of (binary_mask, rgba_image) based on return_type
585
+ # """
586
+ # # Convert BGR to RGB if needed
587
  # if len(image.shape) == 3 and image.shape[2] == 3:
588
+ # if image[0, 0, 0] != image[0, 0, 2]: # Simple heuristic
589
+ # image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
590
+ # else:
591
+ # image_rgb = image
592
  # else:
593
  # raise ValueError("Input must be a color image (H, W, 3)")
594
+
595
+ # # Convert to PIL
 
 
 
596
  # image_pil = Image.fromarray(image_rgb)
597
+ # original_size = image_pil.size
598
+
599
+ # # Transform
600
  # input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
601
  # if DEVICE == "cpu":
602
  # input_tensor = input_tensor.float()
603
 
604
+
605
  # # Inference
606
  # with torch.no_grad():
607
  # if self.model_type == "u2netp":
608
  # outputs = self.model(input_tensor)
609
+ # pred = outputs[0] # Main output
610
  # else: # birefnet or rmbg
611
  # pred = self.model(input_tensor)[-1].sigmoid()
612
+
613
+ # # Post-process
614
  # pred = pred.squeeze().cpu().numpy()
615
+
616
+ # # Resize to original
617
+ # pred_resized = cv2.resize(pred, original_size, interpolation=cv2.INTER_LINEAR)
618
+
 
 
 
 
 
 
 
 
 
619
  # # Normalize to 0-255
620
+ # pred_normalized = ((pred_resized - pred_resized.min()) /
621
+ # (pred_resized.max() - pred_resized.min() + 1e-8) * 255)
622
+
623
+ # # Create binary mask
 
 
624
  # binary_mask = (pred_normalized > (threshold * 255)).astype(np.uint8) * 255
625
+
626
+ # # Optional: Morphological operations for cleaner mask
627
  # kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
628
  # binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
629
  # binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
630
+
 
 
 
 
 
 
631
  # # Create RGBA if needed
632
  # rgba_image = None
633
  # if return_type in ["rgba", "both"]:
634
+ # # Create 4-channel image
635
  # rgba = np.dstack([image_rgb, binary_mask])
636
  # rgba_image = Image.fromarray(rgba, mode='RGBA')
637
+
638
+ # # Return based on type
 
 
 
639
  # if return_type == "mask":
640
  # return binary_mask, None
641
  # elif return_type == "rgba":
642
  # return None, rgba_image
643
+ # else: # both
644
  # return binary_mask, rgba_image
645
+
646
+ def segment(
647
+ self,
648
+ image: np.ndarray,
649
+ threshold: float = 0.5,
650
+ return_type: Literal["mask", "rgba", "both"] = "mask"
651
+ ) -> Tuple[Optional[np.ndarray], Optional[Image.Image]]:
652
+
653
+ # Convert BGR to RGB
654
+ if len(image.shape) == 3 and image.shape[2] == 3:
655
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
656
+ else:
657
+ raise ValueError("Input must be a color image (H, W, 3)")
658
+
659
+ # Store ORIGINAL dimensions (H, W) from numpy
660
+ orig_h, orig_w = image.shape[:2]
661
+
662
+ # Convert to PIL for transforms
663
+ image_pil = Image.fromarray(image_rgb)
664
+
665
+ # Transform (model resizes internally e.g. 320x320 / 512x512)
666
+ input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
667
+ if DEVICE == "cpu":
668
+ input_tensor = input_tensor.float()
669
+
670
+ # Inference
671
+ with torch.no_grad():
672
+ if self.model_type == "u2netp":
673
+ outputs = self.model(input_tensor)
674
+ pred = outputs[0]
675
+ else: # birefnet or rmbg
676
+ pred = self.model(input_tensor)[-1].sigmoid()
677
+
678
+ # Post-process - squeeze to 2D
679
+ pred = pred.squeeze().cpu().numpy()
680
+
681
+ # ✅ FIX: Resize back to ORIGINAL (width, height) for cv2
682
+ # cv2.resize takes (width, height) = (orig_w, orig_h)
683
+ pred_resized = cv2.resize(
684
+ pred,
685
+ (orig_w, orig_h), # ← correct order for cv2
686
+ interpolation=cv2.INTER_LINEAR
687
+ )
688
+
689
+ # Verify shape matches original
690
+ assert pred_resized.shape == (orig_h, orig_w), \
691
+ f"Shape mismatch! Got {pred_resized.shape}, expected ({orig_h}, {orig_w})"
692
+
693
+ # Normalize to 0-255
694
+ pred_normalized = (
695
+ (pred_resized - pred_resized.min()) /
696
+ (pred_resized.max() - pred_resized.min() + 1e-8) * 255
697
+ )
698
+
699
+ # Binary mask
700
+ binary_mask = (pred_normalized > (threshold * 255)).astype(np.uint8) * 255
701
+
702
+ # Morphological cleanup
703
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
704
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
705
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
706
+
707
+ # ✅ Verify final mask dimensions match input
708
+ assert binary_mask.shape == (orig_h, orig_w), \
709
+ f"Final mask mismatch! Got {binary_mask.shape}, expected ({orig_h}, {orig_w})"
710
+
711
+ logger.info(f"Input shape: ({orig_h}, {orig_w}) | Output mask shape: {binary_mask.shape} ✅")
712
+
713
+ # Create RGBA if needed
714
+ rgba_image = None
715
+ if return_type in ["rgba", "both"]:
716
+ rgba = np.dstack([image_rgb, binary_mask])
717
+ rgba_image = Image.fromarray(rgba, mode='RGBA')
718
+
719
+ # ✅ Verify RGBA dimensions
720
+ assert rgba_image.size == (orig_w, orig_h), \
721
+ f"RGBA size mismatch! Got {rgba_image.size}, expected ({orig_w}, {orig_h})"
722
+
723
+ if return_type == "mask":
724
+ return binary_mask, None
725
+ elif return_type == "rgba":
726
+ return None, rgba_image
727
+ else:
728
+ return binary_mask, rgba_image
729
 
730
  def batch_segment(
731
  self,