MLX
File size: 7,595 Bytes
ced11e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""
Tests for SAM3 MLX models

Validates that all model components work correctly
"""

try:
    import pytest
except ImportError:
    pytest = None

import mlx.core as mx
import sys
from pathlib import Path

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))

from models.attention import MultiHeadAttentionRoPE, WindowedAttention, RoPEEmbedding
from models.hiera import HieraVisionEncoder, create_hiera_base
from models.prompt_encoder import PromptEncoder, create_prompt_encoder
from models.mask_decoder import MaskDecoder, create_mask_decoder
from models.sam3 import SAM3MLX


class TestAttention:
    """Test attention modules"""

    def test_rope_embedding(self):
        """Test RoPE embedding generation"""
        rope = RoPEEmbedding(dim=64, max_seq_len=1024)
        emb = rope.forward(seq_len=256)

        assert emb.shape == (2, 256, 64), f"Wrong shape: {emb.shape}"
        print("βœ… RoPE embedding test passed")

    def test_multihead_attention_rope(self):
        """Test multi-head attention with RoPE"""
        attn = MultiHeadAttentionRoPE(dim=256, num_heads=8, use_rope=True)

        # Create dummy input
        x = mx.random.normal((2, 64, 256))  # (batch, seq_len, dim)

        # Forward pass
        out = attn(x)

        assert out.shape == x.shape, f"Wrong output shape: {out.shape}"
        print("βœ… Multi-head attention RoPE test passed")

    def test_windowed_attention(self):
        """Test windowed attention"""
        attn = WindowedAttention(dim=256, num_heads=8, window_size=14)

        x = mx.random.normal((2, 64, 256))
        out = attn(x)

        assert out.shape == x.shape
        print("βœ… Windowed attention test passed")


class TestHiera:
    """Test Hiera vision encoder"""

    def test_hiera_base(self):
        """Test Hiera-Base encoder"""
        encoder = create_hiera_base()

        # Create dummy image (1024x1024 RGB in NHWC format)
        image = mx.random.normal((1, 1024, 1024, 3))

        # Forward pass
        features = encoder(image)

        # Check output shape
        # After patch embedding (1024/14 = 73) and 3 downsample layers (73/8 = 9)
        # Should be (1, 81, 1024) - approximately 9x9 grid
        batch, num_patches, embed_dim = features.shape

        assert batch == 1, f"Wrong batch size: {batch}"
        assert embed_dim == 1024, f"Wrong embed dim: {embed_dim}"
        # Approximately 9x9 = 81 patches
        assert 70 < num_patches < 90, f"Wrong number of patches: {num_patches}"

        print(f"βœ… Hiera-Base test passed - output shape: {features.shape}")


class TestPromptEncoder:
    """Test prompt encoder"""

    def test_point_encoding(self):
        """Test point prompt encoding"""
        encoder = create_prompt_encoder(
            embed_dim=256,
            image_embedding_size=(64, 64),
            input_image_size=(1024, 1024),
        )

        # Create point prompts
        point_coords = mx.array([[[512, 384]]]).astype(mx.float32)  # (1, 1, 2)
        point_labels = mx.array([[1]]).astype(mx.float32)  # (1, 1)

        sparse_emb, dense_emb = encoder(
            points=(point_coords, point_labels),
            boxes=None,
            masks=None,
        )

        # Check sparse embeddings (should include padding)
        assert sparse_emb.shape[0] == 1  # batch
        assert sparse_emb.shape[2] == 256  # embed_dim

        # Check dense embeddings
        assert dense_emb.shape == (1, 64, 64, 256)

        print("βœ… Prompt encoder point test passed")

    def test_box_encoding(self):
        """Test box prompt encoding"""
        encoder = create_prompt_encoder(embed_dim=256)

        # Create box prompt [x0, y0, x1, y1]
        box = mx.array([[100, 100, 500, 500]]).astype(mx.float32)

        sparse_emb, dense_emb = encoder(
            points=None,
            boxes=box,
            masks=None,
        )

        # Should have 2 corner embeddings
        assert sparse_emb.shape[1] == 2
        assert sparse_emb.shape[2] == 256

        print("βœ… Prompt encoder box test passed")


class TestMaskDecoder:
    """Test mask decoder"""

    def test_mask_decoder(self):
        """Test mask decoder forward pass"""
        decoder = create_mask_decoder(transformer_dim=256)

        # Create dummy inputs
        B, H, W, C = 1, 64, 64, 256
        image_embeddings = mx.random.normal((B, H, W, C))
        image_pe = mx.random.normal((B, H, W, C))
        sparse_prompt_embeddings = mx.random.normal((B, 3, C))
        dense_prompt_embeddings = mx.zeros((B, H, W, C))

        # Forward pass
        masks, iou_pred = decoder(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_prompt_embeddings,
            dense_prompt_embeddings=dense_prompt_embeddings,
            multimask_output=True,
        )

        # Check outputs
        assert masks.shape[0] == B
        assert masks.shape[1] == 3  # 3 masks in multimask mode
        assert iou_pred.shape == (B, 3)

        print(f"βœ… Mask decoder test passed - masks shape: {masks.shape}")


class TestSAM3:
    """Test complete SAM3 model"""

    def test_sam3_initialization(self):
        """Test SAM3 model initialization"""
        model = SAM3MLX()

        assert model is not None
        assert hasattr(model, 'vision_encoder')
        assert hasattr(model, 'prompt_encoder')
        assert hasattr(model, 'mask_decoder')

        print("βœ… SAM3 initialization test passed")

    def test_sam3_forward(self):
        """Test SAM3 forward pass"""
        model = SAM3MLX()

        # Create dummy inputs
        image = mx.random.normal((1, 1024, 1024, 3))
        point_coords = mx.array([[[512, 384]]]).astype(mx.float32)
        point_labels = mx.array([[1]]).astype(mx.float32)

        # Forward pass
        result = model.predict(
            image=image,
            point_coords=point_coords,
            point_labels=point_labels,
            multimask_output=True,
        )

        # Check outputs
        assert "masks" in result
        assert "iou_predictions" in result

        masks = result["masks"]
        iou_pred = result["iou_predictions"]

        assert masks.shape[0] == 1  # batch
        assert masks.shape[1] == 3  # 3 masks
        assert iou_pred.shape == (1, 3)

        print(f"βœ… SAM3 forward test passed")
        print(f"   Masks shape: {masks.shape}")
        print(f"   IoU predictions shape: {iou_pred.shape}")


if __name__ == "__main__":
    print("πŸ§ͺ Running SAM3 MLX Tests\n")
    print("=" * 60)

    # Run tests
    test_suite = [
        ("Attention Tests", TestAttention),
        ("Hiera Tests", TestHiera),
        ("Prompt Encoder Tests", TestPromptEncoder),
        ("Mask Decoder Tests", TestMaskDecoder),
        ("SAM3 Tests", TestSAM3),
    ]

    passed = 0
    failed = 0

    for suite_name, test_class in test_suite:
        print(f"\n{suite_name}")
        print("-" * 60)

        test_instance = test_class()
        methods = [m for m in dir(test_instance) if m.startswith('test_')]

        for method_name in methods:
            try:
                method = getattr(test_instance, method_name)
                method()
                passed += 1
            except Exception as e:
                print(f"❌ {method_name} failed: {e}")
                failed += 1

    print("\n" + "=" * 60)
    print(f"Test Results: {passed} passed, {failed} failed")

    if failed == 0:
        print("βœ… All tests passed!")
        exit(0)
    else:
        print(f"❌ {failed} tests failed")
        exit(1)