File size: 3,965 Bytes
ced11e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | #!/usr/bin/env python3
"""
SAM3 MLX Benchmarks
Measures performance on Apple Silicon to validate <200ms target
"""
import time
import mlx.core as mx
import numpy as np
import sys
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from models.sam3 import SAM3MLX
def benchmark_component(name: str, func, *args, warmup=3, iterations=10, **kwargs):
"""Benchmark a component with warmup"""
print(f"\n{'='*60}")
print(f"Benchmarking: {name}")
print(f"{'='*60}")
# Warmup
print(f"Warming up ({warmup} iterations)...")
for _ in range(warmup):
result = func(*args, **kwargs)
if isinstance(result, dict):
for v in result.values():
if isinstance(v, mx.array):
mx.eval(v)
elif isinstance(v, mx.array):
mx.eval(result)
# Benchmark
print(f"Running benchmark ({iterations} iterations)...")
times = []
for i in range(iterations):
start = time.time()
result = func(*args, **kwargs)
# Force evaluation
if isinstance(result, dict):
for v in result.values():
if isinstance(v, mx.array):
mx.eval(v)
elif isinstance(result, mx.array):
mx.eval(result)
elapsed = (time.time() - start) * 1000 # Convert to ms
times.append(elapsed)
print(f" Iteration {i+1}: {elapsed:.2f}ms")
# Statistics
times = np.array(times)
print(f"\n📊 Results:")
print(f" Mean: {times.mean():.2f}ms")
print(f" Median: {np.median(times):.2f}ms")
print(f" Min: {times.min():.2f}ms")
print(f" Max: {times.max():.2f}ms")
print(f" Std: {times.std():.2f}ms")
return times.mean()
def main():
print("🚀 SAM3 MLX Performance Benchmarks")
print("=" * 60)
print(f"MLX version: {mx.__version__}")
print(f"Device: Apple Silicon (Metal)")
print("=" * 60)
# Initialize model
print("\n🏗️ Initializing SAM3 MLX...")
model = SAM3MLX()
# Prepare inputs
print("\n📦 Preparing test inputs...")
image = mx.random.normal((1, 1024, 1024, 3))
point_coords = mx.array([[[512, 384]]]).astype(mx.float32)
point_labels = mx.array([[1]]).astype(mx.float32)
# Benchmark components
results = {}
# 1. Vision Encoder
results['vision_encoder'] = benchmark_component(
"Vision Encoder (Hiera)",
model.encode_image,
image,
warmup=3,
iterations=10,
)
# 2. Prompt Encoder
results['prompt_encoder'] = benchmark_component(
"Prompt Encoder",
model.prompt_encoder,
(point_coords, point_labels),
None,
None,
warmup=3,
iterations=20,
)
# 3. Full Pipeline
results['full_pipeline'] = benchmark_component(
"Full Pipeline (encode + decode)",
model.predict,
image,
point_coords,
point_labels,
warmup=3,
iterations=10,
)
# Summary
print(f"\n{'='*60}")
print(f"PERFORMANCE SUMMARY")
print(f"{'='*60}")
for component, avg_time in results.items():
status = "✅" if avg_time < 1000 else "⚠️"
print(f"{status} {component:30s} {avg_time:8.2f}ms")
print(f"\n{'='*60}")
print(f"TARGET METRICS")
print(f"{'='*60}")
vision_target = 500 # ms
full_target = 200 # ms (after optimization)
vision_status = "✅ PASS" if results['vision_encoder'] < vision_target else "❌ FAIL"
full_status = "🎯 TARGET" if results['full_pipeline'] < full_target else "⚠️ NEEDS OPTIMIZATION"
print(f"Vision Encoding: {vision_status} (target: <{vision_target}ms)")
print(f"Full Pipeline: {full_status} (target: <{full_target}ms)")
print(f"\n{'='*60}")
print("Benchmark complete!")
print(f"{'='*60}")
if __name__ == "__main__":
main()
|