""" Hiera (Hierarchical Vision Transformer) - Complete MLX Implementation This is the vision backbone used in SAM3, featuring: - Multi-scale hierarchical processing - Stage-wise spatial pooling - RoPE attention at each scale - Efficient computation via MLX/Metal """ import mlx.core as mx import mlx.nn as nn from mlx.nn import Module from typing import List, Optional, Tuple from .attention import MultiHeadAttentionRoPE, WindowedAttention class MLP(Module): """ Multi-Layer Perceptron with GELU activation Standard FFN block in transformers """ def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0): super().__init__() self.fc1 = nn.Linear(dim, hidden_dim) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_dim, dim) self.dropout = nn.Dropout(dropout) if dropout > 0 else None def forward(self, x: mx.array) -> mx.array: x = self.fc1(x) x = self.act(x) if self.dropout: x = self.dropout(x) x = self.fc2(x) if self.dropout: x = self.dropout(x) return x class HieraBlock(Module): """ Single Hiera transformer block Features: - Pre-LayerNorm architecture - RoPE Multi-Head Attention - MLP with GELU - Residual connections """ def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, dropout: float = 0.0, use_windowed_attn: bool = False, window_size: int = 14, ): super().__init__() self.norm1 = nn.LayerNorm(dim) # Choose attention type if use_windowed_attn: self.attn = WindowedAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, dropout=dropout, window_size=window_size ) else: self.attn = MultiHeadAttentionRoPE( dim, num_heads=num_heads, qkv_bias=qkv_bias, dropout=dropout ) self.norm2 = nn.LayerNorm(dim) self.mlp = MLP(dim, int(dim * mlp_ratio), dropout=dropout) def forward(self, x: mx.array) -> mx.array: # Attention with pre-norm and residual x = x + self.attn(self.norm1(x)) # MLP with pre-norm and residual x = x + self.mlp(self.norm2(x)) return x class PatchEmbed(Module): """ Image to Patch Embedding using Conv2d Converts (B, H, W, C) image to (B, num_patches, embed_dim) patches """ def __init__( self, img_size: int = 1024, patch_size: int = 14, in_chans: int = 3, embed_dim: int = 1024 ): super().__init__() self.img_size = img_size self.patch_size = patch_size self.grid_size = img_size // patch_size self.num_patches = self.grid_size ** 2 # Convolution for patch embedding self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x: mx.array) -> mx.array: """ Args: x: (B, H, W, C) in NHWC format (MLX convention) Returns: (B, num_patches, embed_dim) """ B, H, W, C = x.shape # Apply convolution x = self.proj(x) # (B, H', W', embed_dim) where H'=W'=grid_size # Flatten spatial dimensions B, H_p, W_p, C_emb = x.shape x = x.reshape(B, H_p * W_p, C_emb) # (B, num_patches, embed_dim) return x class DownsampleBlock(Module): """ Spatial downsampling block for hierarchical processing Reduces spatial resolution by 2x while increasing channels Uses depthwise-separable convolution for efficiency """ def __init__(self, in_dim: int, out_dim: int): super().__init__() # Depthwise convolution (2x2 pooling with stride 2) self.dw_conv = nn.Conv2d(in_dim, in_dim, kernel_size=2, stride=2, groups=in_dim) # Pointwise convolution (1x1 to change channels) self.pw_conv = nn.Conv2d(in_dim, out_dim, kernel_size=1) self.norm = nn.LayerNorm(out_dim) def forward(self, x: mx.array, h: int, w: int) -> Tuple[mx.array, int, int]: """ Args: x: (B, N, C) where N = h*w h, w: Spatial dimensions Returns: (B, N//4, C'), h//2, w//2 """ B, N, C = x.shape # Reshape to spatial format: (B, N, C) -> (B, h, w, C) x = x.reshape(B, h, w, C) # Apply convolutions x = self.dw_conv(x) x = self.pw_conv(x) # Flatten back: (B, h//2, w//2, out_dim) -> (B, N//4, out_dim) B, h_new, w_new, C_new = x.shape x = x.reshape(B, h_new * w_new, C_new) # Normalize x = self.norm(x) return x, h_new, w_new class HieraStage(Module): """ Single stage of Hiera with multiple blocks Each stage processes at a specific spatial scale """ def __init__( self, dim: int, depth: int, num_heads: int, mlp_ratio: float = 4.0, use_windowed_attn: bool = False, window_size: int = 14, ): super().__init__() self.blocks = [ HieraBlock( dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, use_windowed_attn=use_windowed_attn and (i % 2 == 0), # Alternate global/local window_size=window_size ) for i in range(depth) ] def forward(self, x: mx.array) -> mx.array: for block in self.blocks: x = block(x) return x class HieraVisionEncoder(Module): """ Complete Hiera Vision Encoder Multi-scale hierarchical vision transformer with: - 4 stages with increasing channel dimensions - Spatial downsampling between stages - RoPE attention at all scales - Both global and windowed attention Args: img_size: Input image size patch_size: Initial patch size in_chans: Input channels (3 for RGB) embed_dims: Channel dimensions for each stage depths: Number of blocks per stage num_heads: Attention heads per stage mlp_ratio: MLP hidden dim ratio use_windowed_attn: Use windowed attention in stages """ def __init__( self, img_size: int = 1024, patch_size: int = 14, in_chans: int = 3, embed_dims: List[int] = [256, 512, 1024, 1024], # Progressive channel increase depths: List[int] = [2, 8, 16, 6], # Blocks per stage num_heads: List[int] = [4, 8, 16, 16], mlp_ratio: float = 4.0, use_windowed_attn: bool = True, window_size: int = 14, ): super().__init__() assert len(embed_dims) == len(depths) == len(num_heads), \ "embed_dims, depths, and num_heads must have same length" self.num_stages = len(embed_dims) self.patch_size = patch_size # Patch embedding self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0] ) # Initial spatial dimensions self.init_h = self.init_w = img_size // patch_size # Pre-norm before stages self.norm_pre = nn.LayerNorm(embed_dims[0]) # Build stages self.stages = [] self.downsample_layers = [] for i in range(self.num_stages): # Create stage stage = HieraStage( dim=embed_dims[i], depth=depths[i], num_heads=num_heads[i], mlp_ratio=mlp_ratio, use_windowed_attn=use_windowed_attn, window_size=window_size ) self.stages.append(stage) # Create downsampling layer (except for last stage) if i < self.num_stages - 1: downsample = DownsampleBlock(embed_dims[i], embed_dims[i + 1]) self.downsample_layers.append(downsample) # Final norm self.norm = nn.LayerNorm(embed_dims[-1]) def forward(self, x: mx.array) -> mx.array: """ Args: x: (B, H, W, C) image in NHWC format Returns: (B, num_patches_final, embed_dim_final) features """ # Patch embedding x = self.patch_embed(x) # (B, num_patches, embed_dim[0]) # Pre-norm x = self.norm_pre(x) # Track spatial dimensions h, w = self.init_h, self.init_w # Process through stages for i, stage in enumerate(self.stages): # Apply stage x = stage(x) # Downsample (except last stage) if i < len(self.downsample_layers): x, h, w = self.downsample_layers[i](x, h, w) # Final norm x = self.norm(x) return x def create_hiera_base() -> HieraVisionEncoder: """Create Hiera-Base configuration (SAM3 default)""" return HieraVisionEncoder( img_size=1024, patch_size=14, embed_dims=[256, 512, 1024, 1024], depths=[2, 8, 16, 6], num_heads=[4, 8, 16, 16] ) def create_hiera_large() -> HieraVisionEncoder: """Create Hiera-Large configuration""" return HieraVisionEncoder( img_size=1024, patch_size=14, embed_dims=[384, 768, 1536, 1536], depths=[2, 8, 20, 8], num_heads=[6, 12, 24, 24] )