MLX
MLX_SAM3 / mask_decoder.py
Hoodrobot's picture
Upload 15 files
ced11e2 verified
"""
SAM3 Mask Decoder - Complete MLX Implementation
Predicts high-resolution segmentation masks from:
- Image embeddings (from Hiera vision encoder)
- Prompt embeddings (from prompt encoder)
Architecture:
1. Transformer decoder with cross-attention to image features
2. Dynamic mask prediction head
3. IoU quality prediction
4. Multi-mask output (3 masks + confidence scores)
"""
import mlx.core as mx
import mlx.nn as nn
from mlx.nn import Module
from typing import Tuple, List
class MLPBlock(Module):
"""
Simple MLP block with one hidden layer
Used in transformer and prediction heads
"""
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
activation=nn.GELU
):
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = activation()
def forward(self, x: mx.array) -> mx.array:
return self.lin2(self.act(self.lin1(x)))
class TwoWayAttentionBlock(Module):
"""
Two-way cross-attention transformer block
Performs:
1. Self-attention on queries (prompts)
2. Cross-attention from queries to keys (image features)
3. MLP on queries
4. Cross-attention from keys to queries
"""
def __init__(
self,
embedding_dim: int,
num_heads: int = 8,
mlp_dim: int = 2048,
activation=nn.GELU,
skip_first_layer_pe: bool = False,
):
super().__init__()
self.self_attn = nn.MultiHeadAttention(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = nn.MultiHeadAttention(
embedding_dim, num_heads // 2
)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = nn.MultiHeadAttention(
embedding_dim, num_heads // 2
)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
self,
queries: mx.array,
keys: mx.array,
query_pe: mx.array,
key_pe: mx.array,
) -> Tuple[mx.array, mx.array]:
"""
Args:
queries: (B, N_q, C) prompt tokens
keys: (B, N_k, C) image tokens
query_pe: (B, N_q, C) positional encoding for queries
key_pe: (B, N_k, C) positional encoding for keys
Returns:
Updated queries and keys
"""
# Self-attention on queries
if self.skip_first_layer_pe:
queries = self.self_attn(queries, queries, queries)
else:
q = queries + query_pe
queries = self.self_attn(q, q, queries)
queries = self.norm1(queries)
# Cross-attention: queries -> image
q = queries + query_pe
k = keys + key_pe
queries = queries + self.cross_attn_token_to_image(q, k, keys)
queries = self.norm2(queries)
# MLP
queries = queries + self.mlp(queries)
queries = self.norm3(queries)
# Cross-attention: image -> queries
q = queries + query_pe
k = keys + key_pe
keys = keys + self.cross_attn_image_to_token(k, q, queries)
keys = self.norm4(keys)
return queries, keys
class TwoWayTransformer(Module):
"""
Two-way transformer decoder
Processes sparse prompts and dense image features
to produce mask predictions
"""
def __init__(
self,
depth: int,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
):
super().__init__()
self.depth = depth
self.embedding_dim = embedding_dim
# Stack of two-way attention blocks
self.layers = [
TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
skip_first_layer_pe=(i == 0),
)
for i in range(depth)
]
self.final_attn_token_to_image = nn.MultiHeadAttention(
embedding_dim, num_heads
)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
def forward(
self,
image_embedding: mx.array,
image_pe: mx.array,
point_embedding: mx.array,
) -> Tuple[mx.array, mx.array]:
"""
Args:
image_embedding: (B, H*W, C) image features
image_pe: (B, H*W, C) positional encoding for image
point_embedding: (B, N, C) prompt embeddings
Returns:
Updated tokens and image features
"""
# Prepare queries (prompts) and keys (image)
queries = point_embedding
keys = image_embedding
# Pass through transformer layers
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
# Final attention from prompts to image
q = queries + point_embedding
k = keys + image_pe
queries = queries + self.final_attn_token_to_image(q, k, keys)
queries = self.norm_final_attn(queries)
return queries, keys
class MaskDecoder(Module):
"""
Complete SAM3 Mask Decoder
Predicts segmentation masks from image and prompt embeddings.
Outputs multiple masks with quality scores.
Args:
transformer_dim: Channel dimension of transformer
transformer: Two-way transformer for mask prediction
num_multimask_outputs: Number of masks to predict (default 3)
iou_head_depth: Depth of IoU prediction MLP
iou_head_hidden_dim: Hidden dim for IoU MLP
"""
def __init__(
self,
transformer_dim: int = 256,
transformer_depth: int = 2,
transformer_num_heads: int = 8,
transformer_mlp_dim: int = 2048,
num_multimask_outputs: int = 3,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
):
super().__init__()
self.transformer_dim = transformer_dim
self.num_multimask_outputs = num_multimask_outputs
# Two-way transformer
self.transformer = TwoWayTransformer(
depth=transformer_depth,
embedding_dim=transformer_dim,
num_heads=transformer_num_heads,
mlp_dim=transformer_mlp_dim,
)
# IoU prediction head
self.iou_token = nn.Embedding(1, transformer_dim)
# Mask tokens for multi-mask prediction
self.num_mask_tokens = num_multimask_outputs + 1 # +1 for single mask
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
# Output upscaling layers
# Upsample from 64x64 -> 256x256 (4x upsampling)
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
),
nn.LayerNorm(transformer_dim // 4),
nn.GELU(),
nn.ConvTranspose2d(
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
),
nn.GELU(),
)
# Mask prediction heads (one per mask)
self.output_hypernetworks_mlps = [
MLPBlock(transformer_dim, transformer_dim // 8, nn.GELU)
for _ in range(self.num_mask_tokens)
]
# IoU prediction head
self.iou_prediction_head = MLPBlock(
transformer_dim, iou_head_hidden_dim, nn.GELU
)
self.iou_prediction_linear = nn.Linear(iou_head_hidden_dim, self.num_mask_tokens)
def forward(
self,
image_embeddings: mx.array,
image_pe: mx.array,
sparse_prompt_embeddings: mx.array,
dense_prompt_embeddings: mx.array,
multimask_output: bool = True,
) -> Tuple[mx.array, mx.array]:
"""
Predict masks from image and prompt embeddings
Args:
image_embeddings: (B, H, W, C) from vision encoder
image_pe: (B, H, W, C) positional encoding for image
sparse_prompt_embeddings: (B, N, C) point/box embeddings
dense_prompt_embeddings: (B, H, W, C) mask embeddings
multimask_output: Return 3 masks or 1 mask
Returns:
masks: (B, num_masks, H, W) predicted masks
iou_pred: (B, num_masks) quality scores
"""
B, H, W, C = image_embeddings.shape
# Flatten image embeddings and PE
image_embeddings_flat = image_embeddings.reshape(B, H * W, C)
image_pe_flat = image_pe.reshape(B, H * W, C)
# Concatenate output tokens
iou_token_out = self.iou_token.weight.reshape(1, 1, -1).broadcast_to(
(B, 1, self.transformer_dim)
)
mask_tokens_out = self.mask_tokens.weight.reshape(1, -1, self.transformer_dim).broadcast_to(
(B, self.num_mask_tokens, self.transformer_dim)
)
# Combine all prompt tokens: [IoU token, mask tokens, sparse prompts]
tokens = mx.concatenate(
[iou_token_out, mask_tokens_out, sparse_prompt_embeddings], axis=1
)
# Add dense prompt embeddings to image
src = image_embeddings_flat + dense_prompt_embeddings.reshape(B, H * W, C)
# Run through transformer
hs, src = self.transformer(src, image_pe_flat, tokens)
# Extract tokens
iou_token_out = hs[:, 0:1, :]
mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
# Upscale image embeddings
# Reshape to (B, H, W, C) for upsampling
src = src.reshape(B, H, W, C)
upscaled_embedding = self.output_upscaling(src) # (B, H*4, W*4, C//8)
B_up, H_up, W_up, C_up = upscaled_embedding.shape
# Predict masks using hypernetworks
masks = []
for i in range(self.num_mask_tokens):
# Get mask token features
mask_features = self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
# (B, C//8)
# Expand to spatial dimensions and compute dot product
mask_features = mask_features.reshape(B, 1, 1, C_up)
mask = (upscaled_embedding * mask_features).sum(axis=-1) # (B, H_up, W_up)
masks.append(mask)
masks = mx.stack(masks, axis=1) # (B, num_masks, H_up, W_up)
# Predict IoU scores
iou_pred = self.iou_prediction_head(iou_token_out)
iou_pred = self.iou_prediction_linear(iou_pred).squeeze(1) # (B, num_masks)
# Select correct masks
if multimask_output:
# Return 3 multi-masks
mask_slice = slice(1, None)
else:
# Return single mask
mask_slice = slice(0, 1)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
return masks, iou_pred
def create_mask_decoder(
transformer_dim: int = 256,
num_multimask_outputs: int = 3,
) -> MaskDecoder:
"""
Factory function to create SAM3 mask decoder
Args:
transformer_dim: Feature dimension
num_multimask_outputs: Number of masks to output
Returns:
MaskDecoder instance
"""
return MaskDecoder(
transformer_dim=transformer_dim,
num_multimask_outputs=num_multimask_outputs,
)