File size: 8,747 Bytes
36fcf33 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | import numpy as np
import cv2
import torch
import sys
import os
# Add sam2 folder to path to import from local sam2 directory
_current_file_dir = os.path.dirname(os.path.abspath(__file__))
_project_root = os.path.dirname(_current_file_dir)
_sam2_repo_dir = os.path.join(_project_root, "sam2")
# Add sam2 directory to sys.path if not already there
abs_sam2_dir = os.path.abspath(_sam2_repo_dir)
if abs_sam2_dir not in sys.path:
sys.path.insert(0, abs_sam2_dir)
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from model.utils import mask_to_polygon
# Hugging Face model ID for SAM2.1 Hiera Large model
HUGGINGFACE_MODEL_ID = "facebook/sam2.1-hiera-large"
device = "cuda" if torch.cuda.is_available() else "cpu"
class SAM2AutoAnnotation:
"""
SAM2 Auto Annotation wrapper for automatically generating masks for all objects in an image.
Uses SAM2AutomaticMaskGenerator from Hugging Face.
"""
def __init__(
self,
points_per_side: int = 32,
points_per_batch: int = 64,
pred_iou_thresh: float = 0.88,
stability_score_thresh: float = 0.95,
min_mask_region_area: int = 100,
):
"""
Initialize SAM2 Auto Annotation.
Args:
points_per_side: Number of points per side of the image grid
points_per_batch: Number of points to process in each batch
pred_iou_thresh: Prediction IoU threshold
stability_score_thresh: Stability score threshold
min_mask_region_area: Minimum mask region area in pixels
"""
self.points_per_side = points_per_side
self.points_per_batch = points_per_batch
self.pred_iou_thresh = pred_iou_thresh
self.stability_score_thresh = stability_score_thresh
self.min_mask_region_area = min_mask_region_area
self._mask_generator = None
def _get_mask_generator(self):
"""Lazy initialization of mask generator."""
if self._mask_generator is None:
try:
# Try to load with configuration parameters first
try:
self._mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(
HUGGINGFACE_MODEL_ID,
device=device,
points_per_side=self.points_per_side,
points_per_batch=self.points_per_batch,
pred_iou_thresh=self.pred_iou_thresh,
stability_score_thresh=self.stability_score_thresh,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=self.min_mask_region_area,
)
except TypeError:
# If parameters are not accepted by from_pretrained, load without them
self._mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(
HUGGINGFACE_MODEL_ID,
device=device
)
# Try to set parameters if the generator supports it
for attr_name in ['points_per_side', 'points_per_batch', 'pred_iou_thresh',
'stability_score_thresh', 'min_mask_region_area']:
if hasattr(self._mask_generator, attr_name):
setattr(self._mask_generator, attr_name, getattr(self, attr_name))
except ImportError as e:
raise RuntimeError(
f"Failed to import required modules for SAM2. Please ensure 'sam2' and 'huggingface_hub' are installed. "
f"Error: {str(e)}"
)
except Exception as e:
raise RuntimeError(
f"Failed to load SAM2 Automatic Mask Generator from Hugging Face ({HUGGINGFACE_MODEL_ID}). "
f"Please check your internet connection and ensure the model ID is correct. "
f"Error: {str(e)}"
)
return self._mask_generator
def generate_masks(
self,
image: np.ndarray,
min_confidence: float = 0.0,
min_area: int = None,
filter_blank_regions: bool = True,
scale_factors: tuple = (1.0, 1.0),
) -> list:
"""
Generate all masks for objects in the image.
Args:
image: Image as numpy array (RGB format, H, W, 3)
min_confidence: Minimum confidence score to filter masks (default: 0.0)
min_area: Minimum mask area in pixels (default: uses self.min_mask_region_area)
filter_blank_regions: Filter out blank/black regions (default: True)
scale_factors: Tuple (scale_x, scale_y) to scale coordinates FROM processed TO display size
(matching predict_polygon_from_point logic)
Returns:
List of mask dictionaries, each containing:
- polygon: flattened coordinates [x1, y1, x2, y2, ...] (scaled to display size)
- confidence: confidence score
- area: mask area in pixels
"""
if min_area is None:
min_area = self.min_mask_region_area
# Get mask generator
mask_generator = self._get_mask_generator()
# Generate all masks automatically
masks = mask_generator.generate(image)
# Convert image to grayscale for blank region detection
if filter_blank_regions:
if len(image.shape) == 3:
gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray_image = image
# Process masks and convert to polygons
results = []
for mask_data in masks:
# Extract mask information
mask = mask_data["segmentation"] # Boolean mask
score = float(mask_data.get("stability_score", mask_data.get("predicted_iou", 0.0)))
area = int(mask_data.get("area", 0))
# Filter by confidence threshold
if score < min_confidence:
continue
# Filter by minimum area
if area < min_area:
continue
# Filter blank/black regions if enabled
if filter_blank_regions:
masked_region = gray_image[mask]
if len(masked_region) > 0:
mean_intensity = float(np.mean(masked_region))
if mean_intensity < 30:
variance = float(np.var(masked_region))
if variance < 100:
continue # Skip blank/black regions
elif mean_intensity < 50:
variance = float(np.var(masked_region))
if variance < 50:
continue # Skip very uniform dark regions
# Convert boolean mask to uint8 format
mask_uint8 = (mask.astype(np.uint8) * 255)
# Convert mask to polygon with proper scaling (matching predict_polygon_from_point)
# scale_factors should represent FROM processed image TO display size
# mask_to_polygon divides by scale_factors to convert FROM processed TO display
polygon = mask_to_polygon(mask_uint8, scale_factors=scale_factors)
results.append({
"polygon": polygon, # Flattened format [x1, y1, x2, y2, ...] (scaled to display size)
"confidence": score,
"area": area
})
return results
def create_sam2_auto_annotation(
points_per_side: int = 32,
points_per_batch: int = 64,
pred_iou_thresh: float = 0.88,
stability_score_thresh: float = 0.95,
min_mask_region_area: int = 100,
) -> SAM2AutoAnnotation:
"""
Factory function to create a SAM2 Auto Annotation instance.
Args:
points_per_side: Number of points per side of the image grid
points_per_batch: Number of points to process in each batch
pred_iou_thresh: Prediction IoU threshold
stability_score_thresh: Stability score threshold
min_mask_region_area: Minimum mask region area in pixels
Returns:
SAM2AutoAnnotation instance
"""
return SAM2AutoAnnotation(
points_per_side=points_per_side,
points_per_batch=points_per_batch,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
min_mask_region_area=min_mask_region_area,
)
|