| import numpy as np |
| import cv2 |
| import torch |
| import sys |
| import os |
|
|
| |
| _current_file_dir = os.path.dirname(os.path.abspath(__file__)) |
| _project_root = os.path.dirname(_current_file_dir) |
| _sam2_repo_dir = os.path.join(_project_root, "sam2") |
| |
| abs_sam2_dir = os.path.abspath(_sam2_repo_dir) |
| if abs_sam2_dir not in sys.path: |
| sys.path.insert(0, abs_sam2_dir) |
|
|
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator |
|
|
| from model.utils import mask_to_polygon |
|
|
| |
| HUGGINGFACE_MODEL_ID = "facebook/sam2.1-hiera-large" |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| class SAM2AutoAnnotation: |
| """ |
| SAM2 Auto Annotation wrapper for automatically generating masks for all objects in an image. |
| Uses SAM2AutomaticMaskGenerator from Hugging Face. |
| """ |
| |
| def __init__( |
| self, |
| points_per_side: int = 32, |
| points_per_batch: int = 64, |
| pred_iou_thresh: float = 0.88, |
| stability_score_thresh: float = 0.95, |
| min_mask_region_area: int = 100, |
| ): |
| """ |
| Initialize SAM2 Auto Annotation. |
| |
| Args: |
| points_per_side: Number of points per side of the image grid |
| points_per_batch: Number of points to process in each batch |
| pred_iou_thresh: Prediction IoU threshold |
| stability_score_thresh: Stability score threshold |
| min_mask_region_area: Minimum mask region area in pixels |
| """ |
| self.points_per_side = points_per_side |
| self.points_per_batch = points_per_batch |
| self.pred_iou_thresh = pred_iou_thresh |
| self.stability_score_thresh = stability_score_thresh |
| self.min_mask_region_area = min_mask_region_area |
| self._mask_generator = None |
| |
| def _get_mask_generator(self): |
| """Lazy initialization of mask generator.""" |
| if self._mask_generator is None: |
| try: |
| |
| try: |
| self._mask_generator = SAM2AutomaticMaskGenerator.from_pretrained( |
| HUGGINGFACE_MODEL_ID, |
| device=device, |
| points_per_side=self.points_per_side, |
| points_per_batch=self.points_per_batch, |
| pred_iou_thresh=self.pred_iou_thresh, |
| stability_score_thresh=self.stability_score_thresh, |
| crop_n_layers=1, |
| crop_n_points_downscale_factor=2, |
| min_mask_region_area=self.min_mask_region_area, |
| ) |
| except TypeError: |
| |
| self._mask_generator = SAM2AutomaticMaskGenerator.from_pretrained( |
| HUGGINGFACE_MODEL_ID, |
| device=device |
| ) |
| |
| for attr_name in ['points_per_side', 'points_per_batch', 'pred_iou_thresh', |
| 'stability_score_thresh', 'min_mask_region_area']: |
| if hasattr(self._mask_generator, attr_name): |
| setattr(self._mask_generator, attr_name, getattr(self, attr_name)) |
| except ImportError as e: |
| raise RuntimeError( |
| f"Failed to import required modules for SAM2. Please ensure 'sam2' and 'huggingface_hub' are installed. " |
| f"Error: {str(e)}" |
| ) |
| except Exception as e: |
| raise RuntimeError( |
| f"Failed to load SAM2 Automatic Mask Generator from Hugging Face ({HUGGINGFACE_MODEL_ID}). " |
| f"Please check your internet connection and ensure the model ID is correct. " |
| f"Error: {str(e)}" |
| ) |
| return self._mask_generator |
| |
| def generate_masks( |
| self, |
| image: np.ndarray, |
| min_confidence: float = 0.0, |
| min_area: int = None, |
| filter_blank_regions: bool = True, |
| scale_factors: tuple = (1.0, 1.0), |
| ) -> list: |
| """ |
| Generate all masks for objects in the image. |
| |
| Args: |
| image: Image as numpy array (RGB format, H, W, 3) |
| min_confidence: Minimum confidence score to filter masks (default: 0.0) |
| min_area: Minimum mask area in pixels (default: uses self.min_mask_region_area) |
| filter_blank_regions: Filter out blank/black regions (default: True) |
| scale_factors: Tuple (scale_x, scale_y) to scale coordinates FROM processed TO display size |
| (matching predict_polygon_from_point logic) |
| |
| Returns: |
| List of mask dictionaries, each containing: |
| - polygon: flattened coordinates [x1, y1, x2, y2, ...] (scaled to display size) |
| - confidence: confidence score |
| - area: mask area in pixels |
| """ |
| if min_area is None: |
| min_area = self.min_mask_region_area |
| |
| |
| mask_generator = self._get_mask_generator() |
| |
| |
| masks = mask_generator.generate(image) |
| |
| |
| if filter_blank_regions: |
| if len(image.shape) == 3: |
| gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) |
| else: |
| gray_image = image |
| |
| |
| results = [] |
| for mask_data in masks: |
| |
| mask = mask_data["segmentation"] |
| score = float(mask_data.get("stability_score", mask_data.get("predicted_iou", 0.0))) |
| area = int(mask_data.get("area", 0)) |
| |
| |
| if score < min_confidence: |
| continue |
| |
| |
| if area < min_area: |
| continue |
| |
| |
| if filter_blank_regions: |
| masked_region = gray_image[mask] |
| if len(masked_region) > 0: |
| mean_intensity = float(np.mean(masked_region)) |
| if mean_intensity < 30: |
| variance = float(np.var(masked_region)) |
| if variance < 100: |
| continue |
| elif mean_intensity < 50: |
| variance = float(np.var(masked_region)) |
| if variance < 50: |
| continue |
| |
| |
| mask_uint8 = (mask.astype(np.uint8) * 255) |
| |
| |
| |
| |
| polygon = mask_to_polygon(mask_uint8, scale_factors=scale_factors) |
| |
| results.append({ |
| "polygon": polygon, |
| "confidence": score, |
| "area": area |
| }) |
| |
| return results |
|
|
|
|
| def create_sam2_auto_annotation( |
| points_per_side: int = 32, |
| points_per_batch: int = 64, |
| pred_iou_thresh: float = 0.88, |
| stability_score_thresh: float = 0.95, |
| min_mask_region_area: int = 100, |
| ) -> SAM2AutoAnnotation: |
| """ |
| Factory function to create a SAM2 Auto Annotation instance. |
| |
| Args: |
| points_per_side: Number of points per side of the image grid |
| points_per_batch: Number of points to process in each batch |
| pred_iou_thresh: Prediction IoU threshold |
| stability_score_thresh: Stability score threshold |
| min_mask_region_area: Minimum mask region area in pixels |
| |
| Returns: |
| SAM2AutoAnnotation instance |
| """ |
| return SAM2AutoAnnotation( |
| points_per_side=points_per_side, |
| points_per_batch=points_per_batch, |
| pred_iou_thresh=pred_iou_thresh, |
| stability_score_thresh=stability_score_thresh, |
| min_mask_region_area=min_mask_region_area, |
| ) |
|
|
|
|