mansi-object-detector / segmenter.py
mansi-2's picture
Upload 11 files
eab4d9b verified
"""
models/segmenter.py
--------------------
SAM (Segment Anything Model) wrapper.
Given a scene image and one or more bounding boxes (from the detector),
produces precise pixel-level masks for each detected object.
"""
import os
import sys
from typing import List, Tuple, Optional
import numpy as np
import torch
from PIL import Image
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import DEVICE, SAM_CHECKPOINT, SAM_MODEL_TYPE, MASK_DILATION_PX
from image_utils import load_image, save_image, show_mask, show_box, dilate_mask_with_sam_prediction, dilate_mask, combine_masks
class SAMSegmenter:
"""
Wraps SAM to convert bounding boxes into fine-grained masks.
Loaded lazily on first use.
"""
def __init__(self) -> None:
self._predictor = None
def _load(self) -> None:
if self._predictor is not None:
self._predictor.model.to(DEVICE)
return
print(" Loading SAM (this may take a moment) ...")
try:
from segment_anything import sam_model_registry, SamPredictor
sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
sam = sam.to(DEVICE)
self._predictor = SamPredictor(sam)
except ImportError as e:
raise RuntimeError(
"segment-anything is not installed.\n"
"Run: pip install git+https://github.com/facebookresearch/segment-anything.git\n"
f"Original error: {e}"
)
def segment_boxes(
self,
image_pil: Image.Image,
boxes: List[Tuple[int, int, int, int]],
dilation_px: int = MASK_DILATION_PX,
) -> np.ndarray:
"""
For each box, run SAM and return the combined binary mask (HW, uint8).
Args:
image_pil: The scene image.
boxes: List of (x1, y1, x2, y2) in absolute pixels.
dilation_px: How many pixels to dilate the final mask (covers edges).
Returns:
Combined mask (255 = object, 0 = background).
"""
self._load()
img_np = np.array(image_pil)
h, w = img_np.shape[:2]
self._predictor.set_image(img_np)
individual_masks = []
for box in boxes:
x1, y1, x2, y2 = box
box_np = np.array([[x1, y1, x2, y2]], dtype=np.float32)
masks, scores, _ = self._predictor.predict(
box=box_np,
multimask_output=True,
)
# scores shape: (3,); masks shape: (3, H, W)
best_idx = scores.argmax()
best_mask = (masks[best_idx].astype(np.uint8)) * 255
individual_masks.append(best_mask)
if not individual_masks:
return np.zeros((h, w), dtype=np.uint8)
combined = combine_masks(individual_masks)
if dilation_px > 0:
combined = dilate_mask(combined, dilation_px)
pct = 100 * (combined > 0).sum() / (h * w)
print(f" SAM mask covers {pct:.1f}% of the image")
return combined
def segment_points(
self,
image_pil: Image.Image,
points: List[Tuple[int, int]],
point_labels: Optional[List[int]] = None,
dilation_px: int = MASK_DILATION_PX,
) -> np.ndarray:
"""
Segment using foreground point prompts (1 = foreground, 0 = background).
Falls back to all-foreground if point_labels is None.
"""
self._load()
img_np = np.array(image_pil)
h, w = img_np.shape[:2]
self._predictor.set_image(img_np)
pts_np = np.array(points, dtype=np.float32)
labels_np = np.array(
point_labels if point_labels else [1] * len(points), dtype=np.int32
)
masks, scores, _ = self._predictor.predict(
point_coords=pts_np,
point_labels=labels_np,
multimask_output=True,
)
best_idx = scores.argmax()
best_mask = (masks[best_idx].astype(np.uint8)) * 255
if dilation_px > 0:
best_mask = dilate_mask(best_mask, dilation_px)
return best_mask