Spaces:
Sleeping
Sleeping
| import torch | |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor | |
| import cv2 | |
| class SamWrapper: | |
| """ | |
| Wrapper for Segment Anything Model (SAM). | |
| Handles both automatic mask generation and guided segmentation via bounding boxes. | |
| """ | |
| def __init__(self, model_type="vit_b", checkpoint_path=None, device=None): | |
| """ | |
| Initialize the SAM model. | |
| :param model_type: Type of SAM backbone (e.g., 'vit_b', 'vit_l', 'vit_h') | |
| :param checkpoint_path: Path to the .pth checkpoint file | |
| :param device: 'cuda' or 'cpu'; if None, auto-detects | |
| """ | |
| device = "cpu" | |
| self.device = device | |
| self.model = sam_model_registry[model_type](checkpoint=checkpoint_path) | |
| self.model.to(self.device) | |
| self.automatic_generator = SamAutomaticMaskGenerator( | |
| model=self.model, | |
| points_per_side=12, | |
| pred_iou_thresh=0.92, | |
| stability_score_thresh=0.95, | |
| min_mask_region_area=1500, | |
| box_nms_thresh=0.3 | |
| ) | |
| self.predictor = SamPredictor(self.model) | |
| def generate_masks(self, image, boxes=None): | |
| """ | |
| Generate segmentation masks for the given image. | |
| :param image: Input image as NumPy array (BGR) | |
| :param boxes: Optional list of bounding boxes [x1, y1, x2, y2] | |
| :return: List of binary masks (NumPy arrays) | |
| """ | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| if boxes is None: | |
| masks = self.automatic_generator.generate(image_rgb) | |
| return [mask['segmentation'] for mask in masks] | |
| # Set the image once | |
| self.predictor.set_image(image_rgb) | |
| # Convert boxes to tensor and transform | |
| transformed_boxes = self.predictor.transform.apply_boxes_torch( | |
| torch.tensor(boxes, dtype=torch.float32, device=self.device), | |
| image.shape[:2] | |
| ) | |
| # Predict masks for all boxes at once | |
| masks, _, _ = self.predictor.predict_torch( | |
| point_coords=None, | |
| point_labels=None, | |
| boxes=transformed_boxes, | |
| multimask_output=False | |
| ) | |
| return [m[0].cpu().numpy() for m in masks] | |
| def predict_with_box(self, image, box): | |
| """ | |
| Predict a single segmentation mask for the given box. | |
| :param image: Input image (BGR) | |
| :param box: One bounding box [x1, y1, x2, y2] | |
| :return: Binary mask (NumPy array) | |
| """ | |
| masks = self.generate_masks(image, boxes=[box]) | |
| return masks[0] if masks else None | |