| """ |
| Tests for SAM3 MLX models |
| |
| Validates that all model components work correctly |
| """ |
|
|
| try: |
| import pytest |
| except ImportError: |
| pytest = None |
|
|
| import mlx.core as mx |
| import sys |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from models.attention import MultiHeadAttentionRoPE, WindowedAttention, RoPEEmbedding |
| from models.hiera import HieraVisionEncoder, create_hiera_base |
| from models.prompt_encoder import PromptEncoder, create_prompt_encoder |
| from models.mask_decoder import MaskDecoder, create_mask_decoder |
| from models.sam3 import SAM3MLX |
|
|
|
|
| class TestAttention: |
| """Test attention modules""" |
|
|
| def test_rope_embedding(self): |
| """Test RoPE embedding generation""" |
| rope = RoPEEmbedding(dim=64, max_seq_len=1024) |
| emb = rope.forward(seq_len=256) |
|
|
| assert emb.shape == (2, 256, 64), f"Wrong shape: {emb.shape}" |
| print("β
RoPE embedding test passed") |
|
|
| def test_multihead_attention_rope(self): |
| """Test multi-head attention with RoPE""" |
| attn = MultiHeadAttentionRoPE(dim=256, num_heads=8, use_rope=True) |
|
|
| |
| x = mx.random.normal((2, 64, 256)) |
|
|
| |
| out = attn(x) |
|
|
| assert out.shape == x.shape, f"Wrong output shape: {out.shape}" |
| print("β
Multi-head attention RoPE test passed") |
|
|
| def test_windowed_attention(self): |
| """Test windowed attention""" |
| attn = WindowedAttention(dim=256, num_heads=8, window_size=14) |
|
|
| x = mx.random.normal((2, 64, 256)) |
| out = attn(x) |
|
|
| assert out.shape == x.shape |
| print("β
Windowed attention test passed") |
|
|
|
|
| class TestHiera: |
| """Test Hiera vision encoder""" |
|
|
| def test_hiera_base(self): |
| """Test Hiera-Base encoder""" |
| encoder = create_hiera_base() |
|
|
| |
| image = mx.random.normal((1, 1024, 1024, 3)) |
|
|
| |
| features = encoder(image) |
|
|
| |
| |
| |
| batch, num_patches, embed_dim = features.shape |
|
|
| assert batch == 1, f"Wrong batch size: {batch}" |
| assert embed_dim == 1024, f"Wrong embed dim: {embed_dim}" |
| |
| assert 70 < num_patches < 90, f"Wrong number of patches: {num_patches}" |
|
|
| print(f"β
Hiera-Base test passed - output shape: {features.shape}") |
|
|
|
|
| class TestPromptEncoder: |
| """Test prompt encoder""" |
|
|
| def test_point_encoding(self): |
| """Test point prompt encoding""" |
| encoder = create_prompt_encoder( |
| embed_dim=256, |
| image_embedding_size=(64, 64), |
| input_image_size=(1024, 1024), |
| ) |
|
|
| |
| point_coords = mx.array([[[512, 384]]]).astype(mx.float32) |
| point_labels = mx.array([[1]]).astype(mx.float32) |
|
|
| sparse_emb, dense_emb = encoder( |
| points=(point_coords, point_labels), |
| boxes=None, |
| masks=None, |
| ) |
|
|
| |
| assert sparse_emb.shape[0] == 1 |
| assert sparse_emb.shape[2] == 256 |
|
|
| |
| assert dense_emb.shape == (1, 64, 64, 256) |
|
|
| print("β
Prompt encoder point test passed") |
|
|
| def test_box_encoding(self): |
| """Test box prompt encoding""" |
| encoder = create_prompt_encoder(embed_dim=256) |
|
|
| |
| box = mx.array([[100, 100, 500, 500]]).astype(mx.float32) |
|
|
| sparse_emb, dense_emb = encoder( |
| points=None, |
| boxes=box, |
| masks=None, |
| ) |
|
|
| |
| assert sparse_emb.shape[1] == 2 |
| assert sparse_emb.shape[2] == 256 |
|
|
| print("β
Prompt encoder box test passed") |
|
|
|
|
| class TestMaskDecoder: |
| """Test mask decoder""" |
|
|
| def test_mask_decoder(self): |
| """Test mask decoder forward pass""" |
| decoder = create_mask_decoder(transformer_dim=256) |
|
|
| |
| B, H, W, C = 1, 64, 64, 256 |
| image_embeddings = mx.random.normal((B, H, W, C)) |
| image_pe = mx.random.normal((B, H, W, C)) |
| sparse_prompt_embeddings = mx.random.normal((B, 3, C)) |
| dense_prompt_embeddings = mx.zeros((B, H, W, C)) |
|
|
| |
| masks, iou_pred = decoder( |
| image_embeddings=image_embeddings, |
| image_pe=image_pe, |
| sparse_prompt_embeddings=sparse_prompt_embeddings, |
| dense_prompt_embeddings=dense_prompt_embeddings, |
| multimask_output=True, |
| ) |
|
|
| |
| assert masks.shape[0] == B |
| assert masks.shape[1] == 3 |
| assert iou_pred.shape == (B, 3) |
|
|
| print(f"β
Mask decoder test passed - masks shape: {masks.shape}") |
|
|
|
|
| class TestSAM3: |
| """Test complete SAM3 model""" |
|
|
| def test_sam3_initialization(self): |
| """Test SAM3 model initialization""" |
| model = SAM3MLX() |
|
|
| assert model is not None |
| assert hasattr(model, 'vision_encoder') |
| assert hasattr(model, 'prompt_encoder') |
| assert hasattr(model, 'mask_decoder') |
|
|
| print("β
SAM3 initialization test passed") |
|
|
| def test_sam3_forward(self): |
| """Test SAM3 forward pass""" |
| model = SAM3MLX() |
|
|
| |
| image = mx.random.normal((1, 1024, 1024, 3)) |
| point_coords = mx.array([[[512, 384]]]).astype(mx.float32) |
| point_labels = mx.array([[1]]).astype(mx.float32) |
|
|
| |
| result = model.predict( |
| image=image, |
| point_coords=point_coords, |
| point_labels=point_labels, |
| multimask_output=True, |
| ) |
|
|
| |
| assert "masks" in result |
| assert "iou_predictions" in result |
|
|
| masks = result["masks"] |
| iou_pred = result["iou_predictions"] |
|
|
| assert masks.shape[0] == 1 |
| assert masks.shape[1] == 3 |
| assert iou_pred.shape == (1, 3) |
|
|
| print(f"β
SAM3 forward test passed") |
| print(f" Masks shape: {masks.shape}") |
| print(f" IoU predictions shape: {iou_pred.shape}") |
|
|
|
|
| if __name__ == "__main__": |
| print("π§ͺ Running SAM3 MLX Tests\n") |
| print("=" * 60) |
|
|
| |
| test_suite = [ |
| ("Attention Tests", TestAttention), |
| ("Hiera Tests", TestHiera), |
| ("Prompt Encoder Tests", TestPromptEncoder), |
| ("Mask Decoder Tests", TestMaskDecoder), |
| ("SAM3 Tests", TestSAM3), |
| ] |
|
|
| passed = 0 |
| failed = 0 |
|
|
| for suite_name, test_class in test_suite: |
| print(f"\n{suite_name}") |
| print("-" * 60) |
|
|
| test_instance = test_class() |
| methods = [m for m in dir(test_instance) if m.startswith('test_')] |
|
|
| for method_name in methods: |
| try: |
| method = getattr(test_instance, method_name) |
| method() |
| passed += 1 |
| except Exception as e: |
| print(f"β {method_name} failed: {e}") |
| failed += 1 |
|
|
| print("\n" + "=" * 60) |
| print(f"Test Results: {passed} passed, {failed} failed") |
|
|
| if failed == 0: |
| print("β
All tests passed!") |
| exit(0) |
| else: |
| print(f"β {failed} tests failed") |
| exit(1) |
|
|