Spaces:
Runtime error
Runtime error
File size: 4,331 Bytes
eab4d9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """
models/segmenter.py
--------------------
SAM (Segment Anything Model) wrapper.
Given a scene image and one or more bounding boxes (from the detector),
produces precise pixel-level masks for each detected object.
"""
import os
import sys
from typing import List, Tuple, Optional
import numpy as np
import torch
from PIL import Image
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import DEVICE, SAM_CHECKPOINT, SAM_MODEL_TYPE, MASK_DILATION_PX
from image_utils import load_image, save_image, show_mask, show_box, dilate_mask_with_sam_prediction, dilate_mask, combine_masks
class SAMSegmenter:
"""
Wraps SAM to convert bounding boxes into fine-grained masks.
Loaded lazily on first use.
"""
def __init__(self) -> None:
self._predictor = None
def _load(self) -> None:
if self._predictor is not None:
self._predictor.model.to(DEVICE)
return
print(" Loading SAM (this may take a moment) ...")
try:
from segment_anything import sam_model_registry, SamPredictor
sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
sam = sam.to(DEVICE)
self._predictor = SamPredictor(sam)
except ImportError as e:
raise RuntimeError(
"segment-anything is not installed.\n"
"Run: pip install git+https://github.com/facebookresearch/segment-anything.git\n"
f"Original error: {e}"
)
def segment_boxes(
self,
image_pil: Image.Image,
boxes: List[Tuple[int, int, int, int]],
dilation_px: int = MASK_DILATION_PX,
) -> np.ndarray:
"""
For each box, run SAM and return the combined binary mask (HW, uint8).
Args:
image_pil: The scene image.
boxes: List of (x1, y1, x2, y2) in absolute pixels.
dilation_px: How many pixels to dilate the final mask (covers edges).
Returns:
Combined mask (255 = object, 0 = background).
"""
self._load()
img_np = np.array(image_pil)
h, w = img_np.shape[:2]
self._predictor.set_image(img_np)
individual_masks = []
for box in boxes:
x1, y1, x2, y2 = box
box_np = np.array([[x1, y1, x2, y2]], dtype=np.float32)
masks, scores, _ = self._predictor.predict(
box=box_np,
multimask_output=True,
)
# scores shape: (3,); masks shape: (3, H, W)
best_idx = scores.argmax()
best_mask = (masks[best_idx].astype(np.uint8)) * 255
individual_masks.append(best_mask)
if not individual_masks:
return np.zeros((h, w), dtype=np.uint8)
combined = combine_masks(individual_masks)
if dilation_px > 0:
combined = dilate_mask(combined, dilation_px)
pct = 100 * (combined > 0).sum() / (h * w)
print(f" SAM mask covers {pct:.1f}% of the image")
return combined
def segment_points(
self,
image_pil: Image.Image,
points: List[Tuple[int, int]],
point_labels: Optional[List[int]] = None,
dilation_px: int = MASK_DILATION_PX,
) -> np.ndarray:
"""
Segment using foreground point prompts (1 = foreground, 0 = background).
Falls back to all-foreground if point_labels is None.
"""
self._load()
img_np = np.array(image_pil)
h, w = img_np.shape[:2]
self._predictor.set_image(img_np)
pts_np = np.array(points, dtype=np.float32)
labels_np = np.array(
point_labels if point_labels else [1] * len(points), dtype=np.int32
)
masks, scores, _ = self._predictor.predict(
point_coords=pts_np,
point_labels=labels_np,
multimask_output=True,
)
best_idx = scores.argmax()
best_mask = (masks[best_idx].astype(np.uint8)) * 255
if dilation_px > 0:
best_mask = dilate_mask(best_mask, dilation_px)
return best_mask
|