""" Tests for SAM3 MLX models Validates that all model components work correctly """ try: import pytest except ImportError: pytest = None import mlx.core as mx import sys from pathlib import Path # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent)) from models.attention import MultiHeadAttentionRoPE, WindowedAttention, RoPEEmbedding from models.hiera import HieraVisionEncoder, create_hiera_base from models.prompt_encoder import PromptEncoder, create_prompt_encoder from models.mask_decoder import MaskDecoder, create_mask_decoder from models.sam3 import SAM3MLX class TestAttention: """Test attention modules""" def test_rope_embedding(self): """Test RoPE embedding generation""" rope = RoPEEmbedding(dim=64, max_seq_len=1024) emb = rope.forward(seq_len=256) assert emb.shape == (2, 256, 64), f"Wrong shape: {emb.shape}" print("✅ RoPE embedding test passed") def test_multihead_attention_rope(self): """Test multi-head attention with RoPE""" attn = MultiHeadAttentionRoPE(dim=256, num_heads=8, use_rope=True) # Create dummy input x = mx.random.normal((2, 64, 256)) # (batch, seq_len, dim) # Forward pass out = attn(x) assert out.shape == x.shape, f"Wrong output shape: {out.shape}" print("✅ Multi-head attention RoPE test passed") def test_windowed_attention(self): """Test windowed attention""" attn = WindowedAttention(dim=256, num_heads=8, window_size=14) x = mx.random.normal((2, 64, 256)) out = attn(x) assert out.shape == x.shape print("✅ Windowed attention test passed") class TestHiera: """Test Hiera vision encoder""" def test_hiera_base(self): """Test Hiera-Base encoder""" encoder = create_hiera_base() # Create dummy image (1024x1024 RGB in NHWC format) image = mx.random.normal((1, 1024, 1024, 3)) # Forward pass features = encoder(image) # Check output shape # After patch embedding (1024/14 = 73) and 3 downsample layers (73/8 = 9) # Should be (1, 81, 1024) - approximately 9x9 grid batch, num_patches, embed_dim = features.shape assert batch == 1, f"Wrong batch size: {batch}" assert embed_dim == 1024, f"Wrong embed dim: {embed_dim}" # Approximately 9x9 = 81 patches assert 70 < num_patches < 90, f"Wrong number of patches: {num_patches}" print(f"✅ Hiera-Base test passed - output shape: {features.shape}") class TestPromptEncoder: """Test prompt encoder""" def test_point_encoding(self): """Test point prompt encoding""" encoder = create_prompt_encoder( embed_dim=256, image_embedding_size=(64, 64), input_image_size=(1024, 1024), ) # Create point prompts point_coords = mx.array([[[512, 384]]]).astype(mx.float32) # (1, 1, 2) point_labels = mx.array([[1]]).astype(mx.float32) # (1, 1) sparse_emb, dense_emb = encoder( points=(point_coords, point_labels), boxes=None, masks=None, ) # Check sparse embeddings (should include padding) assert sparse_emb.shape[0] == 1 # batch assert sparse_emb.shape[2] == 256 # embed_dim # Check dense embeddings assert dense_emb.shape == (1, 64, 64, 256) print("✅ Prompt encoder point test passed") def test_box_encoding(self): """Test box prompt encoding""" encoder = create_prompt_encoder(embed_dim=256) # Create box prompt [x0, y0, x1, y1] box = mx.array([[100, 100, 500, 500]]).astype(mx.float32) sparse_emb, dense_emb = encoder( points=None, boxes=box, masks=None, ) # Should have 2 corner embeddings assert sparse_emb.shape[1] == 2 assert sparse_emb.shape[2] == 256 print("✅ Prompt encoder box test passed") class TestMaskDecoder: """Test mask decoder""" def test_mask_decoder(self): """Test mask decoder forward pass""" decoder = create_mask_decoder(transformer_dim=256) # Create dummy inputs B, H, W, C = 1, 64, 64, 256 image_embeddings = mx.random.normal((B, H, W, C)) image_pe = mx.random.normal((B, H, W, C)) sparse_prompt_embeddings = mx.random.normal((B, 3, C)) dense_prompt_embeddings = mx.zeros((B, H, W, C)) # Forward pass masks, iou_pred = decoder( image_embeddings=image_embeddings, image_pe=image_pe, sparse_prompt_embeddings=sparse_prompt_embeddings, dense_prompt_embeddings=dense_prompt_embeddings, multimask_output=True, ) # Check outputs assert masks.shape[0] == B assert masks.shape[1] == 3 # 3 masks in multimask mode assert iou_pred.shape == (B, 3) print(f"✅ Mask decoder test passed - masks shape: {masks.shape}") class TestSAM3: """Test complete SAM3 model""" def test_sam3_initialization(self): """Test SAM3 model initialization""" model = SAM3MLX() assert model is not None assert hasattr(model, 'vision_encoder') assert hasattr(model, 'prompt_encoder') assert hasattr(model, 'mask_decoder') print("✅ SAM3 initialization test passed") def test_sam3_forward(self): """Test SAM3 forward pass""" model = SAM3MLX() # Create dummy inputs image = mx.random.normal((1, 1024, 1024, 3)) point_coords = mx.array([[[512, 384]]]).astype(mx.float32) point_labels = mx.array([[1]]).astype(mx.float32) # Forward pass result = model.predict( image=image, point_coords=point_coords, point_labels=point_labels, multimask_output=True, ) # Check outputs assert "masks" in result assert "iou_predictions" in result masks = result["masks"] iou_pred = result["iou_predictions"] assert masks.shape[0] == 1 # batch assert masks.shape[1] == 3 # 3 masks assert iou_pred.shape == (1, 3) print(f"✅ SAM3 forward test passed") print(f" Masks shape: {masks.shape}") print(f" IoU predictions shape: {iou_pred.shape}") if __name__ == "__main__": print("🧪 Running SAM3 MLX Tests\n") print("=" * 60) # Run tests test_suite = [ ("Attention Tests", TestAttention), ("Hiera Tests", TestHiera), ("Prompt Encoder Tests", TestPromptEncoder), ("Mask Decoder Tests", TestMaskDecoder), ("SAM3 Tests", TestSAM3), ] passed = 0 failed = 0 for suite_name, test_class in test_suite: print(f"\n{suite_name}") print("-" * 60) test_instance = test_class() methods = [m for m in dir(test_instance) if m.startswith('test_')] for method_name in methods: try: method = getattr(test_instance, method_name) method() passed += 1 except Exception as e: print(f"❌ {method_name} failed: {e}") failed += 1 print("\n" + "=" * 60) print(f"Test Results: {passed} passed, {failed} failed") if failed == 0: print("✅ All tests passed!") exit(0) else: print(f"❌ {failed} tests failed") exit(1)