#!/usr/bin/env python3 """ SAM3 MLX Click Segmentation Example Demonstrates how to: 1. Load SAM3 MLX model 2. Process an image 3. Segment objects with point clicks 4. Visualize results Usage: python click_segment.py --image path/to/image.jpg --point 100,200 """ import argparse import time from pathlib import Path from typing import Tuple, Optional import numpy as np import mlx.core as mx try: from PIL import Image import matplotlib.pyplot as plt except ImportError: print("āŒ Please install PIL and matplotlib:") print(" pip install pillow matplotlib") exit(1) # Add parent directory to path import sys sys.path.insert(0, str(Path(__file__).parent.parent)) from models.sam3 import SAM3MLX from utils.weights import load_weights def load_image(image_path: str, target_size: int = 1024) -> Tuple[mx.array, np.ndarray]: """ Load and preprocess image for SAM3 Args: image_path: Path to image file target_size: Target image size (SAM3 uses 1024x1024) Returns: Tuple of (preprocessed MLX array, original numpy array) """ # Load image img = Image.open(image_path).convert("RGB") original = np.array(img) # Resize to target size img_resized = img.resize((target_size, target_size), Image.BILINEAR) img_np = np.array(img_resized).astype(np.float32) / 255.0 # Convert to MLX array in NHWC format img_mlx = mx.array(img_np).reshape(1, target_size, target_size, 3) return img_mlx, original def visualize_prediction( image: np.ndarray, masks: mx.array, point_coords: mx.array, point_labels: mx.array, iou_scores: mx.array, save_path: Optional[str] = None, ): """ Visualize segmentation results Args: image: Original image (H, W, 3) masks: Predicted masks (1, num_masks, H, W) point_coords: Input point coordinates (1, N, 2) point_labels: Input point labels (1, N) iou_scores: IoU quality scores (1, num_masks) save_path: Optional path to save visualization """ # Convert MLX to numpy masks_np = np.array(masks[0]) # (num_masks, H, W) point_coords_np = np.array(point_coords[0]) # (N, 2) point_labels_np = np.array(point_labels[0]) # (N,) iou_scores_np = np.array(iou_scores[0]) # (num_masks,) num_masks = masks_np.shape[0] # Create figure fig, axes = plt.subplots(1, num_masks + 1, figsize=(5 * (num_masks + 1), 5)) if num_masks == 1: axes = [axes[0], axes[1]] # Show original image with points axes[0].imshow(image) axes[0].set_title("Input Image with Points") # Plot positive points (green) and negative points (red) for coord, label in zip(point_coords_np, point_labels_np): color = 'g' if label == 1 else 'r' marker = 'o' if label == 1 else 'x' axes[0].scatter(coord[0], coord[1], c=color, marker=marker, s=200, linewidths=3) axes[0].axis('off') # Show each predicted mask for i in range(num_masks): # Resize mask to original image size mask = masks_np[i] H, W = image.shape[:2] from PIL import Image as PILImage mask_resized = PILImage.fromarray((mask * 255).astype(np.uint8)) mask_resized = mask_resized.resize((W, H), PILImage.BILINEAR) mask_resized = np.array(mask_resized) / 255.0 # Overlay mask on image overlay = image.copy() mask_3ch = np.stack([mask_resized] * 3, axis=-1) overlay = (overlay * (1 - mask_3ch * 0.5) + np.array([0, 255, 0]) * mask_3ch * 0.5).astype(np.uint8) axes[i + 1].imshow(overlay) axes[i + 1].set_title(f"Mask {i+1} (IoU: {iou_scores_np[i]:.3f})") axes[i + 1].axis('off') plt.tight_layout() if save_path: plt.savefig(save_path, bbox_inches='tight', dpi=150) print(f"šŸ’¾ Saved visualization to {save_path}") plt.show() def main(): parser = argparse.ArgumentParser(description="SAM3 MLX Click Segmentation Example") parser.add_argument("--image", type=str, required=True, help="Path to input image") parser.add_argument( "--point", type=str, action="append", help="Click point as 'x,y' (can specify multiple). Use +x,y for positive, -x,y for negative", ) parser.add_argument( "--checkpoint", type=str, default="./checkpoints/sam3_mlx", help="Path to SAM3 MLX checkpoint directory", ) parser.add_argument( "--output", type=str, default=None, help="Path to save output visualization", ) parser.add_argument( "--single-mask", action="store_true", help="Output single mask instead of 3 masks", ) args = parser.parse_args() print("šŸš€ SAM3 MLX Click Segmentation Example") print("=" * 60) # Parse points if not args.point: print("āŒ Please specify at least one point with --point x,y") return point_coords_list = [] point_labels_list = [] for point_str in args.point: # Check for label prefix if point_str.startswith('+'): label = 1 # Positive point_str = point_str[1:] elif point_str.startswith('-'): label = 0 # Negative point_str = point_str[1:] else: label = 1 # Default to positive x, y = map(float, point_str.split(',')) point_coords_list.append([x, y]) point_labels_list.append(label) point_coords = mx.array(point_coords_list).reshape(1, -1, 2) point_labels = mx.array(point_labels_list).reshape(1, -1) print(f"šŸ“ Input points: {len(point_coords_list)}") for i, (coord, label) in enumerate(zip(point_coords_list, point_labels_list)): label_str = "positive" if label == 1 else "negative" print(f" Point {i+1}: ({coord[0]:.0f}, {coord[1]:.0f}) [{label_str}]") # Load image print(f"\nšŸ“ø Loading image: {args.image}") image_mlx, image_original = load_image(args.image) print(f" Image size: {image_original.shape[1]}x{image_original.shape[0]}") # Initialize model print(f"\nšŸ—ļø Initializing SAM3 MLX model...") model = SAM3MLX() # Load weights if available checkpoint_dir = Path(args.checkpoint) weights_path = checkpoint_dir / "sam3_mlx_weights.npz" if weights_path.exists(): print(f"\nšŸ“„ Loading weights from {checkpoint_dir}") model = load_weights(model, str(weights_path), strict=False, verbose=True) else: print(f"\nāš ļø Weights not found at {weights_path}") print(" Using randomly initialized model (for testing architecture only)") # Run inference print(f"\nšŸŽÆ Running segmentation...") start_time = time.time() result = model.predict( image=image_mlx, point_coords=point_coords, point_labels=point_labels, multimask_output=not args.single_mask, ) # Ensure computation is complete mx.eval(result["masks"]) inference_time = (time.time() - start_time) * 1000 print(f"āœ… Inference completed in {inference_time:.1f}ms") # Print results masks = result["masks"] iou_predictions = result["iou_predictions"] print(f"\nšŸ“Š Results:") print(f" Number of masks: {masks.shape[1]}") print(f" Mask resolution: {masks.shape[2]}x{masks.shape[3]}") print(f" IoU scores: {np.array(iou_predictions[0])}") # Visualize print(f"\nšŸŽØ Visualizing results...") visualize_prediction( image_original, masks, point_coords, point_labels, iou_predictions, save_path=args.output, ) print(f"\nāœ… Done!") if __name__ == "__main__": main()