Spaces:
Sleeping
Sleeping
| """ | |
| Binary Image Segmentation Tool | |
| A lightweight, professional implementation for foreground object segmentation. | |
| Supports multiple models: | |
| - U2NETP (fastest, 1.1M params) | |
| - BiRefNet (best accuracy, larger model) | |
| - RMBG (good balance) | |
| """ | |
| import os | |
| import logging | |
| from pathlib import Path | |
| from typing import Literal, Tuple, Optional | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| import cv2 | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Device configuration | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {DEVICE}") | |
| class U2NETP(torch.nn.Module): | |
| """U2-Net Portrait (U2NETP) - Lightweight segmentation model""" | |
| def __init__(self, in_ch=3, out_ch=1): | |
| super(U2NETP, self).__init__() | |
| # Encoder | |
| self.stage1 = self._make_stage(in_ch, 16, 64) | |
| self.pool12 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) | |
| self.stage2 = self._make_stage(64, 16, 64) | |
| self.pool23 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) | |
| self.stage3 = self._make_stage(64, 16, 64) | |
| self.pool34 = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) | |
| self.stage4 = self._make_stage(64, 16, 64) | |
| # Bridge | |
| self.stage5 = self._make_stage(64, 16, 64) | |
| # Decoder | |
| self.stage4d = self._make_stage(128, 16, 64) | |
| self.stage3d = self._make_stage(128, 16, 64) | |
| self.stage2d = self._make_stage(128, 16, 64) | |
| self.stage1d = self._make_stage(128, 16, 64) | |
| # Side outputs | |
| self.side1 = torch.nn.Conv2d(64, out_ch, 3, padding=1) | |
| self.side2 = torch.nn.Conv2d(64, out_ch, 3, padding=1) | |
| self.side3 = torch.nn.Conv2d(64, out_ch, 3, padding=1) | |
| self.side4 = torch.nn.Conv2d(64, out_ch, 3, padding=1) | |
| self.side5 = torch.nn.Conv2d(64, out_ch, 3, padding=1) | |
| # Output fusion | |
| self.outconv = torch.nn.Conv2d(5 * out_ch, out_ch, 1) | |
| def _make_stage(self, in_ch, mid_ch, out_ch): | |
| return torch.nn.Sequential( | |
| torch.nn.Conv2d(in_ch, mid_ch, 3, padding=1), | |
| torch.nn.ReLU(inplace=True), | |
| torch.nn.Conv2d(mid_ch, mid_ch, 3, padding=1), | |
| torch.nn.ReLU(inplace=True), | |
| torch.nn.Conv2d(mid_ch, out_ch, 3, padding=1), | |
| torch.nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| hx = x | |
| # Encoder | |
| hx1 = self.stage1(hx) | |
| hx = self.pool12(hx1) | |
| hx2 = self.stage2(hx) | |
| hx = self.pool23(hx2) | |
| hx3 = self.stage3(hx) | |
| hx = self.pool34(hx3) | |
| hx4 = self.stage4(hx) | |
| hx5 = self.stage5(hx4) | |
| # Decoder | |
| hx4d = self.stage4d(torch.cat((hx5, hx4), 1)) | |
| hx4dup = torch.nn.functional.interpolate(hx4d, scale_factor=2, mode='bilinear', align_corners=True) | |
| hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) | |
| hx3dup = torch.nn.functional.interpolate(hx3d, scale_factor=2, mode='bilinear', align_corners=True) | |
| hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) | |
| hx2dup = torch.nn.functional.interpolate(hx2d, scale_factor=2, mode='bilinear', align_corners=True) | |
| hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) | |
| # Side outputs | |
| d1 = self.side1(hx1d) | |
| d2 = torch.nn.functional.interpolate(self.side2(hx2d), size=d1.shape[2:], mode='bilinear', align_corners=True) | |
| d3 = torch.nn.functional.interpolate(self.side3(hx3d), size=d1.shape[2:], mode='bilinear', align_corners=True) | |
| d4 = torch.nn.functional.interpolate(self.side4(hx4d), size=d1.shape[2:], mode='bilinear', align_corners=True) | |
| d5 = torch.nn.functional.interpolate(self.side5(hx5), size=d1.shape[2:], mode='bilinear', align_corners=True) | |
| # Fusion | |
| d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5), 1)) | |
| return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5) | |
| class BinarySegmenter: | |
| """ | |
| Professional binary segmentation tool with multiple model backends. | |
| Args: | |
| model_type: Choice of segmentation model | |
| cache_dir: Directory to cache downloaded models | |
| """ | |
| def __init__( | |
| self, | |
| model_type: Literal["u2netp", "birefnet", "rmbg"] = "u2netp", | |
| cache_dir: str = "./.model_cache" | |
| ): | |
| self.model_type = model_type | |
| self.cache_dir = Path(cache_dir) | |
| self.cache_dir.mkdir(exist_ok=True) | |
| self.model = None | |
| self.transform = None | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the specified segmentation model""" | |
| logger.info(f"Loading {self.model_type} model...") | |
| if self.model_type == "u2netp": | |
| self._load_u2netp() | |
| elif self.model_type == "birefnet": | |
| self._load_birefnet() | |
| elif self.model_type == "rmbg": | |
| self._load_rmbg() | |
| else: | |
| raise ValueError(f"Unknown model type: {self.model_type}") | |
| self.model.to(DEVICE) | |
| self.model.eval() | |
| logger.info(f"{self.model_type} loaded successfully") | |
| def _load_u2netp(self): | |
| """Load U2NETP model (1.1M parameters, fastest)""" | |
| self.model = U2NETP(3, 1) | |
| # Try to load pretrained weights | |
| model_path = self.cache_dir / "u2netp.pth" | |
| if model_path.exists(): | |
| logger.info(f"Loading weights from {model_path}") | |
| self.model.load_state_dict( | |
| torch.load(model_path, map_location=DEVICE) | |
| ) | |
| else: | |
| logger.warning(f"No pretrained weights found at {model_path}") | |
| logger.warning("Download from: https://github.com/xuebinqin/U-2-Net") | |
| # Standard ImageNet normalization | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((320, 320)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def _load_birefnet(self): | |
| """Load BiRefNet model (best accuracy, larger)""" | |
| try: | |
| from transformers import AutoModelForImageSegmentation | |
| self.model = AutoModelForImageSegmentation.from_pretrained( | |
| 'ZhengPeng7/BiRefNet', | |
| trust_remote_code=True, | |
| cache_dir=str(self.cache_dir) | |
| ) | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| except ImportError: | |
| raise ImportError("BiRefNet requires: pip install transformers") | |
| def _load_rmbg(self): | |
| """Load RMBG model (good balance)""" | |
| try: | |
| from transformers import AutoModelForImageSegmentation | |
| self.model = AutoModelForImageSegmentation.from_pretrained( | |
| 'briaai/RMBG-1.4', | |
| trust_remote_code=True, | |
| cache_dir=str(self.cache_dir) | |
| ) | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| except ImportError: | |
| raise ImportError("RMBG requires: pip install transformers") | |
| def segment( | |
| self, | |
| image: np.ndarray, | |
| threshold: float = 0.5, | |
| return_type: Literal["mask", "rgba", "both"] = "mask" | |
| ) -> Tuple[Optional[np.ndarray], Optional[Image.Image]]: | |
| """ | |
| Segment foreground object from image. | |
| Args: | |
| image: Input image as numpy array (H, W, 3) in RGB or BGR | |
| threshold: Threshold for binary mask (0-1) | |
| return_type: What to return - "mask", "rgba", or "both" | |
| Returns: | |
| Tuple of (binary_mask, rgba_image) based on return_type | |
| """ | |
| # Convert BGR to RGB if needed | |
| if len(image.shape) == 3 and image.shape[2] == 3: | |
| if image[0, 0, 0] != image[0, 0, 2]: # Simple heuristic | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| else: | |
| image_rgb = image | |
| else: | |
| raise ValueError("Input must be a color image (H, W, 3)") | |
| # Convert to PIL | |
| image_pil = Image.fromarray(image_rgb) | |
| original_size = image_pil.size | |
| # Transform | |
| input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE) | |
| # Inference | |
| with torch.no_grad(): | |
| if self.model_type == "u2netp": | |
| outputs = self.model(input_tensor) | |
| pred = outputs[0] # Main output | |
| else: # birefnet or rmbg | |
| pred = self.model(input_tensor)[-1].sigmoid() | |
| # Post-process | |
| pred = pred.squeeze().cpu().numpy() | |
| # Resize to original | |
| pred_resized = cv2.resize(pred, original_size, interpolation=cv2.INTER_LINEAR) | |
| # Normalize to 0-255 | |
| pred_normalized = ((pred_resized - pred_resized.min()) / | |
| (pred_resized.max() - pred_resized.min() + 1e-8) * 255) | |
| # Create binary mask | |
| binary_mask = (pred_normalized > (threshold * 255)).astype(np.uint8) * 255 | |
| # Optional: Morphological operations for cleaner mask | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel) | |
| binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel) | |
| # Create RGBA if needed | |
| rgba_image = None | |
| if return_type in ["rgba", "both"]: | |
| # Create 4-channel image | |
| rgba = np.dstack([image_rgb, binary_mask]) | |
| rgba_image = Image.fromarray(rgba, mode='RGBA') | |
| # Return based on type | |
| if return_type == "mask": | |
| return binary_mask, None | |
| elif return_type == "rgba": | |
| return None, rgba_image | |
| else: # both | |
| return binary_mask, rgba_image | |
| def batch_segment( | |
| self, | |
| images: list[np.ndarray], | |
| threshold: float = 0.5, | |
| return_type: Literal["mask", "rgba", "both"] = "mask" | |
| ) -> list: | |
| """ | |
| Segment multiple images in batch. | |
| Args: | |
| images: List of input images | |
| threshold: Threshold for binary masks | |
| return_type: What to return for each image | |
| Returns: | |
| List of segmentation results | |
| """ | |
| results = [] | |
| for i, img in enumerate(images): | |
| logger.info(f"Processing image {i+1}/{len(images)}") | |
| result = self.segment(img, threshold, return_type) | |
| results.append(result) | |
| return results | |
| def segment_image_file( | |
| input_path: str, | |
| output_path: str, | |
| model_type: str = "u2netp", | |
| threshold: float = 0.5, | |
| save_rgba: bool = True | |
| ): | |
| """ | |
| Convenience function to segment an image file. | |
| Args: | |
| input_path: Path to input image | |
| output_path: Path to save output (mask or RGBA) | |
| model_type: Model to use | |
| threshold: Segmentation threshold | |
| save_rgba: If True, save RGBA; if False, save binary mask | |
| """ | |
| # Load image | |
| image = cv2.imread(input_path) | |
| if image is None: | |
| raise FileNotFoundError(f"Could not load image: {input_path}") | |
| # Create segmenter | |
| segmenter = BinarySegmenter(model_type=model_type) | |
| # Segment | |
| return_type = "rgba" if save_rgba else "mask" | |
| mask, rgba = segmenter.segment(image, threshold, return_type) | |
| # Save | |
| output_path = Path(output_path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| if save_rgba and rgba is not None: | |
| rgba.save(output_path) | |
| logger.info(f"Saved RGBA to: {output_path}") | |
| elif mask is not None: | |
| cv2.imwrite(str(output_path), mask) | |
| logger.info(f"Saved mask to: {output_path}") | |
| return str(output_path) | |
| # Example usage | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Binary image segmentation") | |
| parser.add_argument("input", help="Input image path") | |
| parser.add_argument("output", help="Output path") | |
| parser.add_argument( | |
| "--model", | |
| choices=["u2netp", "birefnet", "rmbg"], | |
| default="u2netp", | |
| help="Segmentation model" | |
| ) | |
| parser.add_argument( | |
| "--threshold", | |
| type=float, | |
| default=0.5, | |
| help="Segmentation threshold (0-1)" | |
| ) | |
| parser.add_argument( | |
| "--format", | |
| choices=["mask", "rgba"], | |
| default="rgba", | |
| help="Output format" | |
| ) | |
| args = parser.parse_args() | |
| # Process | |
| segment_image_file( | |
| args.input, | |
| args.output, | |
| model_type=args.model, | |
| threshold=args.threshold, | |
| save_rgba=(args.format == "rgba") | |
| ) | |