Spaces:
Sleeping
Sleeping
File size: 2,622 Bytes
aa1c1e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import cv2
class SamWrapper:
"""
Wrapper for Segment Anything Model (SAM).
Handles both automatic mask generation and guided segmentation via bounding boxes.
"""
def __init__(self, model_type="vit_b", checkpoint_path=None, device=None):
"""
Initialize the SAM model.
:param model_type: Type of SAM backbone (e.g., 'vit_b', 'vit_l', 'vit_h')
:param checkpoint_path: Path to the .pth checkpoint file
:param device: 'cuda' or 'cpu'; if None, auto-detects
"""
device = "cpu"
self.device = device
self.model = sam_model_registry[model_type](checkpoint=checkpoint_path)
self.model.to(self.device)
self.automatic_generator = SamAutomaticMaskGenerator(
model=self.model,
points_per_side=12,
pred_iou_thresh=0.92,
stability_score_thresh=0.95,
min_mask_region_area=1500,
box_nms_thresh=0.3
)
self.predictor = SamPredictor(self.model)
def generate_masks(self, image, boxes=None):
"""
Generate segmentation masks for the given image.
:param image: Input image as NumPy array (BGR)
:param boxes: Optional list of bounding boxes [x1, y1, x2, y2]
:return: List of binary masks (NumPy arrays)
"""
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if boxes is None:
masks = self.automatic_generator.generate(image_rgb)
return [mask['segmentation'] for mask in masks]
# Set the image once
self.predictor.set_image(image_rgb)
# Convert boxes to tensor and transform
transformed_boxes = self.predictor.transform.apply_boxes_torch(
torch.tensor(boxes, dtype=torch.float32, device=self.device),
image.shape[:2]
)
# Predict masks for all boxes at once
masks, _, _ = self.predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False
)
return [m[0].cpu().numpy() for m in masks]
def predict_with_box(self, image, box):
"""
Predict a single segmentation mask for the given box.
:param image: Input image (BGR)
:param box: One bounding box [x1, y1, x2, y2]
:return: Binary mask (NumPy array)
"""
masks = self.generate_masks(image, boxes=[box])
return masks[0] if masks else None
|