""" models/segmenter.py -------------------- SAM (Segment Anything Model) wrapper. Given a scene image and one or more bounding boxes (from the detector), produces precise pixel-level masks for each detected object. """ import os import sys from typing import List, Tuple, Optional import numpy as np import torch from PIL import Image sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from config import DEVICE, SAM_CHECKPOINT, SAM_MODEL_TYPE, MASK_DILATION_PX from image_utils import load_image, save_image, show_mask, show_box, dilate_mask_with_sam_prediction, dilate_mask, combine_masks class SAMSegmenter: """ Wraps SAM to convert bounding boxes into fine-grained masks. Loaded lazily on first use. """ def __init__(self) -> None: self._predictor = None def _load(self) -> None: if self._predictor is not None: self._predictor.model.to(DEVICE) return print(" Loading SAM (this may take a moment) ...") try: from segment_anything import sam_model_registry, SamPredictor sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT) sam = sam.to(DEVICE) self._predictor = SamPredictor(sam) except ImportError as e: raise RuntimeError( "segment-anything is not installed.\n" "Run: pip install git+https://github.com/facebookresearch/segment-anything.git\n" f"Original error: {e}" ) def segment_boxes( self, image_pil: Image.Image, boxes: List[Tuple[int, int, int, int]], dilation_px: int = MASK_DILATION_PX, ) -> np.ndarray: """ For each box, run SAM and return the combined binary mask (HW, uint8). Args: image_pil: The scene image. boxes: List of (x1, y1, x2, y2) in absolute pixels. dilation_px: How many pixels to dilate the final mask (covers edges). Returns: Combined mask (255 = object, 0 = background). """ self._load() img_np = np.array(image_pil) h, w = img_np.shape[:2] self._predictor.set_image(img_np) individual_masks = [] for box in boxes: x1, y1, x2, y2 = box box_np = np.array([[x1, y1, x2, y2]], dtype=np.float32) masks, scores, _ = self._predictor.predict( box=box_np, multimask_output=True, ) # scores shape: (3,); masks shape: (3, H, W) best_idx = scores.argmax() best_mask = (masks[best_idx].astype(np.uint8)) * 255 individual_masks.append(best_mask) if not individual_masks: return np.zeros((h, w), dtype=np.uint8) combined = combine_masks(individual_masks) if dilation_px > 0: combined = dilate_mask(combined, dilation_px) pct = 100 * (combined > 0).sum() / (h * w) print(f" SAM mask covers {pct:.1f}% of the image") return combined def segment_points( self, image_pil: Image.Image, points: List[Tuple[int, int]], point_labels: Optional[List[int]] = None, dilation_px: int = MASK_DILATION_PX, ) -> np.ndarray: """ Segment using foreground point prompts (1 = foreground, 0 = background). Falls back to all-foreground if point_labels is None. """ self._load() img_np = np.array(image_pil) h, w = img_np.shape[:2] self._predictor.set_image(img_np) pts_np = np.array(points, dtype=np.float32) labels_np = np.array( point_labels if point_labels else [1] * len(points), dtype=np.int32 ) masks, scores, _ = self._predictor.predict( point_coords=pts_np, point_labels=labels_np, multimask_output=True, ) best_idx = scores.argmax() best_mask = (masks[best_idx].astype(np.uint8)) * 255 if dilation_px > 0: best_mask = dilate_mask(best_mask, dilation_px) return best_mask