| """ |
| SAM3 Prompt Encoder - Complete MLX Implementation |
| |
| Encodes different types of user prompts: |
| - Points (clicks): Positive/negative points with coordinates |
| - Boxes: Bounding box coordinates (top-left, bottom-right) |
| - Masks: Dense mask inputs |
| |
| Outputs: |
| - Sparse embeddings: Point and box prompt embeddings |
| - Dense embeddings: Mask prompt embeddings |
| """ |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
| from mlx.nn import Module |
| from typing import Optional, Tuple, List |
| import math |
|
|
|
|
| class PositionEmbeddingRandom(Module): |
| """ |
| Positional encoding using random spatial frequencies |
| |
| Similar to Fourier features - maps 2D coordinates to high-dimensional space |
| using learned frequency basis. |
| """ |
|
|
| def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None): |
| super().__init__() |
| if scale is None or scale <= 0.0: |
| scale = 1.0 |
| self.scale = scale |
|
|
| |
| |
| self.positional_encoding_gaussian_matrix = mx.random.normal( |
| shape=(2, num_pos_feats) |
| ) * scale |
|
|
| def _pe_encoding(self, coords: mx.array) -> mx.array: |
| """ |
| Positionally encode points normalized to [0, 1] |
| |
| Args: |
| coords: (B, N, 2) coordinates in [0, 1] range |
| |
| Returns: |
| (B, N, num_pos_feats * 2) positional encoding |
| """ |
| |
| |
| coords_scaled = coords * 2 * math.pi |
|
|
| |
| |
| projected = coords_scaled @ self.positional_encoding_gaussian_matrix |
|
|
| |
| sin_proj = mx.sin(projected) |
| cos_proj = mx.cos(projected) |
|
|
| |
| return mx.concatenate([sin_proj, cos_proj], axis=-1) |
|
|
| def forward(self, size: Tuple[int, int]) -> mx.array: |
| """ |
| Generate positional encoding for a 2D grid |
| |
| Args: |
| size: (H, W) grid size |
| |
| Returns: |
| (H, W, C) positional encoding |
| """ |
| h, w = size |
| device = self.positional_encoding_gaussian_matrix.device |
|
|
| |
| |
| y_embed = mx.arange(h, dtype=mx.float32).reshape(-1, 1).broadcast_to((h, w)) |
| x_embed = mx.arange(w, dtype=mx.float32).reshape(1, -1).broadcast_to((h, w)) |
|
|
| |
| y_embed = y_embed / h |
| x_embed = x_embed / w |
|
|
| |
| coords = mx.stack([x_embed, y_embed], axis=-1) |
|
|
| |
| |
| coords = coords.reshape(1, h * w, 2) |
| pe = self._pe_encoding(coords) |
| pe = pe.reshape(h, w, -1) |
|
|
| return pe |
|
|
| def forward_with_coords( |
| self, coords_input: mx.array, image_size: Tuple[int, int] |
| ) -> mx.array: |
| """ |
| Encode arbitrary point coordinates |
| |
| Args: |
| coords_input: (B, N, 2) in pixel coordinates |
| image_size: (H, W) image dimensions for normalization |
| |
| Returns: |
| (B, N, C) positional encodings |
| """ |
| |
| coords = coords_input.astype(mx.float32) |
| coords[:, :, 0] = coords[:, :, 0] / image_size[1] |
| coords[:, :, 1] = coords[:, :, 1] / image_size[0] |
|
|
| return self._pe_encoding(coords) |
|
|
|
|
| class PromptEncoder(Module): |
| """ |
| Complete SAM3 Prompt Encoder |
| |
| Encodes prompts into embeddings for the mask decoder: |
| - Points: Sparse embeddings with learned type (positive/negative) |
| - Boxes: Sparse embeddings for corners (top-left, bottom-right) |
| - Masks: Dense embeddings from downsampled mask |
| |
| Args: |
| embed_dim: Channel dimension for embeddings |
| image_embedding_size: Size of image embeddings from encoder |
| input_image_size: Original input image size |
| mask_in_chans: Input channels for mask encoder (default 16) |
| """ |
|
|
| def __init__( |
| self, |
| embed_dim: int, |
| image_embedding_size: Tuple[int, int], |
| input_image_size: Tuple[int, int], |
| mask_in_chans: int = 16, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.input_image_size = input_image_size |
| self.image_embedding_size = image_embedding_size |
|
|
| |
| self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) |
|
|
| |
| self.num_point_embeddings = 4 |
| self.point_embeddings = [ |
| nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings) |
| ] |
|
|
| |
| self.not_a_point_embed = nn.Embedding(1, embed_dim) |
|
|
| |
| |
| self.mask_downscaling = nn.Sequential( |
| nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), |
| nn.LayerNorm(mask_in_chans // 4), |
| nn.GELU(), |
| nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), |
| nn.LayerNorm(mask_in_chans), |
| nn.GELU(), |
| nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), |
| ) |
|
|
| |
| self.no_mask_embed = nn.Embedding(1, embed_dim) |
|
|
| def get_dense_pe(self) -> mx.array: |
| """ |
| Get positional encoding for image embedding grid |
| |
| Returns: |
| (H, W, C) dense positional encoding |
| """ |
| return self.pe_layer(self.image_embedding_size) |
|
|
| def _embed_points( |
| self, |
| points: mx.array, |
| labels: mx.array, |
| pad: bool, |
| ) -> mx.array: |
| """ |
| Embed point prompts |
| |
| Args: |
| points: (B, N, 2) point coordinates |
| labels: (B, N) point labels (0=negative, 1=positive) |
| pad: Whether to pad with "not a point" embedding |
| |
| Returns: |
| (B, N, C) or (B, N+1, C) point embeddings |
| """ |
| |
| points = points + 0.5 |
| point_embedding = self.pe_layer.forward_with_coords( |
| points, self.input_image_size |
| ) |
|
|
| |
| |
| B, N, C = point_embedding.shape |
| for b in range(B): |
| for n in range(N): |
| label = int(labels[b, n].item()) |
| if label == 0: |
| |
| type_embed = self.point_embeddings[0].weight |
| elif label == 1: |
| |
| type_embed = self.point_embeddings[1].weight |
| else: |
| |
| type_embed = self.point_embeddings[0].weight |
|
|
| point_embedding[b, n, :] = point_embedding[b, n, :] + type_embed.reshape(-1) |
|
|
| |
| if pad: |
| padding_point = self.not_a_point_embed.weight.reshape(1, 1, -1).broadcast_to( |
| (B, 1, C) |
| ) |
| point_embedding = mx.concatenate([point_embedding, padding_point], axis=1) |
|
|
| return point_embedding |
|
|
| def _embed_boxes(self, boxes: mx.array) -> mx.array: |
| """ |
| Embed box prompts |
| |
| Args: |
| boxes: (B, 4) boxes as [x0, y0, x1, y1] |
| |
| Returns: |
| (B, 2, C) corner embeddings [top-left, bottom-right] |
| """ |
| B = boxes.shape[0] |
| boxes = boxes + 0.5 |
|
|
| |
| coords = mx.stack( |
| [ |
| boxes[:, :2], |
| boxes[:, 2:], |
| ], |
| axis=1, |
| ) |
|
|
| |
| corner_embedding = self.pe_layer.forward_with_coords( |
| coords, self.input_image_size |
| ) |
|
|
| |
| corner_embedding[:, 0, :] = corner_embedding[:, 0, :] + self.point_embeddings[2].weight.reshape(-1) |
| corner_embedding[:, 1, :] = corner_embedding[:, 1, :] + self.point_embeddings[3].weight.reshape(-1) |
|
|
| return corner_embedding |
|
|
| def _embed_masks(self, masks: mx.array) -> mx.array: |
| """ |
| Embed mask prompts |
| |
| Args: |
| masks: (B, 1, H, W) dense masks |
| |
| Returns: |
| (B, H_emb, W_emb, C) downsampled mask embeddings |
| """ |
| |
| mask_embedding = self.mask_downscaling(masks) |
| return mask_embedding |
|
|
| def forward( |
| self, |
| points: Optional[Tuple[mx.array, mx.array]] = None, |
| boxes: Optional[mx.array] = None, |
| masks: Optional[mx.array] = None, |
| ) -> Tuple[mx.array, mx.array]: |
| """ |
| Encode prompts into sparse and dense embeddings |
| |
| Args: |
| points: Optional tuple of (coords, labels) |
| - coords: (B, N, 2) point coordinates |
| - labels: (B, N) point labels (0=neg, 1=pos) |
| boxes: Optional (B, 4) boxes as [x0, y0, x1, y1] |
| masks: Optional (B, 1, H, W) mask prompts |
| |
| Returns: |
| sparse_embeddings: (B, N_sparse, C) point/box embeddings |
| dense_embeddings: (B, H_emb, W_emb, C) mask embeddings |
| """ |
| bs = 1 |
|
|
| |
| sparse_embeddings_list = [] |
|
|
| if points is not None: |
| coords, labels = points |
| bs = coords.shape[0] |
| point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) |
| sparse_embeddings_list.append(point_embeddings) |
|
|
| if boxes is not None: |
| bs = boxes.shape[0] |
| box_embeddings = self._embed_boxes(boxes) |
| sparse_embeddings_list.append(box_embeddings) |
|
|
| |
| if len(sparse_embeddings_list) > 0: |
| sparse_embeddings = mx.concatenate(sparse_embeddings_list, axis=1) |
| else: |
| |
| sparse_embeddings = self.not_a_point_embed.weight.reshape( |
| 1, 1, -1 |
| ).broadcast_to((bs, 1, self.embed_dim)) |
|
|
| |
| if masks is not None: |
| bs = masks.shape[0] |
| dense_embeddings = self._embed_masks(masks) |
| else: |
| |
| H, W = self.image_embedding_size |
| dense_embeddings = self.no_mask_embed.weight.reshape( |
| 1, 1, 1, -1 |
| ).broadcast_to((bs, H, W, self.embed_dim)) |
|
|
| return sparse_embeddings, dense_embeddings |
|
|
|
|
| def create_prompt_encoder( |
| embed_dim: int = 256, |
| image_embedding_size: Tuple[int, int] = (64, 64), |
| input_image_size: Tuple[int, int] = (1024, 1024), |
| ) -> PromptEncoder: |
| """ |
| Factory function to create SAM3 prompt encoder |
| |
| Args: |
| embed_dim: Embedding dimension |
| image_embedding_size: Size of vision encoder output |
| input_image_size: Size of input images |
| |
| Returns: |
| PromptEncoder instance |
| """ |
| return PromptEncoder( |
| embed_dim=embed_dim, |
| image_embedding_size=image_embedding_size, |
| input_image_size=input_image_size, |
| ) |
|
|