""" SAM3 Mask Decoder - Complete MLX Implementation Predicts high-resolution segmentation masks from: - Image embeddings (from Hiera vision encoder) - Prompt embeddings (from prompt encoder) Architecture: 1. Transformer decoder with cross-attention to image features 2. Dynamic mask prediction head 3. IoU quality prediction 4. Multi-mask output (3 masks + confidence scores) """ import mlx.core as mx import mlx.nn as nn from mlx.nn import Module from typing import Tuple, List class MLPBlock(Module): """ Simple MLP block with one hidden layer Used in transformer and prediction heads """ def __init__( self, embedding_dim: int, mlp_dim: int, activation=nn.GELU ): super().__init__() self.lin1 = nn.Linear(embedding_dim, mlp_dim) self.lin2 = nn.Linear(mlp_dim, embedding_dim) self.act = activation() def forward(self, x: mx.array) -> mx.array: return self.lin2(self.act(self.lin1(x))) class TwoWayAttentionBlock(Module): """ Two-way cross-attention transformer block Performs: 1. Self-attention on queries (prompts) 2. Cross-attention from queries to keys (image features) 3. MLP on queries 4. Cross-attention from keys to queries """ def __init__( self, embedding_dim: int, num_heads: int = 8, mlp_dim: int = 2048, activation=nn.GELU, skip_first_layer_pe: bool = False, ): super().__init__() self.self_attn = nn.MultiHeadAttention(embedding_dim, num_heads) self.norm1 = nn.LayerNorm(embedding_dim) self.cross_attn_token_to_image = nn.MultiHeadAttention( embedding_dim, num_heads // 2 ) self.norm2 = nn.LayerNorm(embedding_dim) self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) self.norm3 = nn.LayerNorm(embedding_dim) self.norm4 = nn.LayerNorm(embedding_dim) self.cross_attn_image_to_token = nn.MultiHeadAttention( embedding_dim, num_heads // 2 ) self.skip_first_layer_pe = skip_first_layer_pe def forward( self, queries: mx.array, keys: mx.array, query_pe: mx.array, key_pe: mx.array, ) -> Tuple[mx.array, mx.array]: """ Args: queries: (B, N_q, C) prompt tokens keys: (B, N_k, C) image tokens query_pe: (B, N_q, C) positional encoding for queries key_pe: (B, N_k, C) positional encoding for keys Returns: Updated queries and keys """ # Self-attention on queries if self.skip_first_layer_pe: queries = self.self_attn(queries, queries, queries) else: q = queries + query_pe queries = self.self_attn(q, q, queries) queries = self.norm1(queries) # Cross-attention: queries -> image q = queries + query_pe k = keys + key_pe queries = queries + self.cross_attn_token_to_image(q, k, keys) queries = self.norm2(queries) # MLP queries = queries + self.mlp(queries) queries = self.norm3(queries) # Cross-attention: image -> queries q = queries + query_pe k = keys + key_pe keys = keys + self.cross_attn_image_to_token(k, q, queries) keys = self.norm4(keys) return queries, keys class TwoWayTransformer(Module): """ Two-way transformer decoder Processes sparse prompts and dense image features to produce mask predictions """ def __init__( self, depth: int, embedding_dim: int, num_heads: int, mlp_dim: int, ): super().__init__() self.depth = depth self.embedding_dim = embedding_dim # Stack of two-way attention blocks self.layers = [ TwoWayAttentionBlock( embedding_dim=embedding_dim, num_heads=num_heads, mlp_dim=mlp_dim, skip_first_layer_pe=(i == 0), ) for i in range(depth) ] self.final_attn_token_to_image = nn.MultiHeadAttention( embedding_dim, num_heads ) self.norm_final_attn = nn.LayerNorm(embedding_dim) def forward( self, image_embedding: mx.array, image_pe: mx.array, point_embedding: mx.array, ) -> Tuple[mx.array, mx.array]: """ Args: image_embedding: (B, H*W, C) image features image_pe: (B, H*W, C) positional encoding for image point_embedding: (B, N, C) prompt embeddings Returns: Updated tokens and image features """ # Prepare queries (prompts) and keys (image) queries = point_embedding keys = image_embedding # Pass through transformer layers for layer in self.layers: queries, keys = layer( queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe, ) # Final attention from prompts to image q = queries + point_embedding k = keys + image_pe queries = queries + self.final_attn_token_to_image(q, k, keys) queries = self.norm_final_attn(queries) return queries, keys class MaskDecoder(Module): """ Complete SAM3 Mask Decoder Predicts segmentation masks from image and prompt embeddings. Outputs multiple masks with quality scores. Args: transformer_dim: Channel dimension of transformer transformer: Two-way transformer for mask prediction num_multimask_outputs: Number of masks to predict (default 3) iou_head_depth: Depth of IoU prediction MLP iou_head_hidden_dim: Hidden dim for IoU MLP """ def __init__( self, transformer_dim: int = 256, transformer_depth: int = 2, transformer_num_heads: int = 8, transformer_mlp_dim: int = 2048, num_multimask_outputs: int = 3, iou_head_depth: int = 3, iou_head_hidden_dim: int = 256, ): super().__init__() self.transformer_dim = transformer_dim self.num_multimask_outputs = num_multimask_outputs # Two-way transformer self.transformer = TwoWayTransformer( depth=transformer_depth, embedding_dim=transformer_dim, num_heads=transformer_num_heads, mlp_dim=transformer_mlp_dim, ) # IoU prediction head self.iou_token = nn.Embedding(1, transformer_dim) # Mask tokens for multi-mask prediction self.num_mask_tokens = num_multimask_outputs + 1 # +1 for single mask self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) # Output upscaling layers # Upsample from 64x64 -> 256x256 (4x upsampling) self.output_upscaling = nn.Sequential( nn.ConvTranspose2d( transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 ), nn.LayerNorm(transformer_dim // 4), nn.GELU(), nn.ConvTranspose2d( transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 ), nn.GELU(), ) # Mask prediction heads (one per mask) self.output_hypernetworks_mlps = [ MLPBlock(transformer_dim, transformer_dim // 8, nn.GELU) for _ in range(self.num_mask_tokens) ] # IoU prediction head self.iou_prediction_head = MLPBlock( transformer_dim, iou_head_hidden_dim, nn.GELU ) self.iou_prediction_linear = nn.Linear(iou_head_hidden_dim, self.num_mask_tokens) def forward( self, image_embeddings: mx.array, image_pe: mx.array, sparse_prompt_embeddings: mx.array, dense_prompt_embeddings: mx.array, multimask_output: bool = True, ) -> Tuple[mx.array, mx.array]: """ Predict masks from image and prompt embeddings Args: image_embeddings: (B, H, W, C) from vision encoder image_pe: (B, H, W, C) positional encoding for image sparse_prompt_embeddings: (B, N, C) point/box embeddings dense_prompt_embeddings: (B, H, W, C) mask embeddings multimask_output: Return 3 masks or 1 mask Returns: masks: (B, num_masks, H, W) predicted masks iou_pred: (B, num_masks) quality scores """ B, H, W, C = image_embeddings.shape # Flatten image embeddings and PE image_embeddings_flat = image_embeddings.reshape(B, H * W, C) image_pe_flat = image_pe.reshape(B, H * W, C) # Concatenate output tokens iou_token_out = self.iou_token.weight.reshape(1, 1, -1).broadcast_to( (B, 1, self.transformer_dim) ) mask_tokens_out = self.mask_tokens.weight.reshape(1, -1, self.transformer_dim).broadcast_to( (B, self.num_mask_tokens, self.transformer_dim) ) # Combine all prompt tokens: [IoU token, mask tokens, sparse prompts] tokens = mx.concatenate( [iou_token_out, mask_tokens_out, sparse_prompt_embeddings], axis=1 ) # Add dense prompt embeddings to image src = image_embeddings_flat + dense_prompt_embeddings.reshape(B, H * W, C) # Run through transformer hs, src = self.transformer(src, image_pe_flat, tokens) # Extract tokens iou_token_out = hs[:, 0:1, :] mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :] # Upscale image embeddings # Reshape to (B, H, W, C) for upsampling src = src.reshape(B, H, W, C) upscaled_embedding = self.output_upscaling(src) # (B, H*4, W*4, C//8) B_up, H_up, W_up, C_up = upscaled_embedding.shape # Predict masks using hypernetworks masks = [] for i in range(self.num_mask_tokens): # Get mask token features mask_features = self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) # (B, C//8) # Expand to spatial dimensions and compute dot product mask_features = mask_features.reshape(B, 1, 1, C_up) mask = (upscaled_embedding * mask_features).sum(axis=-1) # (B, H_up, W_up) masks.append(mask) masks = mx.stack(masks, axis=1) # (B, num_masks, H_up, W_up) # Predict IoU scores iou_pred = self.iou_prediction_head(iou_token_out) iou_pred = self.iou_prediction_linear(iou_pred).squeeze(1) # (B, num_masks) # Select correct masks if multimask_output: # Return 3 multi-masks mask_slice = slice(1, None) else: # Return single mask mask_slice = slice(0, 1) masks = masks[:, mask_slice, :, :] iou_pred = iou_pred[:, mask_slice] return masks, iou_pred def create_mask_decoder( transformer_dim: int = 256, num_multimask_outputs: int = 3, ) -> MaskDecoder: """ Factory function to create SAM3 mask decoder Args: transformer_dim: Feature dimension num_multimask_outputs: Number of masks to output Returns: MaskDecoder instance """ return MaskDecoder( transformer_dim=transformer_dim, num_multimask_outputs=num_multimask_outputs, )