SAM2-Image-Auto-Segment / model /sam2_detection_function.py
Singh
Initial deployment
36fcf33
import numpy as np
import cv2
import torch
import sys
import os
# Add sam2 folder to path to import from local sam2 directory
_current_file_dir = os.path.dirname(os.path.abspath(__file__))
_project_root = os.path.dirname(_current_file_dir)
_sam2_repo_dir = os.path.join(_project_root, "sam2")
# Add sam2 directory to sys.path if not already there
abs_sam2_dir = os.path.abspath(_sam2_repo_dir)
if abs_sam2_dir not in sys.path:
sys.path.insert(0, abs_sam2_dir)
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from model.utils import mask_to_polygon
# Hugging Face model ID for SAM2.1 Hiera Large model
HUGGINGFACE_MODEL_ID = "facebook/sam2.1-hiera-large"
device = "cuda" if torch.cuda.is_available() else "cpu"
class SAM2AutoAnnotation:
"""
SAM2 Auto Annotation wrapper for automatically generating masks for all objects in an image.
Uses SAM2AutomaticMaskGenerator from Hugging Face.
"""
def __init__(
self,
points_per_side: int = 32,
points_per_batch: int = 64,
pred_iou_thresh: float = 0.88,
stability_score_thresh: float = 0.95,
min_mask_region_area: int = 100,
):
"""
Initialize SAM2 Auto Annotation.
Args:
points_per_side: Number of points per side of the image grid
points_per_batch: Number of points to process in each batch
pred_iou_thresh: Prediction IoU threshold
stability_score_thresh: Stability score threshold
min_mask_region_area: Minimum mask region area in pixels
"""
self.points_per_side = points_per_side
self.points_per_batch = points_per_batch
self.pred_iou_thresh = pred_iou_thresh
self.stability_score_thresh = stability_score_thresh
self.min_mask_region_area = min_mask_region_area
self._mask_generator = None
def _get_mask_generator(self):
"""Lazy initialization of mask generator."""
if self._mask_generator is None:
try:
# Try to load with configuration parameters first
try:
self._mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(
HUGGINGFACE_MODEL_ID,
device=device,
points_per_side=self.points_per_side,
points_per_batch=self.points_per_batch,
pred_iou_thresh=self.pred_iou_thresh,
stability_score_thresh=self.stability_score_thresh,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=self.min_mask_region_area,
)
except TypeError:
# If parameters are not accepted by from_pretrained, load without them
self._mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(
HUGGINGFACE_MODEL_ID,
device=device
)
# Try to set parameters if the generator supports it
for attr_name in ['points_per_side', 'points_per_batch', 'pred_iou_thresh',
'stability_score_thresh', 'min_mask_region_area']:
if hasattr(self._mask_generator, attr_name):
setattr(self._mask_generator, attr_name, getattr(self, attr_name))
except ImportError as e:
raise RuntimeError(
f"Failed to import required modules for SAM2. Please ensure 'sam2' and 'huggingface_hub' are installed. "
f"Error: {str(e)}"
)
except Exception as e:
raise RuntimeError(
f"Failed to load SAM2 Automatic Mask Generator from Hugging Face ({HUGGINGFACE_MODEL_ID}). "
f"Please check your internet connection and ensure the model ID is correct. "
f"Error: {str(e)}"
)
return self._mask_generator
def generate_masks(
self,
image: np.ndarray,
min_confidence: float = 0.0,
min_area: int = None,
filter_blank_regions: bool = True,
scale_factors: tuple = (1.0, 1.0),
) -> list:
"""
Generate all masks for objects in the image.
Args:
image: Image as numpy array (RGB format, H, W, 3)
min_confidence: Minimum confidence score to filter masks (default: 0.0)
min_area: Minimum mask area in pixels (default: uses self.min_mask_region_area)
filter_blank_regions: Filter out blank/black regions (default: True)
scale_factors: Tuple (scale_x, scale_y) to scale coordinates FROM processed TO display size
(matching predict_polygon_from_point logic)
Returns:
List of mask dictionaries, each containing:
- polygon: flattened coordinates [x1, y1, x2, y2, ...] (scaled to display size)
- confidence: confidence score
- area: mask area in pixels
"""
if min_area is None:
min_area = self.min_mask_region_area
# Get mask generator
mask_generator = self._get_mask_generator()
# Generate all masks automatically
masks = mask_generator.generate(image)
# Convert image to grayscale for blank region detection
if filter_blank_regions:
if len(image.shape) == 3:
gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray_image = image
# Process masks and convert to polygons
results = []
for mask_data in masks:
# Extract mask information
mask = mask_data["segmentation"] # Boolean mask
score = float(mask_data.get("stability_score", mask_data.get("predicted_iou", 0.0)))
area = int(mask_data.get("area", 0))
# Filter by confidence threshold
if score < min_confidence:
continue
# Filter by minimum area
if area < min_area:
continue
# Filter blank/black regions if enabled
if filter_blank_regions:
masked_region = gray_image[mask]
if len(masked_region) > 0:
mean_intensity = float(np.mean(masked_region))
if mean_intensity < 30:
variance = float(np.var(masked_region))
if variance < 100:
continue # Skip blank/black regions
elif mean_intensity < 50:
variance = float(np.var(masked_region))
if variance < 50:
continue # Skip very uniform dark regions
# Convert boolean mask to uint8 format
mask_uint8 = (mask.astype(np.uint8) * 255)
# Convert mask to polygon with proper scaling (matching predict_polygon_from_point)
# scale_factors should represent FROM processed image TO display size
# mask_to_polygon divides by scale_factors to convert FROM processed TO display
polygon = mask_to_polygon(mask_uint8, scale_factors=scale_factors)
results.append({
"polygon": polygon, # Flattened format [x1, y1, x2, y2, ...] (scaled to display size)
"confidence": score,
"area": area
})
return results
def create_sam2_auto_annotation(
points_per_side: int = 32,
points_per_batch: int = 64,
pred_iou_thresh: float = 0.88,
stability_score_thresh: float = 0.95,
min_mask_region_area: int = 100,
) -> SAM2AutoAnnotation:
"""
Factory function to create a SAM2 Auto Annotation instance.
Args:
points_per_side: Number of points per side of the image grid
points_per_batch: Number of points to process in each batch
pred_iou_thresh: Prediction IoU threshold
stability_score_thresh: Stability score threshold
min_mask_region_area: Minimum mask region area in pixels
Returns:
SAM2AutoAnnotation instance
"""
return SAM2AutoAnnotation(
points_per_side=points_per_side,
points_per_batch=points_per_batch,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
min_mask_region_area=min_mask_region_area,
)