Spaces:
Runtime error
Runtime error
| """ | |
| models/segmenter.py | |
| -------------------- | |
| SAM (Segment Anything Model) wrapper. | |
| Given a scene image and one or more bounding boxes (from the detector), | |
| produces precise pixel-level masks for each detected object. | |
| """ | |
| import os | |
| import sys | |
| from typing import List, Tuple, Optional | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from config import DEVICE, SAM_CHECKPOINT, SAM_MODEL_TYPE, MASK_DILATION_PX | |
| from image_utils import load_image, save_image, show_mask, show_box, dilate_mask_with_sam_prediction, dilate_mask, combine_masks | |
| class SAMSegmenter: | |
| """ | |
| Wraps SAM to convert bounding boxes into fine-grained masks. | |
| Loaded lazily on first use. | |
| """ | |
| def __init__(self) -> None: | |
| self._predictor = None | |
| def _load(self) -> None: | |
| if self._predictor is not None: | |
| self._predictor.model.to(DEVICE) | |
| return | |
| print(" Loading SAM (this may take a moment) ...") | |
| try: | |
| from segment_anything import sam_model_registry, SamPredictor | |
| sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT) | |
| sam = sam.to(DEVICE) | |
| self._predictor = SamPredictor(sam) | |
| except ImportError as e: | |
| raise RuntimeError( | |
| "segment-anything is not installed.\n" | |
| "Run: pip install git+https://github.com/facebookresearch/segment-anything.git\n" | |
| f"Original error: {e}" | |
| ) | |
| def segment_boxes( | |
| self, | |
| image_pil: Image.Image, | |
| boxes: List[Tuple[int, int, int, int]], | |
| dilation_px: int = MASK_DILATION_PX, | |
| ) -> np.ndarray: | |
| """ | |
| For each box, run SAM and return the combined binary mask (HW, uint8). | |
| Args: | |
| image_pil: The scene image. | |
| boxes: List of (x1, y1, x2, y2) in absolute pixels. | |
| dilation_px: How many pixels to dilate the final mask (covers edges). | |
| Returns: | |
| Combined mask (255 = object, 0 = background). | |
| """ | |
| self._load() | |
| img_np = np.array(image_pil) | |
| h, w = img_np.shape[:2] | |
| self._predictor.set_image(img_np) | |
| individual_masks = [] | |
| for box in boxes: | |
| x1, y1, x2, y2 = box | |
| box_np = np.array([[x1, y1, x2, y2]], dtype=np.float32) | |
| masks, scores, _ = self._predictor.predict( | |
| box=box_np, | |
| multimask_output=True, | |
| ) | |
| # scores shape: (3,); masks shape: (3, H, W) | |
| best_idx = scores.argmax() | |
| best_mask = (masks[best_idx].astype(np.uint8)) * 255 | |
| individual_masks.append(best_mask) | |
| if not individual_masks: | |
| return np.zeros((h, w), dtype=np.uint8) | |
| combined = combine_masks(individual_masks) | |
| if dilation_px > 0: | |
| combined = dilate_mask(combined, dilation_px) | |
| pct = 100 * (combined > 0).sum() / (h * w) | |
| print(f" SAM mask covers {pct:.1f}% of the image") | |
| return combined | |
| def segment_points( | |
| self, | |
| image_pil: Image.Image, | |
| points: List[Tuple[int, int]], | |
| point_labels: Optional[List[int]] = None, | |
| dilation_px: int = MASK_DILATION_PX, | |
| ) -> np.ndarray: | |
| """ | |
| Segment using foreground point prompts (1 = foreground, 0 = background). | |
| Falls back to all-foreground if point_labels is None. | |
| """ | |
| self._load() | |
| img_np = np.array(image_pil) | |
| h, w = img_np.shape[:2] | |
| self._predictor.set_image(img_np) | |
| pts_np = np.array(points, dtype=np.float32) | |
| labels_np = np.array( | |
| point_labels if point_labels else [1] * len(points), dtype=np.int32 | |
| ) | |
| masks, scores, _ = self._predictor.predict( | |
| point_coords=pts_np, | |
| point_labels=labels_np, | |
| multimask_output=True, | |
| ) | |
| best_idx = scores.argmax() | |
| best_mask = (masks[best_idx].astype(np.uint8)) * 255 | |
| if dilation_px > 0: | |
| best_mask = dilate_mask(best_mask, dilation_px) | |
| return best_mask | |