MLX
MLX_SAM3 / benchmark.py
Hoodrobot's picture
Upload 15 files
ced11e2 verified
#!/usr/bin/env python3
"""
SAM3 MLX Benchmarks
Measures performance on Apple Silicon to validate <200ms target
"""
import time
import mlx.core as mx
import numpy as np
import sys
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from models.sam3 import SAM3MLX
def benchmark_component(name: str, func, *args, warmup=3, iterations=10, **kwargs):
"""Benchmark a component with warmup"""
print(f"\n{'='*60}")
print(f"Benchmarking: {name}")
print(f"{'='*60}")
# Warmup
print(f"Warming up ({warmup} iterations)...")
for _ in range(warmup):
result = func(*args, **kwargs)
if isinstance(result, dict):
for v in result.values():
if isinstance(v, mx.array):
mx.eval(v)
elif isinstance(v, mx.array):
mx.eval(result)
# Benchmark
print(f"Running benchmark ({iterations} iterations)...")
times = []
for i in range(iterations):
start = time.time()
result = func(*args, **kwargs)
# Force evaluation
if isinstance(result, dict):
for v in result.values():
if isinstance(v, mx.array):
mx.eval(v)
elif isinstance(result, mx.array):
mx.eval(result)
elapsed = (time.time() - start) * 1000 # Convert to ms
times.append(elapsed)
print(f" Iteration {i+1}: {elapsed:.2f}ms")
# Statistics
times = np.array(times)
print(f"\n๐Ÿ“Š Results:")
print(f" Mean: {times.mean():.2f}ms")
print(f" Median: {np.median(times):.2f}ms")
print(f" Min: {times.min():.2f}ms")
print(f" Max: {times.max():.2f}ms")
print(f" Std: {times.std():.2f}ms")
return times.mean()
def main():
print("๐Ÿš€ SAM3 MLX Performance Benchmarks")
print("=" * 60)
print(f"MLX version: {mx.__version__}")
print(f"Device: Apple Silicon (Metal)")
print("=" * 60)
# Initialize model
print("\n๐Ÿ—๏ธ Initializing SAM3 MLX...")
model = SAM3MLX()
# Prepare inputs
print("\n๐Ÿ“ฆ Preparing test inputs...")
image = mx.random.normal((1, 1024, 1024, 3))
point_coords = mx.array([[[512, 384]]]).astype(mx.float32)
point_labels = mx.array([[1]]).astype(mx.float32)
# Benchmark components
results = {}
# 1. Vision Encoder
results['vision_encoder'] = benchmark_component(
"Vision Encoder (Hiera)",
model.encode_image,
image,
warmup=3,
iterations=10,
)
# 2. Prompt Encoder
results['prompt_encoder'] = benchmark_component(
"Prompt Encoder",
model.prompt_encoder,
(point_coords, point_labels),
None,
None,
warmup=3,
iterations=20,
)
# 3. Full Pipeline
results['full_pipeline'] = benchmark_component(
"Full Pipeline (encode + decode)",
model.predict,
image,
point_coords,
point_labels,
warmup=3,
iterations=10,
)
# Summary
print(f"\n{'='*60}")
print(f"PERFORMANCE SUMMARY")
print(f"{'='*60}")
for component, avg_time in results.items():
status = "โœ…" if avg_time < 1000 else "โš ๏ธ"
print(f"{status} {component:30s} {avg_time:8.2f}ms")
print(f"\n{'='*60}")
print(f"TARGET METRICS")
print(f"{'='*60}")
vision_target = 500 # ms
full_target = 200 # ms (after optimization)
vision_status = "โœ… PASS" if results['vision_encoder'] < vision_target else "โŒ FAIL"
full_status = "๐ŸŽฏ TARGET" if results['full_pipeline'] < full_target else "โš ๏ธ NEEDS OPTIMIZATION"
print(f"Vision Encoding: {vision_status} (target: <{vision_target}ms)")
print(f"Full Pipeline: {full_status} (target: <{full_target}ms)")
print(f"\n{'='*60}")
print("Benchmark complete!")
print(f"{'='*60}")
if __name__ == "__main__":
main()