MLX
MLX_SAM3 / hiera.py
Hoodrobot's picture
Upload 15 files
ced11e2 verified
"""
Hiera (Hierarchical Vision Transformer) - Complete MLX Implementation
This is the vision backbone used in SAM3, featuring:
- Multi-scale hierarchical processing
- Stage-wise spatial pooling
- RoPE attention at each scale
- Efficient computation via MLX/Metal
"""
import mlx.core as mx
import mlx.nn as nn
from mlx.nn import Module
from typing import List, Optional, Tuple
from .attention import MultiHeadAttentionRoPE, WindowedAttention
class MLP(Module):
"""
Multi-Layer Perceptron with GELU activation
Standard FFN block in transformers
"""
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, dim)
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
def forward(self, x: mx.array) -> mx.array:
x = self.fc1(x)
x = self.act(x)
if self.dropout:
x = self.dropout(x)
x = self.fc2(x)
if self.dropout:
x = self.dropout(x)
return x
class HieraBlock(Module):
"""
Single Hiera transformer block
Features:
- Pre-LayerNorm architecture
- RoPE Multi-Head Attention
- MLP with GELU
- Residual connections
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
dropout: float = 0.0,
use_windowed_attn: bool = False,
window_size: int = 14,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
# Choose attention type
if use_windowed_attn:
self.attn = WindowedAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
dropout=dropout,
window_size=window_size
)
else:
self.attn = MultiHeadAttentionRoPE(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
dropout=dropout
)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, int(dim * mlp_ratio), dropout=dropout)
def forward(self, x: mx.array) -> mx.array:
# Attention with pre-norm and residual
x = x + self.attn(self.norm1(x))
# MLP with pre-norm and residual
x = x + self.mlp(self.norm2(x))
return x
class PatchEmbed(Module):
"""
Image to Patch Embedding using Conv2d
Converts (B, H, W, C) image to (B, num_patches, embed_dim) patches
"""
def __init__(
self,
img_size: int = 1024,
patch_size: int = 14,
in_chans: int = 3,
embed_dim: int = 1024
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = img_size // patch_size
self.num_patches = self.grid_size ** 2
# Convolution for patch embedding
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x: mx.array) -> mx.array:
"""
Args:
x: (B, H, W, C) in NHWC format (MLX convention)
Returns:
(B, num_patches, embed_dim)
"""
B, H, W, C = x.shape
# Apply convolution
x = self.proj(x) # (B, H', W', embed_dim) where H'=W'=grid_size
# Flatten spatial dimensions
B, H_p, W_p, C_emb = x.shape
x = x.reshape(B, H_p * W_p, C_emb) # (B, num_patches, embed_dim)
return x
class DownsampleBlock(Module):
"""
Spatial downsampling block for hierarchical processing
Reduces spatial resolution by 2x while increasing channels
Uses depthwise-separable convolution for efficiency
"""
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
# Depthwise convolution (2x2 pooling with stride 2)
self.dw_conv = nn.Conv2d(in_dim, in_dim, kernel_size=2, stride=2, groups=in_dim)
# Pointwise convolution (1x1 to change channels)
self.pw_conv = nn.Conv2d(in_dim, out_dim, kernel_size=1)
self.norm = nn.LayerNorm(out_dim)
def forward(self, x: mx.array, h: int, w: int) -> Tuple[mx.array, int, int]:
"""
Args:
x: (B, N, C) where N = h*w
h, w: Spatial dimensions
Returns:
(B, N//4, C'), h//2, w//2
"""
B, N, C = x.shape
# Reshape to spatial format: (B, N, C) -> (B, h, w, C)
x = x.reshape(B, h, w, C)
# Apply convolutions
x = self.dw_conv(x)
x = self.pw_conv(x)
# Flatten back: (B, h//2, w//2, out_dim) -> (B, N//4, out_dim)
B, h_new, w_new, C_new = x.shape
x = x.reshape(B, h_new * w_new, C_new)
# Normalize
x = self.norm(x)
return x, h_new, w_new
class HieraStage(Module):
"""
Single stage of Hiera with multiple blocks
Each stage processes at a specific spatial scale
"""
def __init__(
self,
dim: int,
depth: int,
num_heads: int,
mlp_ratio: float = 4.0,
use_windowed_attn: bool = False,
window_size: int = 14,
):
super().__init__()
self.blocks = [
HieraBlock(
dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
use_windowed_attn=use_windowed_attn and (i % 2 == 0), # Alternate global/local
window_size=window_size
)
for i in range(depth)
]
def forward(self, x: mx.array) -> mx.array:
for block in self.blocks:
x = block(x)
return x
class HieraVisionEncoder(Module):
"""
Complete Hiera Vision Encoder
Multi-scale hierarchical vision transformer with:
- 4 stages with increasing channel dimensions
- Spatial downsampling between stages
- RoPE attention at all scales
- Both global and windowed attention
Args:
img_size: Input image size
patch_size: Initial patch size
in_chans: Input channels (3 for RGB)
embed_dims: Channel dimensions for each stage
depths: Number of blocks per stage
num_heads: Attention heads per stage
mlp_ratio: MLP hidden dim ratio
use_windowed_attn: Use windowed attention in stages
"""
def __init__(
self,
img_size: int = 1024,
patch_size: int = 14,
in_chans: int = 3,
embed_dims: List[int] = [256, 512, 1024, 1024], # Progressive channel increase
depths: List[int] = [2, 8, 16, 6], # Blocks per stage
num_heads: List[int] = [4, 8, 16, 16],
mlp_ratio: float = 4.0,
use_windowed_attn: bool = True,
window_size: int = 14,
):
super().__init__()
assert len(embed_dims) == len(depths) == len(num_heads), \
"embed_dims, depths, and num_heads must have same length"
self.num_stages = len(embed_dims)
self.patch_size = patch_size
# Patch embedding
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dims[0]
)
# Initial spatial dimensions
self.init_h = self.init_w = img_size // patch_size
# Pre-norm before stages
self.norm_pre = nn.LayerNorm(embed_dims[0])
# Build stages
self.stages = []
self.downsample_layers = []
for i in range(self.num_stages):
# Create stage
stage = HieraStage(
dim=embed_dims[i],
depth=depths[i],
num_heads=num_heads[i],
mlp_ratio=mlp_ratio,
use_windowed_attn=use_windowed_attn,
window_size=window_size
)
self.stages.append(stage)
# Create downsampling layer (except for last stage)
if i < self.num_stages - 1:
downsample = DownsampleBlock(embed_dims[i], embed_dims[i + 1])
self.downsample_layers.append(downsample)
# Final norm
self.norm = nn.LayerNorm(embed_dims[-1])
def forward(self, x: mx.array) -> mx.array:
"""
Args:
x: (B, H, W, C) image in NHWC format
Returns:
(B, num_patches_final, embed_dim_final) features
"""
# Patch embedding
x = self.patch_embed(x) # (B, num_patches, embed_dim[0])
# Pre-norm
x = self.norm_pre(x)
# Track spatial dimensions
h, w = self.init_h, self.init_w
# Process through stages
for i, stage in enumerate(self.stages):
# Apply stage
x = stage(x)
# Downsample (except last stage)
if i < len(self.downsample_layers):
x, h, w = self.downsample_layers[i](x, h, w)
# Final norm
x = self.norm(x)
return x
def create_hiera_base() -> HieraVisionEncoder:
"""Create Hiera-Base configuration (SAM3 default)"""
return HieraVisionEncoder(
img_size=1024,
patch_size=14,
embed_dims=[256, 512, 1024, 1024],
depths=[2, 8, 16, 6],
num_heads=[4, 8, 16, 16]
)
def create_hiera_large() -> HieraVisionEncoder:
"""Create Hiera-Large configuration"""
return HieraVisionEncoder(
img_size=1024,
patch_size=14,
embed_dims=[384, 768, 1536, 1536],
depths=[2, 8, 20, 8],
num_heads=[6, 12, 24, 24]
)