| |
| """ |
| SAM3 MLX Benchmarks |
| |
| Measures performance on Apple Silicon to validate <200ms target |
| """ |
|
|
| import time |
| import mlx.core as mx |
| import numpy as np |
| import sys |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from models.sam3 import SAM3MLX |
|
|
|
|
| def benchmark_component(name: str, func, *args, warmup=3, iterations=10, **kwargs): |
| """Benchmark a component with warmup""" |
| print(f"\n{'='*60}") |
| print(f"Benchmarking: {name}") |
| print(f"{'='*60}") |
|
|
| |
| print(f"Warming up ({warmup} iterations)...") |
| for _ in range(warmup): |
| result = func(*args, **kwargs) |
| if isinstance(result, dict): |
| for v in result.values(): |
| if isinstance(v, mx.array): |
| mx.eval(v) |
| elif isinstance(v, mx.array): |
| mx.eval(result) |
|
|
| |
| print(f"Running benchmark ({iterations} iterations)...") |
| times = [] |
|
|
| for i in range(iterations): |
| start = time.time() |
| result = func(*args, **kwargs) |
|
|
| |
| if isinstance(result, dict): |
| for v in result.values(): |
| if isinstance(v, mx.array): |
| mx.eval(v) |
| elif isinstance(result, mx.array): |
| mx.eval(result) |
|
|
| elapsed = (time.time() - start) * 1000 |
| times.append(elapsed) |
| print(f" Iteration {i+1}: {elapsed:.2f}ms") |
|
|
| |
| times = np.array(times) |
| print(f"\n๐ Results:") |
| print(f" Mean: {times.mean():.2f}ms") |
| print(f" Median: {np.median(times):.2f}ms") |
| print(f" Min: {times.min():.2f}ms") |
| print(f" Max: {times.max():.2f}ms") |
| print(f" Std: {times.std():.2f}ms") |
|
|
| return times.mean() |
|
|
|
|
| def main(): |
| print("๐ SAM3 MLX Performance Benchmarks") |
| print("=" * 60) |
| print(f"MLX version: {mx.__version__}") |
| print(f"Device: Apple Silicon (Metal)") |
| print("=" * 60) |
|
|
| |
| print("\n๐๏ธ Initializing SAM3 MLX...") |
| model = SAM3MLX() |
|
|
| |
| print("\n๐ฆ Preparing test inputs...") |
| image = mx.random.normal((1, 1024, 1024, 3)) |
| point_coords = mx.array([[[512, 384]]]).astype(mx.float32) |
| point_labels = mx.array([[1]]).astype(mx.float32) |
|
|
| |
| results = {} |
|
|
| |
| results['vision_encoder'] = benchmark_component( |
| "Vision Encoder (Hiera)", |
| model.encode_image, |
| image, |
| warmup=3, |
| iterations=10, |
| ) |
|
|
| |
| results['prompt_encoder'] = benchmark_component( |
| "Prompt Encoder", |
| model.prompt_encoder, |
| (point_coords, point_labels), |
| None, |
| None, |
| warmup=3, |
| iterations=20, |
| ) |
|
|
| |
| results['full_pipeline'] = benchmark_component( |
| "Full Pipeline (encode + decode)", |
| model.predict, |
| image, |
| point_coords, |
| point_labels, |
| warmup=3, |
| iterations=10, |
| ) |
|
|
| |
| print(f"\n{'='*60}") |
| print(f"PERFORMANCE SUMMARY") |
| print(f"{'='*60}") |
|
|
| for component, avg_time in results.items(): |
| status = "โ
" if avg_time < 1000 else "โ ๏ธ" |
| print(f"{status} {component:30s} {avg_time:8.2f}ms") |
|
|
| print(f"\n{'='*60}") |
| print(f"TARGET METRICS") |
| print(f"{'='*60}") |
|
|
| vision_target = 500 |
| full_target = 200 |
|
|
| vision_status = "โ
PASS" if results['vision_encoder'] < vision_target else "โ FAIL" |
| full_status = "๐ฏ TARGET" if results['full_pipeline'] < full_target else "โ ๏ธ NEEDS OPTIMIZATION" |
|
|
| print(f"Vision Encoding: {vision_status} (target: <{vision_target}ms)") |
| print(f"Full Pipeline: {full_status} (target: <{full_target}ms)") |
|
|
| print(f"\n{'='*60}") |
| print("Benchmark complete!") |
| print(f"{'='*60}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|