#!/usr/bin/env python3 """ SAM3 MLX Benchmarks Measures performance on Apple Silicon to validate <200ms target """ import time import mlx.core as mx import numpy as np import sys from pathlib import Path # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent)) from models.sam3 import SAM3MLX def benchmark_component(name: str, func, *args, warmup=3, iterations=10, **kwargs): """Benchmark a component with warmup""" print(f"\n{'='*60}") print(f"Benchmarking: {name}") print(f"{'='*60}") # Warmup print(f"Warming up ({warmup} iterations)...") for _ in range(warmup): result = func(*args, **kwargs) if isinstance(result, dict): for v in result.values(): if isinstance(v, mx.array): mx.eval(v) elif isinstance(v, mx.array): mx.eval(result) # Benchmark print(f"Running benchmark ({iterations} iterations)...") times = [] for i in range(iterations): start = time.time() result = func(*args, **kwargs) # Force evaluation if isinstance(result, dict): for v in result.values(): if isinstance(v, mx.array): mx.eval(v) elif isinstance(result, mx.array): mx.eval(result) elapsed = (time.time() - start) * 1000 # Convert to ms times.append(elapsed) print(f" Iteration {i+1}: {elapsed:.2f}ms") # Statistics times = np.array(times) print(f"\nšŸ“Š Results:") print(f" Mean: {times.mean():.2f}ms") print(f" Median: {np.median(times):.2f}ms") print(f" Min: {times.min():.2f}ms") print(f" Max: {times.max():.2f}ms") print(f" Std: {times.std():.2f}ms") return times.mean() def main(): print("šŸš€ SAM3 MLX Performance Benchmarks") print("=" * 60) print(f"MLX version: {mx.__version__}") print(f"Device: Apple Silicon (Metal)") print("=" * 60) # Initialize model print("\nšŸ—ļø Initializing SAM3 MLX...") model = SAM3MLX() # Prepare inputs print("\nšŸ“¦ Preparing test inputs...") image = mx.random.normal((1, 1024, 1024, 3)) point_coords = mx.array([[[512, 384]]]).astype(mx.float32) point_labels = mx.array([[1]]).astype(mx.float32) # Benchmark components results = {} # 1. Vision Encoder results['vision_encoder'] = benchmark_component( "Vision Encoder (Hiera)", model.encode_image, image, warmup=3, iterations=10, ) # 2. Prompt Encoder results['prompt_encoder'] = benchmark_component( "Prompt Encoder", model.prompt_encoder, (point_coords, point_labels), None, None, warmup=3, iterations=20, ) # 3. Full Pipeline results['full_pipeline'] = benchmark_component( "Full Pipeline (encode + decode)", model.predict, image, point_coords, point_labels, warmup=3, iterations=10, ) # Summary print(f"\n{'='*60}") print(f"PERFORMANCE SUMMARY") print(f"{'='*60}") for component, avg_time in results.items(): status = "āœ…" if avg_time < 1000 else "āš ļø" print(f"{status} {component:30s} {avg_time:8.2f}ms") print(f"\n{'='*60}") print(f"TARGET METRICS") print(f"{'='*60}") vision_target = 500 # ms full_target = 200 # ms (after optimization) vision_status = "āœ… PASS" if results['vision_encoder'] < vision_target else "āŒ FAIL" full_status = "šŸŽÆ TARGET" if results['full_pipeline'] < full_target else "āš ļø NEEDS OPTIMIZATION" print(f"Vision Encoding: {vision_status} (target: <{vision_target}ms)") print(f"Full Pipeline: {full_status} (target: <{full_target}ms)") print(f"\n{'='*60}") print("Benchmark complete!") print(f"{'='*60}") if __name__ == "__main__": main()