MLX
MLX_SAM3 / test_models.py
Hoodrobot's picture
Upload 15 files
ced11e2 verified
"""
Tests for SAM3 MLX models
Validates that all model components work correctly
"""
try:
import pytest
except ImportError:
pytest = None
import mlx.core as mx
import sys
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from models.attention import MultiHeadAttentionRoPE, WindowedAttention, RoPEEmbedding
from models.hiera import HieraVisionEncoder, create_hiera_base
from models.prompt_encoder import PromptEncoder, create_prompt_encoder
from models.mask_decoder import MaskDecoder, create_mask_decoder
from models.sam3 import SAM3MLX
class TestAttention:
"""Test attention modules"""
def test_rope_embedding(self):
"""Test RoPE embedding generation"""
rope = RoPEEmbedding(dim=64, max_seq_len=1024)
emb = rope.forward(seq_len=256)
assert emb.shape == (2, 256, 64), f"Wrong shape: {emb.shape}"
print("βœ… RoPE embedding test passed")
def test_multihead_attention_rope(self):
"""Test multi-head attention with RoPE"""
attn = MultiHeadAttentionRoPE(dim=256, num_heads=8, use_rope=True)
# Create dummy input
x = mx.random.normal((2, 64, 256)) # (batch, seq_len, dim)
# Forward pass
out = attn(x)
assert out.shape == x.shape, f"Wrong output shape: {out.shape}"
print("βœ… Multi-head attention RoPE test passed")
def test_windowed_attention(self):
"""Test windowed attention"""
attn = WindowedAttention(dim=256, num_heads=8, window_size=14)
x = mx.random.normal((2, 64, 256))
out = attn(x)
assert out.shape == x.shape
print("βœ… Windowed attention test passed")
class TestHiera:
"""Test Hiera vision encoder"""
def test_hiera_base(self):
"""Test Hiera-Base encoder"""
encoder = create_hiera_base()
# Create dummy image (1024x1024 RGB in NHWC format)
image = mx.random.normal((1, 1024, 1024, 3))
# Forward pass
features = encoder(image)
# Check output shape
# After patch embedding (1024/14 = 73) and 3 downsample layers (73/8 = 9)
# Should be (1, 81, 1024) - approximately 9x9 grid
batch, num_patches, embed_dim = features.shape
assert batch == 1, f"Wrong batch size: {batch}"
assert embed_dim == 1024, f"Wrong embed dim: {embed_dim}"
# Approximately 9x9 = 81 patches
assert 70 < num_patches < 90, f"Wrong number of patches: {num_patches}"
print(f"βœ… Hiera-Base test passed - output shape: {features.shape}")
class TestPromptEncoder:
"""Test prompt encoder"""
def test_point_encoding(self):
"""Test point prompt encoding"""
encoder = create_prompt_encoder(
embed_dim=256,
image_embedding_size=(64, 64),
input_image_size=(1024, 1024),
)
# Create point prompts
point_coords = mx.array([[[512, 384]]]).astype(mx.float32) # (1, 1, 2)
point_labels = mx.array([[1]]).astype(mx.float32) # (1, 1)
sparse_emb, dense_emb = encoder(
points=(point_coords, point_labels),
boxes=None,
masks=None,
)
# Check sparse embeddings (should include padding)
assert sparse_emb.shape[0] == 1 # batch
assert sparse_emb.shape[2] == 256 # embed_dim
# Check dense embeddings
assert dense_emb.shape == (1, 64, 64, 256)
print("βœ… Prompt encoder point test passed")
def test_box_encoding(self):
"""Test box prompt encoding"""
encoder = create_prompt_encoder(embed_dim=256)
# Create box prompt [x0, y0, x1, y1]
box = mx.array([[100, 100, 500, 500]]).astype(mx.float32)
sparse_emb, dense_emb = encoder(
points=None,
boxes=box,
masks=None,
)
# Should have 2 corner embeddings
assert sparse_emb.shape[1] == 2
assert sparse_emb.shape[2] == 256
print("βœ… Prompt encoder box test passed")
class TestMaskDecoder:
"""Test mask decoder"""
def test_mask_decoder(self):
"""Test mask decoder forward pass"""
decoder = create_mask_decoder(transformer_dim=256)
# Create dummy inputs
B, H, W, C = 1, 64, 64, 256
image_embeddings = mx.random.normal((B, H, W, C))
image_pe = mx.random.normal((B, H, W, C))
sparse_prompt_embeddings = mx.random.normal((B, 3, C))
dense_prompt_embeddings = mx.zeros((B, H, W, C))
# Forward pass
masks, iou_pred = decoder(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
multimask_output=True,
)
# Check outputs
assert masks.shape[0] == B
assert masks.shape[1] == 3 # 3 masks in multimask mode
assert iou_pred.shape == (B, 3)
print(f"βœ… Mask decoder test passed - masks shape: {masks.shape}")
class TestSAM3:
"""Test complete SAM3 model"""
def test_sam3_initialization(self):
"""Test SAM3 model initialization"""
model = SAM3MLX()
assert model is not None
assert hasattr(model, 'vision_encoder')
assert hasattr(model, 'prompt_encoder')
assert hasattr(model, 'mask_decoder')
print("βœ… SAM3 initialization test passed")
def test_sam3_forward(self):
"""Test SAM3 forward pass"""
model = SAM3MLX()
# Create dummy inputs
image = mx.random.normal((1, 1024, 1024, 3))
point_coords = mx.array([[[512, 384]]]).astype(mx.float32)
point_labels = mx.array([[1]]).astype(mx.float32)
# Forward pass
result = model.predict(
image=image,
point_coords=point_coords,
point_labels=point_labels,
multimask_output=True,
)
# Check outputs
assert "masks" in result
assert "iou_predictions" in result
masks = result["masks"]
iou_pred = result["iou_predictions"]
assert masks.shape[0] == 1 # batch
assert masks.shape[1] == 3 # 3 masks
assert iou_pred.shape == (1, 3)
print(f"βœ… SAM3 forward test passed")
print(f" Masks shape: {masks.shape}")
print(f" IoU predictions shape: {iou_pred.shape}")
if __name__ == "__main__":
print("πŸ§ͺ Running SAM3 MLX Tests\n")
print("=" * 60)
# Run tests
test_suite = [
("Attention Tests", TestAttention),
("Hiera Tests", TestHiera),
("Prompt Encoder Tests", TestPromptEncoder),
("Mask Decoder Tests", TestMaskDecoder),
("SAM3 Tests", TestSAM3),
]
passed = 0
failed = 0
for suite_name, test_class in test_suite:
print(f"\n{suite_name}")
print("-" * 60)
test_instance = test_class()
methods = [m for m in dir(test_instance) if m.startswith('test_')]
for method_name in methods:
try:
method = getattr(test_instance, method_name)
method()
passed += 1
except Exception as e:
print(f"❌ {method_name} failed: {e}")
failed += 1
print("\n" + "=" * 60)
print(f"Test Results: {passed} passed, {failed} failed")
if failed == 0:
print("βœ… All tests passed!")
exit(0)
else:
print(f"❌ {failed} tests failed")
exit(1)