| import torch |
| import numpy as np |
| import cv2 |
| import psutil |
| import os |
| import sys |
|
|
| |
| _current_file_dir = os.path.dirname(os.path.abspath(__file__)) |
| _project_root = os.path.dirname(_current_file_dir) |
| _sam2_repo_dir = os.path.join(_project_root, "sam2") |
| |
| abs_sam2_dir = os.path.abspath(_sam2_repo_dir) |
| if abs_sam2_dir not in sys.path: |
| sys.path.insert(0, abs_sam2_dir) |
|
|
| from sam2.sam2_image_predictor import SAM2ImagePredictor |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator |
| from model.utils import mask_to_polygon |
|
|
| |
| |
| |
| HUGGINGFACE_MODEL_ID = "facebook/sam2.1-hiera-large" |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| predictor = None |
| mask_generator = None |
|
|
|
|
| def initialize_sam(): |
| """ |
| Initialize SAM2 Large model from Hugging Face if not already loaded. |
| |
| Returns: |
| SAM2ImagePredictor instance |
| |
| Raises: |
| ImportError: If sam2 or huggingface_hub is not installed |
| RuntimeError: If model fails to load from Hugging Face |
| """ |
| global predictor |
| if predictor is None: |
| try: |
| |
| |
| predictor = SAM2ImagePredictor.from_pretrained( |
| HUGGINGFACE_MODEL_ID, |
| device=device |
| ) |
| except ImportError as e: |
| raise ImportError( |
| f"Failed to import required modules. Please ensure 'sam2' and 'huggingface_hub' are installed. " |
| f"Install with: pip install segment-anything huggingface_hub. " |
| f"Error: {str(e)}" |
| ) |
| except Exception as e: |
| error_msg = str(e) |
| raise RuntimeError( |
| f"Failed to load SAM2 model from Hugging Face ({HUGGINGFACE_MODEL_ID}). " |
| f"Please check your internet connection and ensure the model ID is correct. " |
| f"Error: {error_msg}" |
| ) |
| return predictor |
|
|
|
|
| def initialize_mask_generator(points_per_side=32, points_per_batch=64): |
| """ |
| Initialize SAM2 Automatic Mask Generator from Hugging Face if not already loaded. |
| Configured with memory-efficient parameters for CPU usage. |
| |
| Args: |
| points_per_side: Number of points per side of the image grid (default: 32, lower = less memory) |
| points_per_batch: Number of points to process in each batch (default: 64, lower = less memory) |
| |
| Returns: |
| SAM2AutomaticMaskGenerator instance |
| |
| Raises: |
| ImportError: If sam2 or huggingface_hub is not installed |
| RuntimeError: If model fails to load from Hugging Face |
| """ |
| global mask_generator |
| if mask_generator is None: |
| try: |
| |
| try: |
| mask_generator = SAM2AutomaticMaskGenerator.from_pretrained( |
| HUGGINGFACE_MODEL_ID, |
| device=device, |
| points_per_side=points_per_side, |
| points_per_batch=points_per_batch, |
| pred_iou_thresh=0.88, |
| stability_score_thresh=0.95, |
| crop_n_layers=1, |
| crop_n_points_downscale_factor=2, |
| min_mask_region_area=100, |
| ) |
| except TypeError: |
| |
| |
| mask_generator = SAM2AutomaticMaskGenerator.from_pretrained( |
| HUGGINGFACE_MODEL_ID, |
| device=device |
| ) |
| |
| if hasattr(mask_generator, 'points_per_side'): |
| mask_generator.points_per_side = points_per_side |
| if hasattr(mask_generator, 'points_per_batch'): |
| mask_generator.points_per_batch = points_per_batch |
| except ImportError as e: |
| raise ImportError( |
| f"Failed to import required modules. Please ensure 'sam2' and 'huggingface_hub' are installed. " |
| f"Install with: pip install segment-anything huggingface_hub. " |
| f"Error: {str(e)}" |
| ) |
| except Exception as e: |
| error_msg = str(e) |
| raise RuntimeError( |
| f"Failed to load SAM2 Automatic Mask Generator from Hugging Face ({HUGGINGFACE_MODEL_ID}). " |
| f"Please check your internet connection and ensure the model ID is correct. " |
| f"Error: {error_msg}" |
| ) |
| return mask_generator |
|
|
|
|
| def resize_image_if_needed(image_rgb, max_dimension=1024): |
| """ |
| Resize image if it exceeds max_dimension to reduce memory usage. |
| Maintains aspect ratio. |
| |
| Args: |
| image_rgb: numpy array (H, W, 3) in RGB format |
| max_dimension: Maximum dimension (width or height) in pixels (default: 1024) |
| |
| Returns: |
| resized_image: Resized numpy array |
| scale_factor: Tuple (scale_x, scale_y) - how much the image was scaled down |
| """ |
| h, w = image_rgb.shape[:2] |
| max_current = max(h, w) |
| |
| if max_current <= max_dimension: |
| return image_rgb, (1.0, 1.0) |
| |
| |
| if h > w: |
| new_h = max_dimension |
| new_w = int(w * (max_dimension / h)) |
| else: |
| new_w = max_dimension |
| new_h = int(h * (max_dimension / w)) |
| |
| |
| resized = cv2.resize(image_rgb, (new_w, new_h), interpolation=cv2.INTER_LINEAR) |
| |
| scale_x = w / new_w if new_w > 0 else 1.0 |
| scale_y = h / new_h if new_h > 0 else 1.0 |
| |
| return resized, (scale_x, scale_y) |
|
|
|
|
| def calculate_memory_usage(): |
| """ |
| Calculate current memory usage of the process. |
| |
| Returns: |
| dict: Memory usage information in MB |
| """ |
| process = psutil.Process(os.getpid()) |
| mem_info = process.memory_info() |
| |
| return { |
| "rss_mb": mem_info.rss / (1024 * 1024), |
| "vms_mb": mem_info.vms / (1024 * 1024), |
| "percent": process.memory_percent() |
| } |
|
|
|
|
| def estimate_image_memory(image_rgb): |
| """ |
| Estimate memory required for processing an image. |
| |
| Args: |
| image_rgb: numpy array (H, W, 3) in RGB format |
| |
| Returns: |
| dict: Estimated memory usage in MB |
| """ |
| h, w = image_rgb.shape[:2] |
| |
| |
| |
| |
| |
| |
| |
| image_memory_mb = (h * w * 3 * 4) / (1024 * 1024) |
| feature_memory_mb = (h * w * 256 * 4) / (1024 * 1024) |
| masks_memory_mb = (h * w * 100 * 1) / (1024 * 1024) |
| |
| total_estimated_mb = image_memory_mb + feature_memory_mb + masks_memory_mb |
| |
| return { |
| "image_mb": image_memory_mb, |
| "features_mb": feature_memory_mb, |
| "masks_mb": masks_memory_mb, |
| "total_estimated_mb": total_estimated_mb, |
| "image_size": f"{w}x{h}" |
| } |
|
|
|
|
| def generate_all_masks(image_rgb, image_size=None, min_area=100, min_confidence=0.5, max_image_dimension=1024, points_per_side=32, points_per_batch=64): |
| """ |
| Generate all possible object masks in an image using SAM2 Automatic Mask Generator. |
| Automatically detects and segments all objects without requiring prompts. |
| Optimized for CPU usage with image resizing and memory-efficient parameters. |
| |
| Args: |
| image_rgb: numpy array (H, W, 3) in RGB format |
| image_size: Optional dict with "width" and "height" for coordinate scaling |
| min_area: Minimum mask area to filter out small/noisy masks (default: 100) |
| min_confidence: Minimum confidence score to filter masks (default: 0.5) |
| max_image_dimension: Maximum dimension (width or height) in pixels before resizing (default: 1024) |
| points_per_side: Number of points per side of the image grid (default: 32, lower = less memory) |
| points_per_batch: Number of points to process in each batch (default: 64, lower = less memory) |
| |
| Returns: |
| dict: Contains: |
| - masks: List of dicts, each containing: |
| - polygon: flattened coordinates array [x1, y1, x2, y2, ...] |
| - confidence: float confidence score |
| - area: int mask area in pixels |
| - memory_info: Memory usage information |
| - was_resized: Whether the image was resized |
| - original_size: Original image dimensions |
| - processed_size: Processed image dimensions |
| """ |
| |
| memory_before = calculate_memory_usage() |
| |
| |
| original_h, original_w = image_rgb.shape[:2] |
| original_size = (original_w, original_h) |
| |
| |
| processed_image, resize_scale = resize_image_if_needed(image_rgb, max_dimension=max_image_dimension) |
| was_resized = resize_scale[0] != 1.0 or resize_scale[1] != 1.0 |
| processed_h, processed_w = processed_image.shape[:2] |
| processed_size = (processed_w, processed_h) |
| |
| |
| memory_estimate = estimate_image_memory(processed_image) |
| |
| |
| generator = initialize_mask_generator(points_per_side=points_per_side, points_per_batch=points_per_batch) |
| |
| |
| scale_x, scale_y = 1.0, 1.0 |
| |
| if image_size is not None: |
| if isinstance(image_size, dict): |
| display_w = float(image_size.get("width", original_w)) |
| display_h = float(image_size.get("height", original_h)) |
| else: |
| display_w, display_h = float(image_size[0]), float(image_size[1]) |
| |
| |
| |
| scale_x = (processed_w / display_w) * resize_scale[0] if display_w > 0 else resize_scale[0] |
| scale_y = (processed_h / display_h) * resize_scale[1] if display_h > 0 else resize_scale[1] |
| else: |
| |
| scale_x = resize_scale[0] |
| scale_y = resize_scale[1] |
| |
| |
| masks = generator.generate(processed_image) |
| |
| |
| memory_after = calculate_memory_usage() |
| |
| |
| result_masks = [] |
| |
| for mask_data in masks: |
| |
| mask = mask_data["segmentation"] |
| confidence = float(mask_data.get("stability_score", mask_data.get("predicted_iou", 0.0))) |
| area = int(mask_data.get("area", 0)) |
| |
| |
| if area < min_area or confidence < min_confidence: |
| continue |
| |
| |
| mask_uint8 = (mask.astype(np.uint8) * 255) |
| |
| |
| |
| |
| polygon = mask_to_polygon(mask_uint8, (1.0/scale_x if scale_x != 0 else 1.0, 1.0/scale_y if scale_y != 0 else 1.0)) |
| |
| if polygon and len(polygon) >= 6: |
| result_masks.append({ |
| "polygon": polygon, |
| "confidence": confidence, |
| "area": area |
| }) |
| |
| |
| result_masks.sort(key=lambda x: x["area"], reverse=True) |
| |
| return { |
| "masks": result_masks, |
| "memory_info": { |
| "before_mb": memory_before["rss_mb"], |
| "after_mb": memory_after["rss_mb"], |
| "peak_mb": memory_after["rss_mb"], |
| "estimated_mb": memory_estimate["total_estimated_mb"], |
| "memory_used_mb": memory_after["rss_mb"] - memory_before["rss_mb"] |
| }, |
| "was_resized": was_resized, |
| "original_size": original_size, |
| "processed_size": processed_size, |
| "resize_scale": resize_scale |
| } |
|
|
|
|
| def predict_polygon(image_rgb, bbox, image_size=None): |
| """ |
| Predict polygon mask using SAM2 with bbox as prompt (CVAT-style). |
| Bbox is used to identify the object, not constrain it. |
| |
| Args: |
| image_rgb: numpy array (H, W, 3) in RGB format |
| bbox: dict with keys "x", "y", "width", "height" OR list [x, y, w, h] |
| image_size: Optional dict with "width" and "height" for coordinate scaling |
| |
| Returns: |
| mask: binary mask (numpy array) - full object shape, NOT clipped to bbox |
| confidence: float confidence score |
| """ |
| predictor = initialize_sam() |
| predictor.set_image(image_rgb) |
|
|
| |
| if isinstance(bbox, dict): |
| x = float(bbox["x"]) |
| y = float(bbox["y"]) |
| bbox_w = float(bbox["width"]) |
| bbox_h = float(bbox["height"]) |
| else: |
| x, y, bbox_w, bbox_h = [float(v) for v in bbox] |
| |
| |
| |
| |
| scale_x, scale_y = 1.0, 1.0 |
| original_h, original_w = image_rgb.shape[:2] |
| |
| if image_size is not None: |
| if isinstance(image_size, dict): |
| display_w = float(image_size.get("width", original_w)) |
| display_h = float(image_size.get("height", original_h)) |
| else: |
| display_w, display_h = float(image_size[0]), float(image_size[1]) |
| |
| |
| scale_x = original_w / display_w if display_w > 0 else 1.0 |
| scale_y = original_h / display_h if display_h > 0 else 1.0 |
| |
| |
| x = x * scale_x |
| y = y * scale_y |
| bbox_w = bbox_w * scale_x |
| bbox_h = bbox_h * scale_y |
| |
| |
| box = np.array([x, y, x + bbox_w, y + bbox_h], dtype=np.float32) |
| |
| |
| |
| center_x = x + bbox_w / 2.0 |
| center_y = y + bbox_h / 2.0 |
| |
| |
| point_coords = np.array([ |
| [center_x, center_y], |
| [x + bbox_w * 0.25, y + bbox_h * 0.25], |
| [x + bbox_w * 0.75, y + bbox_h * 0.25], |
| [x + bbox_w * 0.25, y + bbox_h * 0.75], |
| [x + bbox_w * 0.75, y + bbox_h * 0.75], |
| ], dtype=np.float32) |
| point_labels = np.array([1, 1, 1, 1, 1], dtype=np.int32) |
|
|
| |
| masks, scores, _ = predictor.predict( |
| box=box, |
| point_coords=point_coords, |
| point_labels=point_labels, |
| multimask_output=True |
| ) |
|
|
| |
| |
| best_mask_idx = 0 |
| best_score_combined = 0.0 |
| bbox_area = bbox_w * bbox_h |
| |
| for idx, (mask, score) in enumerate(zip(masks, scores)): |
| |
| mask_binary = mask.astype(np.uint8) * 255 |
| |
| |
| x1_int = max(0, int(x)) |
| y1_int = max(0, int(y)) |
| x2_int = min(mask.shape[1], int(x + bbox_w)) |
| y2_int = min(mask.shape[0], int(y + bbox_h)) |
| |
| mask_bbox_region = mask_binary[y1_int:y2_int, x1_int:x2_int] |
| mask_area_in_bbox = np.sum(mask_bbox_region > 0) |
| |
| |
| coverage_ratio = mask_area_in_bbox / bbox_area if bbox_area > 0 else 0 |
| |
| |
| |
| score_combined = float(score) * 0.6 + coverage_ratio * 0.4 |
| |
| if score_combined > best_score_combined: |
| best_score_combined = score_combined |
| best_mask_idx = idx |
| |
| best_mask = masks[best_mask_idx] |
| best_score = scores[best_mask_idx] |
|
|
| |
| mask = (best_mask * 255).astype("uint8") if best_mask.dtype == bool else (best_mask * 255).astype("uint8") |
| |
| |
| |
| mask_filled = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, |
| cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))) |
| |
| |
| h, w = mask_filled.shape |
| mask_floodfill = mask_filled.copy() |
| cv2.floodFill(mask_floodfill, None, (0, 0), 255) |
| mask_floodfill_inv = cv2.bitwise_not(mask_floodfill) |
| mask_filled = cv2.bitwise_or(mask_filled, mask_floodfill_inv) |
| |
| |
| mask = mask_filled |
| |
| |
| score_arr = np.asarray(best_score).flatten() |
| confidence = float(score_arr[0]) |
|
|
| return mask, confidence, (scale_x, scale_y) |
|
|
|
|
| def predict_polygon_from_point(image_rgb, point, image_size=None): |
| """ |
| Predict polygon mask using SAM2 with a point click as prompt. |
| The point identifies the object to segment. |
| |
| Args: |
| image_rgb: numpy array (H, W, 3) in RGB format |
| point: dict with keys "x", "y" OR list [x, y] - the clicked point coordinate |
| image_size: Optional dict with "width" and "height" for coordinate scaling |
| |
| Returns: |
| mask: binary mask (numpy array) - full object shape |
| confidence: float confidence score |
| scale_factors: tuple (scale_x, scale_y) for coordinate scaling |
| """ |
| predictor = initialize_sam() |
| predictor.set_image(image_rgb) |
|
|
| |
| if isinstance(point, dict): |
| point_x = float(point["x"]) |
| point_y = float(point["y"]) |
| else: |
| point_x, point_y = [float(v) for v in point] |
| |
| |
| |
| |
| scale_x, scale_y = 1.0, 1.0 |
| original_h, original_w = image_rgb.shape[:2] |
| |
| if image_size is not None: |
| if isinstance(image_size, dict): |
| display_w = float(image_size.get("width", original_w)) |
| display_h = float(image_size.get("height", original_h)) |
| else: |
| display_w, display_h = float(image_size[0]), float(image_size[1]) |
| |
| |
| scale_x = original_w / display_w if display_w > 0 else 1.0 |
| scale_y = original_h / display_h if display_h > 0 else 1.0 |
| |
| |
| point_x = point_x * scale_x |
| point_y = point_y * scale_y |
| |
| |
| |
| point_coords = np.array([[point_x, point_y]], dtype=np.float32) |
| point_labels = np.array([1], dtype=np.int32) |
|
|
| |
| masks, scores, _ = predictor.predict( |
| point_coords=point_coords, |
| point_labels=point_labels, |
| multimask_output=True |
| ) |
|
|
| |
| best_mask_idx = np.argmax(scores) |
| best_mask = masks[best_mask_idx] |
| best_score = scores[best_mask_idx] |
|
|
| |
| mask = (best_mask * 255).astype("uint8") if best_mask.dtype == bool else (best_mask * 255).astype("uint8") |
| |
| |
| |
| mask_filled = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, |
| cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))) |
| |
| |
| h, w = mask_filled.shape |
| mask_floodfill = mask_filled.copy() |
| cv2.floodFill(mask_floodfill, None, (0, 0), 255) |
| mask_floodfill_inv = cv2.bitwise_not(mask_floodfill) |
| mask_filled = cv2.bitwise_or(mask_filled, mask_floodfill_inv) |
| |
| |
| mask = mask_filled |
| |
| |
| score_arr = np.asarray(best_score).flatten() |
| confidence = float(score_arr[0]) |
|
|
| return mask, confidence, (scale_x, scale_y) |
|
|