Instructions to use Hoodrobot/MLX_SAM3 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use Hoodrobot/MLX_SAM3 with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir MLX_SAM3 Hoodrobot/MLX_SAM3
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
| """ | |
| Hiera (Hierarchical Vision Transformer) - Complete MLX Implementation | |
| This is the vision backbone used in SAM3, featuring: | |
| - Multi-scale hierarchical processing | |
| - Stage-wise spatial pooling | |
| - RoPE attention at each scale | |
| - Efficient computation via MLX/Metal | |
| """ | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| from mlx.nn import Module | |
| from typing import List, Optional, Tuple | |
| from .attention import MultiHeadAttentionRoPE, WindowedAttention | |
| class MLP(Module): | |
| """ | |
| Multi-Layer Perceptron with GELU activation | |
| Standard FFN block in transformers | |
| """ | |
| def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0): | |
| super().__init__() | |
| self.fc1 = nn.Linear(dim, hidden_dim) | |
| self.act = nn.GELU() | |
| self.fc2 = nn.Linear(hidden_dim, dim) | |
| self.dropout = nn.Dropout(dropout) if dropout > 0 else None | |
| def forward(self, x: mx.array) -> mx.array: | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| if self.dropout: | |
| x = self.dropout(x) | |
| x = self.fc2(x) | |
| if self.dropout: | |
| x = self.dropout(x) | |
| return x | |
| class HieraBlock(Module): | |
| """ | |
| Single Hiera transformer block | |
| Features: | |
| - Pre-LayerNorm architecture | |
| - RoPE Multi-Head Attention | |
| - MLP with GELU | |
| - Residual connections | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int, | |
| mlp_ratio: float = 4.0, | |
| qkv_bias: bool = True, | |
| dropout: float = 0.0, | |
| use_windowed_attn: bool = False, | |
| window_size: int = 14, | |
| ): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(dim) | |
| # Choose attention type | |
| if use_windowed_attn: | |
| self.attn = WindowedAttention( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| dropout=dropout, | |
| window_size=window_size | |
| ) | |
| else: | |
| self.attn = MultiHeadAttentionRoPE( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| dropout=dropout | |
| ) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.mlp = MLP(dim, int(dim * mlp_ratio), dropout=dropout) | |
| def forward(self, x: mx.array) -> mx.array: | |
| # Attention with pre-norm and residual | |
| x = x + self.attn(self.norm1(x)) | |
| # MLP with pre-norm and residual | |
| x = x + self.mlp(self.norm2(x)) | |
| return x | |
| class PatchEmbed(Module): | |
| """ | |
| Image to Patch Embedding using Conv2d | |
| Converts (B, H, W, C) image to (B, num_patches, embed_dim) patches | |
| """ | |
| def __init__( | |
| self, | |
| img_size: int = 1024, | |
| patch_size: int = 14, | |
| in_chans: int = 3, | |
| embed_dim: int = 1024 | |
| ): | |
| super().__init__() | |
| self.img_size = img_size | |
| self.patch_size = patch_size | |
| self.grid_size = img_size // patch_size | |
| self.num_patches = self.grid_size ** 2 | |
| # Convolution for patch embedding | |
| self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |
| def forward(self, x: mx.array) -> mx.array: | |
| """ | |
| Args: | |
| x: (B, H, W, C) in NHWC format (MLX convention) | |
| Returns: | |
| (B, num_patches, embed_dim) | |
| """ | |
| B, H, W, C = x.shape | |
| # Apply convolution | |
| x = self.proj(x) # (B, H', W', embed_dim) where H'=W'=grid_size | |
| # Flatten spatial dimensions | |
| B, H_p, W_p, C_emb = x.shape | |
| x = x.reshape(B, H_p * W_p, C_emb) # (B, num_patches, embed_dim) | |
| return x | |
| class DownsampleBlock(Module): | |
| """ | |
| Spatial downsampling block for hierarchical processing | |
| Reduces spatial resolution by 2x while increasing channels | |
| Uses depthwise-separable convolution for efficiency | |
| """ | |
| def __init__(self, in_dim: int, out_dim: int): | |
| super().__init__() | |
| # Depthwise convolution (2x2 pooling with stride 2) | |
| self.dw_conv = nn.Conv2d(in_dim, in_dim, kernel_size=2, stride=2, groups=in_dim) | |
| # Pointwise convolution (1x1 to change channels) | |
| self.pw_conv = nn.Conv2d(in_dim, out_dim, kernel_size=1) | |
| self.norm = nn.LayerNorm(out_dim) | |
| def forward(self, x: mx.array, h: int, w: int) -> Tuple[mx.array, int, int]: | |
| """ | |
| Args: | |
| x: (B, N, C) where N = h*w | |
| h, w: Spatial dimensions | |
| Returns: | |
| (B, N//4, C'), h//2, w//2 | |
| """ | |
| B, N, C = x.shape | |
| # Reshape to spatial format: (B, N, C) -> (B, h, w, C) | |
| x = x.reshape(B, h, w, C) | |
| # Apply convolutions | |
| x = self.dw_conv(x) | |
| x = self.pw_conv(x) | |
| # Flatten back: (B, h//2, w//2, out_dim) -> (B, N//4, out_dim) | |
| B, h_new, w_new, C_new = x.shape | |
| x = x.reshape(B, h_new * w_new, C_new) | |
| # Normalize | |
| x = self.norm(x) | |
| return x, h_new, w_new | |
| class HieraStage(Module): | |
| """ | |
| Single stage of Hiera with multiple blocks | |
| Each stage processes at a specific spatial scale | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| depth: int, | |
| num_heads: int, | |
| mlp_ratio: float = 4.0, | |
| use_windowed_attn: bool = False, | |
| window_size: int = 14, | |
| ): | |
| super().__init__() | |
| self.blocks = [ | |
| HieraBlock( | |
| dim=dim, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| use_windowed_attn=use_windowed_attn and (i % 2 == 0), # Alternate global/local | |
| window_size=window_size | |
| ) | |
| for i in range(depth) | |
| ] | |
| def forward(self, x: mx.array) -> mx.array: | |
| for block in self.blocks: | |
| x = block(x) | |
| return x | |
| class HieraVisionEncoder(Module): | |
| """ | |
| Complete Hiera Vision Encoder | |
| Multi-scale hierarchical vision transformer with: | |
| - 4 stages with increasing channel dimensions | |
| - Spatial downsampling between stages | |
| - RoPE attention at all scales | |
| - Both global and windowed attention | |
| Args: | |
| img_size: Input image size | |
| patch_size: Initial patch size | |
| in_chans: Input channels (3 for RGB) | |
| embed_dims: Channel dimensions for each stage | |
| depths: Number of blocks per stage | |
| num_heads: Attention heads per stage | |
| mlp_ratio: MLP hidden dim ratio | |
| use_windowed_attn: Use windowed attention in stages | |
| """ | |
| def __init__( | |
| self, | |
| img_size: int = 1024, | |
| patch_size: int = 14, | |
| in_chans: int = 3, | |
| embed_dims: List[int] = [256, 512, 1024, 1024], # Progressive channel increase | |
| depths: List[int] = [2, 8, 16, 6], # Blocks per stage | |
| num_heads: List[int] = [4, 8, 16, 16], | |
| mlp_ratio: float = 4.0, | |
| use_windowed_attn: bool = True, | |
| window_size: int = 14, | |
| ): | |
| super().__init__() | |
| assert len(embed_dims) == len(depths) == len(num_heads), \ | |
| "embed_dims, depths, and num_heads must have same length" | |
| self.num_stages = len(embed_dims) | |
| self.patch_size = patch_size | |
| # Patch embedding | |
| self.patch_embed = PatchEmbed( | |
| img_size=img_size, | |
| patch_size=patch_size, | |
| in_chans=in_chans, | |
| embed_dim=embed_dims[0] | |
| ) | |
| # Initial spatial dimensions | |
| self.init_h = self.init_w = img_size // patch_size | |
| # Pre-norm before stages | |
| self.norm_pre = nn.LayerNorm(embed_dims[0]) | |
| # Build stages | |
| self.stages = [] | |
| self.downsample_layers = [] | |
| for i in range(self.num_stages): | |
| # Create stage | |
| stage = HieraStage( | |
| dim=embed_dims[i], | |
| depth=depths[i], | |
| num_heads=num_heads[i], | |
| mlp_ratio=mlp_ratio, | |
| use_windowed_attn=use_windowed_attn, | |
| window_size=window_size | |
| ) | |
| self.stages.append(stage) | |
| # Create downsampling layer (except for last stage) | |
| if i < self.num_stages - 1: | |
| downsample = DownsampleBlock(embed_dims[i], embed_dims[i + 1]) | |
| self.downsample_layers.append(downsample) | |
| # Final norm | |
| self.norm = nn.LayerNorm(embed_dims[-1]) | |
| def forward(self, x: mx.array) -> mx.array: | |
| """ | |
| Args: | |
| x: (B, H, W, C) image in NHWC format | |
| Returns: | |
| (B, num_patches_final, embed_dim_final) features | |
| """ | |
| # Patch embedding | |
| x = self.patch_embed(x) # (B, num_patches, embed_dim[0]) | |
| # Pre-norm | |
| x = self.norm_pre(x) | |
| # Track spatial dimensions | |
| h, w = self.init_h, self.init_w | |
| # Process through stages | |
| for i, stage in enumerate(self.stages): | |
| # Apply stage | |
| x = stage(x) | |
| # Downsample (except last stage) | |
| if i < len(self.downsample_layers): | |
| x, h, w = self.downsample_layers[i](x, h, w) | |
| # Final norm | |
| x = self.norm(x) | |
| return x | |
| def create_hiera_base() -> HieraVisionEncoder: | |
| """Create Hiera-Base configuration (SAM3 default)""" | |
| return HieraVisionEncoder( | |
| img_size=1024, | |
| patch_size=14, | |
| embed_dims=[256, 512, 1024, 1024], | |
| depths=[2, 8, 16, 6], | |
| num_heads=[4, 8, 16, 16] | |
| ) | |
| def create_hiera_large() -> HieraVisionEncoder: | |
| """Create Hiera-Large configuration""" | |
| return HieraVisionEncoder( | |
| img_size=1024, | |
| patch_size=14, | |
| embed_dims=[384, 768, 1536, 1536], | |
| depths=[2, 8, 20, 8], | |
| num_heads=[6, 12, 24, 24] | |
| ) | |