MLX
MLX_SAM3 / click_segment.py
Hoodrobot's picture
Upload 15 files
ced11e2 verified
#!/usr/bin/env python3
"""
SAM3 MLX Click Segmentation Example
Demonstrates how to:
1. Load SAM3 MLX model
2. Process an image
3. Segment objects with point clicks
4. Visualize results
Usage:
python click_segment.py --image path/to/image.jpg --point 100,200
"""
import argparse
import time
from pathlib import Path
from typing import Tuple, Optional
import numpy as np
import mlx.core as mx
try:
from PIL import Image
import matplotlib.pyplot as plt
except ImportError:
print("โŒ Please install PIL and matplotlib:")
print(" pip install pillow matplotlib")
exit(1)
# Add parent directory to path
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from models.sam3 import SAM3MLX
from utils.weights import load_weights
def load_image(image_path: str, target_size: int = 1024) -> Tuple[mx.array, np.ndarray]:
"""
Load and preprocess image for SAM3
Args:
image_path: Path to image file
target_size: Target image size (SAM3 uses 1024x1024)
Returns:
Tuple of (preprocessed MLX array, original numpy array)
"""
# Load image
img = Image.open(image_path).convert("RGB")
original = np.array(img)
# Resize to target size
img_resized = img.resize((target_size, target_size), Image.BILINEAR)
img_np = np.array(img_resized).astype(np.float32) / 255.0
# Convert to MLX array in NHWC format
img_mlx = mx.array(img_np).reshape(1, target_size, target_size, 3)
return img_mlx, original
def visualize_prediction(
image: np.ndarray,
masks: mx.array,
point_coords: mx.array,
point_labels: mx.array,
iou_scores: mx.array,
save_path: Optional[str] = None,
):
"""
Visualize segmentation results
Args:
image: Original image (H, W, 3)
masks: Predicted masks (1, num_masks, H, W)
point_coords: Input point coordinates (1, N, 2)
point_labels: Input point labels (1, N)
iou_scores: IoU quality scores (1, num_masks)
save_path: Optional path to save visualization
"""
# Convert MLX to numpy
masks_np = np.array(masks[0]) # (num_masks, H, W)
point_coords_np = np.array(point_coords[0]) # (N, 2)
point_labels_np = np.array(point_labels[0]) # (N,)
iou_scores_np = np.array(iou_scores[0]) # (num_masks,)
num_masks = masks_np.shape[0]
# Create figure
fig, axes = plt.subplots(1, num_masks + 1, figsize=(5 * (num_masks + 1), 5))
if num_masks == 1:
axes = [axes[0], axes[1]]
# Show original image with points
axes[0].imshow(image)
axes[0].set_title("Input Image with Points")
# Plot positive points (green) and negative points (red)
for coord, label in zip(point_coords_np, point_labels_np):
color = 'g' if label == 1 else 'r'
marker = 'o' if label == 1 else 'x'
axes[0].scatter(coord[0], coord[1], c=color, marker=marker, s=200, linewidths=3)
axes[0].axis('off')
# Show each predicted mask
for i in range(num_masks):
# Resize mask to original image size
mask = masks_np[i]
H, W = image.shape[:2]
from PIL import Image as PILImage
mask_resized = PILImage.fromarray((mask * 255).astype(np.uint8))
mask_resized = mask_resized.resize((W, H), PILImage.BILINEAR)
mask_resized = np.array(mask_resized) / 255.0
# Overlay mask on image
overlay = image.copy()
mask_3ch = np.stack([mask_resized] * 3, axis=-1)
overlay = (overlay * (1 - mask_3ch * 0.5) + np.array([0, 255, 0]) * mask_3ch * 0.5).astype(np.uint8)
axes[i + 1].imshow(overlay)
axes[i + 1].set_title(f"Mask {i+1} (IoU: {iou_scores_np[i]:.3f})")
axes[i + 1].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, bbox_inches='tight', dpi=150)
print(f"๐Ÿ’พ Saved visualization to {save_path}")
plt.show()
def main():
parser = argparse.ArgumentParser(description="SAM3 MLX Click Segmentation Example")
parser.add_argument("--image", type=str, required=True, help="Path to input image")
parser.add_argument(
"--point",
type=str,
action="append",
help="Click point as 'x,y' (can specify multiple). Use +x,y for positive, -x,y for negative",
)
parser.add_argument(
"--checkpoint",
type=str,
default="./checkpoints/sam3_mlx",
help="Path to SAM3 MLX checkpoint directory",
)
parser.add_argument(
"--output",
type=str,
default=None,
help="Path to save output visualization",
)
parser.add_argument(
"--single-mask",
action="store_true",
help="Output single mask instead of 3 masks",
)
args = parser.parse_args()
print("๐Ÿš€ SAM3 MLX Click Segmentation Example")
print("=" * 60)
# Parse points
if not args.point:
print("โŒ Please specify at least one point with --point x,y")
return
point_coords_list = []
point_labels_list = []
for point_str in args.point:
# Check for label prefix
if point_str.startswith('+'):
label = 1 # Positive
point_str = point_str[1:]
elif point_str.startswith('-'):
label = 0 # Negative
point_str = point_str[1:]
else:
label = 1 # Default to positive
x, y = map(float, point_str.split(','))
point_coords_list.append([x, y])
point_labels_list.append(label)
point_coords = mx.array(point_coords_list).reshape(1, -1, 2)
point_labels = mx.array(point_labels_list).reshape(1, -1)
print(f"๐Ÿ“ Input points: {len(point_coords_list)}")
for i, (coord, label) in enumerate(zip(point_coords_list, point_labels_list)):
label_str = "positive" if label == 1 else "negative"
print(f" Point {i+1}: ({coord[0]:.0f}, {coord[1]:.0f}) [{label_str}]")
# Load image
print(f"\n๐Ÿ“ธ Loading image: {args.image}")
image_mlx, image_original = load_image(args.image)
print(f" Image size: {image_original.shape[1]}x{image_original.shape[0]}")
# Initialize model
print(f"\n๐Ÿ—๏ธ Initializing SAM3 MLX model...")
model = SAM3MLX()
# Load weights if available
checkpoint_dir = Path(args.checkpoint)
weights_path = checkpoint_dir / "sam3_mlx_weights.npz"
if weights_path.exists():
print(f"\n๐Ÿ“ฅ Loading weights from {checkpoint_dir}")
model = load_weights(model, str(weights_path), strict=False, verbose=True)
else:
print(f"\nโš ๏ธ Weights not found at {weights_path}")
print(" Using randomly initialized model (for testing architecture only)")
# Run inference
print(f"\n๐ŸŽฏ Running segmentation...")
start_time = time.time()
result = model.predict(
image=image_mlx,
point_coords=point_coords,
point_labels=point_labels,
multimask_output=not args.single_mask,
)
# Ensure computation is complete
mx.eval(result["masks"])
inference_time = (time.time() - start_time) * 1000
print(f"โœ… Inference completed in {inference_time:.1f}ms")
# Print results
masks = result["masks"]
iou_predictions = result["iou_predictions"]
print(f"\n๐Ÿ“Š Results:")
print(f" Number of masks: {masks.shape[1]}")
print(f" Mask resolution: {masks.shape[2]}x{masks.shape[3]}")
print(f" IoU scores: {np.array(iou_predictions[0])}")
# Visualize
print(f"\n๐ŸŽจ Visualizing results...")
visualize_prediction(
image_original,
masks,
point_coords,
point_labels,
iou_predictions,
save_path=args.output,
)
print(f"\nโœ… Done!")
if __name__ == "__main__":
main()