| """ |
| SAM3 Mask Decoder - Complete MLX Implementation |
| |
| Predicts high-resolution segmentation masks from: |
| - Image embeddings (from Hiera vision encoder) |
| - Prompt embeddings (from prompt encoder) |
| |
| Architecture: |
| 1. Transformer decoder with cross-attention to image features |
| 2. Dynamic mask prediction head |
| 3. IoU quality prediction |
| 4. Multi-mask output (3 masks + confidence scores) |
| """ |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
| from mlx.nn import Module |
| from typing import Tuple, List |
|
|
|
|
| class MLPBlock(Module): |
| """ |
| Simple MLP block with one hidden layer |
| Used in transformer and prediction heads |
| """ |
|
|
| def __init__( |
| self, |
| embedding_dim: int, |
| mlp_dim: int, |
| activation=nn.GELU |
| ): |
| super().__init__() |
| self.lin1 = nn.Linear(embedding_dim, mlp_dim) |
| self.lin2 = nn.Linear(mlp_dim, embedding_dim) |
| self.act = activation() |
|
|
| def forward(self, x: mx.array) -> mx.array: |
| return self.lin2(self.act(self.lin1(x))) |
|
|
|
|
| class TwoWayAttentionBlock(Module): |
| """ |
| Two-way cross-attention transformer block |
| |
| Performs: |
| 1. Self-attention on queries (prompts) |
| 2. Cross-attention from queries to keys (image features) |
| 3. MLP on queries |
| 4. Cross-attention from keys to queries |
| """ |
|
|
| def __init__( |
| self, |
| embedding_dim: int, |
| num_heads: int = 8, |
| mlp_dim: int = 2048, |
| activation=nn.GELU, |
| skip_first_layer_pe: bool = False, |
| ): |
| super().__init__() |
| self.self_attn = nn.MultiHeadAttention(embedding_dim, num_heads) |
| self.norm1 = nn.LayerNorm(embedding_dim) |
|
|
| self.cross_attn_token_to_image = nn.MultiHeadAttention( |
| embedding_dim, num_heads // 2 |
| ) |
| self.norm2 = nn.LayerNorm(embedding_dim) |
|
|
| self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) |
| self.norm3 = nn.LayerNorm(embedding_dim) |
|
|
| self.norm4 = nn.LayerNorm(embedding_dim) |
| self.cross_attn_image_to_token = nn.MultiHeadAttention( |
| embedding_dim, num_heads // 2 |
| ) |
|
|
| self.skip_first_layer_pe = skip_first_layer_pe |
|
|
| def forward( |
| self, |
| queries: mx.array, |
| keys: mx.array, |
| query_pe: mx.array, |
| key_pe: mx.array, |
| ) -> Tuple[mx.array, mx.array]: |
| """ |
| Args: |
| queries: (B, N_q, C) prompt tokens |
| keys: (B, N_k, C) image tokens |
| query_pe: (B, N_q, C) positional encoding for queries |
| key_pe: (B, N_k, C) positional encoding for keys |
| |
| Returns: |
| Updated queries and keys |
| """ |
| |
| if self.skip_first_layer_pe: |
| queries = self.self_attn(queries, queries, queries) |
| else: |
| q = queries + query_pe |
| queries = self.self_attn(q, q, queries) |
| queries = self.norm1(queries) |
|
|
| |
| q = queries + query_pe |
| k = keys + key_pe |
| queries = queries + self.cross_attn_token_to_image(q, k, keys) |
| queries = self.norm2(queries) |
|
|
| |
| queries = queries + self.mlp(queries) |
| queries = self.norm3(queries) |
|
|
| |
| q = queries + query_pe |
| k = keys + key_pe |
| keys = keys + self.cross_attn_image_to_token(k, q, queries) |
| keys = self.norm4(keys) |
|
|
| return queries, keys |
|
|
|
|
| class TwoWayTransformer(Module): |
| """ |
| Two-way transformer decoder |
| |
| Processes sparse prompts and dense image features |
| to produce mask predictions |
| """ |
|
|
| def __init__( |
| self, |
| depth: int, |
| embedding_dim: int, |
| num_heads: int, |
| mlp_dim: int, |
| ): |
| super().__init__() |
| self.depth = depth |
| self.embedding_dim = embedding_dim |
|
|
| |
| self.layers = [ |
| TwoWayAttentionBlock( |
| embedding_dim=embedding_dim, |
| num_heads=num_heads, |
| mlp_dim=mlp_dim, |
| skip_first_layer_pe=(i == 0), |
| ) |
| for i in range(depth) |
| ] |
|
|
| self.final_attn_token_to_image = nn.MultiHeadAttention( |
| embedding_dim, num_heads |
| ) |
| self.norm_final_attn = nn.LayerNorm(embedding_dim) |
|
|
| def forward( |
| self, |
| image_embedding: mx.array, |
| image_pe: mx.array, |
| point_embedding: mx.array, |
| ) -> Tuple[mx.array, mx.array]: |
| """ |
| Args: |
| image_embedding: (B, H*W, C) image features |
| image_pe: (B, H*W, C) positional encoding for image |
| point_embedding: (B, N, C) prompt embeddings |
| |
| Returns: |
| Updated tokens and image features |
| """ |
| |
| queries = point_embedding |
| keys = image_embedding |
|
|
| |
| for layer in self.layers: |
| queries, keys = layer( |
| queries=queries, |
| keys=keys, |
| query_pe=point_embedding, |
| key_pe=image_pe, |
| ) |
|
|
| |
| q = queries + point_embedding |
| k = keys + image_pe |
| queries = queries + self.final_attn_token_to_image(q, k, keys) |
| queries = self.norm_final_attn(queries) |
|
|
| return queries, keys |
|
|
|
|
| class MaskDecoder(Module): |
| """ |
| Complete SAM3 Mask Decoder |
| |
| Predicts segmentation masks from image and prompt embeddings. |
| Outputs multiple masks with quality scores. |
| |
| Args: |
| transformer_dim: Channel dimension of transformer |
| transformer: Two-way transformer for mask prediction |
| num_multimask_outputs: Number of masks to predict (default 3) |
| iou_head_depth: Depth of IoU prediction MLP |
| iou_head_hidden_dim: Hidden dim for IoU MLP |
| """ |
|
|
| def __init__( |
| self, |
| transformer_dim: int = 256, |
| transformer_depth: int = 2, |
| transformer_num_heads: int = 8, |
| transformer_mlp_dim: int = 2048, |
| num_multimask_outputs: int = 3, |
| iou_head_depth: int = 3, |
| iou_head_hidden_dim: int = 256, |
| ): |
| super().__init__() |
| self.transformer_dim = transformer_dim |
| self.num_multimask_outputs = num_multimask_outputs |
|
|
| |
| self.transformer = TwoWayTransformer( |
| depth=transformer_depth, |
| embedding_dim=transformer_dim, |
| num_heads=transformer_num_heads, |
| mlp_dim=transformer_mlp_dim, |
| ) |
|
|
| |
| self.iou_token = nn.Embedding(1, transformer_dim) |
|
|
| |
| self.num_mask_tokens = num_multimask_outputs + 1 |
| self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) |
|
|
| |
| |
| self.output_upscaling = nn.Sequential( |
| nn.ConvTranspose2d( |
| transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 |
| ), |
| nn.LayerNorm(transformer_dim // 4), |
| nn.GELU(), |
| nn.ConvTranspose2d( |
| transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 |
| ), |
| nn.GELU(), |
| ) |
|
|
| |
| self.output_hypernetworks_mlps = [ |
| MLPBlock(transformer_dim, transformer_dim // 8, nn.GELU) |
| for _ in range(self.num_mask_tokens) |
| ] |
|
|
| |
| self.iou_prediction_head = MLPBlock( |
| transformer_dim, iou_head_hidden_dim, nn.GELU |
| ) |
| self.iou_prediction_linear = nn.Linear(iou_head_hidden_dim, self.num_mask_tokens) |
|
|
| def forward( |
| self, |
| image_embeddings: mx.array, |
| image_pe: mx.array, |
| sparse_prompt_embeddings: mx.array, |
| dense_prompt_embeddings: mx.array, |
| multimask_output: bool = True, |
| ) -> Tuple[mx.array, mx.array]: |
| """ |
| Predict masks from image and prompt embeddings |
| |
| Args: |
| image_embeddings: (B, H, W, C) from vision encoder |
| image_pe: (B, H, W, C) positional encoding for image |
| sparse_prompt_embeddings: (B, N, C) point/box embeddings |
| dense_prompt_embeddings: (B, H, W, C) mask embeddings |
| multimask_output: Return 3 masks or 1 mask |
| |
| Returns: |
| masks: (B, num_masks, H, W) predicted masks |
| iou_pred: (B, num_masks) quality scores |
| """ |
| B, H, W, C = image_embeddings.shape |
|
|
| |
| image_embeddings_flat = image_embeddings.reshape(B, H * W, C) |
| image_pe_flat = image_pe.reshape(B, H * W, C) |
|
|
| |
| iou_token_out = self.iou_token.weight.reshape(1, 1, -1).broadcast_to( |
| (B, 1, self.transformer_dim) |
| ) |
| mask_tokens_out = self.mask_tokens.weight.reshape(1, -1, self.transformer_dim).broadcast_to( |
| (B, self.num_mask_tokens, self.transformer_dim) |
| ) |
|
|
| |
| tokens = mx.concatenate( |
| [iou_token_out, mask_tokens_out, sparse_prompt_embeddings], axis=1 |
| ) |
|
|
| |
| src = image_embeddings_flat + dense_prompt_embeddings.reshape(B, H * W, C) |
|
|
| |
| hs, src = self.transformer(src, image_pe_flat, tokens) |
|
|
| |
| iou_token_out = hs[:, 0:1, :] |
| mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :] |
|
|
| |
| |
| src = src.reshape(B, H, W, C) |
| upscaled_embedding = self.output_upscaling(src) |
|
|
| B_up, H_up, W_up, C_up = upscaled_embedding.shape |
|
|
| |
| masks = [] |
| for i in range(self.num_mask_tokens): |
| |
| mask_features = self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) |
| |
|
|
| |
| mask_features = mask_features.reshape(B, 1, 1, C_up) |
| mask = (upscaled_embedding * mask_features).sum(axis=-1) |
| masks.append(mask) |
|
|
| masks = mx.stack(masks, axis=1) |
|
|
| |
| iou_pred = self.iou_prediction_head(iou_token_out) |
| iou_pred = self.iou_prediction_linear(iou_pred).squeeze(1) |
|
|
| |
| if multimask_output: |
| |
| mask_slice = slice(1, None) |
| else: |
| |
| mask_slice = slice(0, 1) |
|
|
| masks = masks[:, mask_slice, :, :] |
| iou_pred = iou_pred[:, mask_slice] |
|
|
| return masks, iou_pred |
|
|
|
|
| def create_mask_decoder( |
| transformer_dim: int = 256, |
| num_multimask_outputs: int = 3, |
| ) -> MaskDecoder: |
| """ |
| Factory function to create SAM3 mask decoder |
| |
| Args: |
| transformer_dim: Feature dimension |
| num_multimask_outputs: Number of masks to output |
| |
| Returns: |
| MaskDecoder instance |
| """ |
| return MaskDecoder( |
| transformer_dim=transformer_dim, |
| num_multimask_outputs=num_multimask_outputs, |
| ) |
|
|