""" SAM3 Prompt Encoder - Complete MLX Implementation Encodes different types of user prompts: - Points (clicks): Positive/negative points with coordinates - Boxes: Bounding box coordinates (top-left, bottom-right) - Masks: Dense mask inputs Outputs: - Sparse embeddings: Point and box prompt embeddings - Dense embeddings: Mask prompt embeddings """ import mlx.core as mx import mlx.nn as nn from mlx.nn import Module from typing import Optional, Tuple, List import math class PositionEmbeddingRandom(Module): """ Positional encoding using random spatial frequencies Similar to Fourier features - maps 2D coordinates to high-dimensional space using learned frequency basis. """ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None): super().__init__() if scale is None or scale <= 0.0: scale = 1.0 self.scale = scale # Random frequency matrix # Each row is a 2D frequency vector self.positional_encoding_gaussian_matrix = mx.random.normal( shape=(2, num_pos_feats) ) * scale def _pe_encoding(self, coords: mx.array) -> mx.array: """ Positionally encode points normalized to [0, 1] Args: coords: (B, N, 2) coordinates in [0, 1] range Returns: (B, N, num_pos_feats * 2) positional encoding """ # coords is (B, N, 2) # Multiply by frequency matrix: (B, N, 2) @ (2, num_pos_feats) -> (B, N, num_pos_feats) coords_scaled = coords * 2 * math.pi # Project through random frequencies # coords_scaled: (B, N, 2), matrix: (2, num_pos_feats) projected = coords_scaled @ self.positional_encoding_gaussian_matrix # Apply sin and cos sin_proj = mx.sin(projected) cos_proj = mx.cos(projected) # Concatenate: (B, N, num_pos_feats * 2) return mx.concatenate([sin_proj, cos_proj], axis=-1) def forward(self, size: Tuple[int, int]) -> mx.array: """ Generate positional encoding for a 2D grid Args: size: (H, W) grid size Returns: (H, W, C) positional encoding """ h, w = size device = self.positional_encoding_gaussian_matrix.device # Create coordinate grid # y_embed: (H, W), x_embed: (H, W) y_embed = mx.arange(h, dtype=mx.float32).reshape(-1, 1).broadcast_to((h, w)) x_embed = mx.arange(w, dtype=mx.float32).reshape(1, -1).broadcast_to((h, w)) # Normalize to [0, 1] y_embed = y_embed / h x_embed = x_embed / w # Stack to (H, W, 2) coords = mx.stack([x_embed, y_embed], axis=-1) # Encode: (H, W, 2) -> (H, W, C) # Add batch dimension, encode, remove batch dimension coords = coords.reshape(1, h * w, 2) pe = self._pe_encoding(coords) pe = pe.reshape(h, w, -1) return pe def forward_with_coords( self, coords_input: mx.array, image_size: Tuple[int, int] ) -> mx.array: """ Encode arbitrary point coordinates Args: coords_input: (B, N, 2) in pixel coordinates image_size: (H, W) image dimensions for normalization Returns: (B, N, C) positional encodings """ # Normalize coordinates to [0, 1] coords = coords_input.astype(mx.float32) coords[:, :, 0] = coords[:, :, 0] / image_size[1] # x / W coords[:, :, 1] = coords[:, :, 1] / image_size[0] # y / H return self._pe_encoding(coords) class PromptEncoder(Module): """ Complete SAM3 Prompt Encoder Encodes prompts into embeddings for the mask decoder: - Points: Sparse embeddings with learned type (positive/negative) - Boxes: Sparse embeddings for corners (top-left, bottom-right) - Masks: Dense embeddings from downsampled mask Args: embed_dim: Channel dimension for embeddings image_embedding_size: Size of image embeddings from encoder input_image_size: Original input image size mask_in_chans: Input channels for mask encoder (default 16) """ def __init__( self, embed_dim: int, image_embedding_size: Tuple[int, int], input_image_size: Tuple[int, int], mask_in_chans: int = 16, ): super().__init__() self.embed_dim = embed_dim self.input_image_size = input_image_size self.image_embedding_size = image_embedding_size # Positional encoding for points and boxes self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) # Learnable embeddings for different prompt types self.num_point_embeddings = 4 # pos, neg, top-left corner, bottom-right corner self.point_embeddings = [ nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings) ] # Embedding for "no mask" case self.not_a_point_embed = nn.Embedding(1, embed_dim) # Mask downsampling encoder # Downsample mask from input_image_size to image_embedding_size self.mask_downscaling = nn.Sequential( nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), nn.LayerNorm(mask_in_chans // 4), nn.GELU(), nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), nn.LayerNorm(mask_in_chans), nn.GELU(), nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), ) # No mask embedding (used when no mask prompt provided) self.no_mask_embed = nn.Embedding(1, embed_dim) def get_dense_pe(self) -> mx.array: """ Get positional encoding for image embedding grid Returns: (H, W, C) dense positional encoding """ return self.pe_layer(self.image_embedding_size) def _embed_points( self, points: mx.array, labels: mx.array, pad: bool, ) -> mx.array: """ Embed point prompts Args: points: (B, N, 2) point coordinates labels: (B, N) point labels (0=negative, 1=positive) pad: Whether to pad with "not a point" embedding Returns: (B, N, C) or (B, N+1, C) point embeddings """ # Add positional encoding to points points = points + 0.5 # Shift to center of pixel point_embedding = self.pe_layer.forward_with_coords( points, self.input_image_size ) # Add learned type embedding based on label # labels: 0 = negative, 1 = positive B, N, C = point_embedding.shape for b in range(B): for n in range(N): label = int(labels[b, n].item()) if label == 0: # Negative point type_embed = self.point_embeddings[0].weight elif label == 1: # Positive point type_embed = self.point_embeddings[1].weight else: # Unknown, use negative type_embed = self.point_embeddings[0].weight point_embedding[b, n, :] = point_embedding[b, n, :] + type_embed.reshape(-1) # Pad with "not a point" embedding if requested if pad: padding_point = self.not_a_point_embed.weight.reshape(1, 1, -1).broadcast_to( (B, 1, C) ) point_embedding = mx.concatenate([point_embedding, padding_point], axis=1) return point_embedding def _embed_boxes(self, boxes: mx.array) -> mx.array: """ Embed box prompts Args: boxes: (B, 4) boxes as [x0, y0, x1, y1] Returns: (B, 2, C) corner embeddings [top-left, bottom-right] """ B = boxes.shape[0] boxes = boxes + 0.5 # Shift to pixel centers # Split into corners: (B, 2, 2) coords = mx.stack( [ boxes[:, :2], # top-left [x0, y0] boxes[:, 2:], # bottom-right [x1, y1] ], axis=1, ) # Get positional encoding for corners corner_embedding = self.pe_layer.forward_with_coords( coords, self.input_image_size ) # (B, 2, C) # Add learned corner type embeddings corner_embedding[:, 0, :] = corner_embedding[:, 0, :] + self.point_embeddings[2].weight.reshape(-1) corner_embedding[:, 1, :] = corner_embedding[:, 1, :] + self.point_embeddings[3].weight.reshape(-1) return corner_embedding def _embed_masks(self, masks: mx.array) -> mx.array: """ Embed mask prompts Args: masks: (B, 1, H, W) dense masks Returns: (B, H_emb, W_emb, C) downsampled mask embeddings """ # Downsample mask to embedding size mask_embedding = self.mask_downscaling(masks) return mask_embedding def forward( self, points: Optional[Tuple[mx.array, mx.array]] = None, boxes: Optional[mx.array] = None, masks: Optional[mx.array] = None, ) -> Tuple[mx.array, mx.array]: """ Encode prompts into sparse and dense embeddings Args: points: Optional tuple of (coords, labels) - coords: (B, N, 2) point coordinates - labels: (B, N) point labels (0=neg, 1=pos) boxes: Optional (B, 4) boxes as [x0, y0, x1, y1] masks: Optional (B, 1, H, W) mask prompts Returns: sparse_embeddings: (B, N_sparse, C) point/box embeddings dense_embeddings: (B, H_emb, W_emb, C) mask embeddings """ bs = 1 # Default batch size # Handle sparse prompts (points and boxes) sparse_embeddings_list = [] if points is not None: coords, labels = points bs = coords.shape[0] point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) sparse_embeddings_list.append(point_embeddings) if boxes is not None: bs = boxes.shape[0] box_embeddings = self._embed_boxes(boxes) sparse_embeddings_list.append(box_embeddings) # Concatenate all sparse embeddings if len(sparse_embeddings_list) > 0: sparse_embeddings = mx.concatenate(sparse_embeddings_list, axis=1) else: # No sparse prompts - use "not a point" embedding sparse_embeddings = self.not_a_point_embed.weight.reshape( 1, 1, -1 ).broadcast_to((bs, 1, self.embed_dim)) # Handle dense prompts (masks) if masks is not None: bs = masks.shape[0] dense_embeddings = self._embed_masks(masks) else: # No mask prompt - broadcast no_mask_embed to image embedding size H, W = self.image_embedding_size dense_embeddings = self.no_mask_embed.weight.reshape( 1, 1, 1, -1 ).broadcast_to((bs, H, W, self.embed_dim)) return sparse_embeddings, dense_embeddings def create_prompt_encoder( embed_dim: int = 256, image_embedding_size: Tuple[int, int] = (64, 64), input_image_size: Tuple[int, int] = (1024, 1024), ) -> PromptEncoder: """ Factory function to create SAM3 prompt encoder Args: embed_dim: Embedding dimension image_embedding_size: Size of vision encoder output input_image_size: Size of input images Returns: PromptEncoder instance """ return PromptEncoder( embed_dim=embed_dim, image_embedding_size=image_embedding_size, input_image_size=input_image_size, )