""" RoPE Multi-Head Attention for SAM3 Implements Rotary Position Embeddings for spatial awareness """ import mlx.core as mx import mlx.nn as nn from mlx.nn import Module import math from typing import Optional class RoPEEmbedding(Module): """Rotary Position Embedding - 2D version for images""" def __init__(self, dim: int, max_seq_len: int = 8192): super().__init__() self.dim = dim # Precompute frequency matrix inv_freq = 1.0 / (10000 ** (mx.arange(0, dim, 2).astype(mx.float32) / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, seq_len: int) -> mx.array: """Generate RoPE embeddings for given sequence length""" # Generate position indices t = mx.arange(seq_len, dtype=mx.float32) # Compute frequencies: outer product of positions and inv_freq freqs = mx.outer(t, self.inv_freq) # (seq_len, dim/2) # Create sin and cos embeddings emb = mx.concatenate([freqs, freqs], axis=-1) # (seq_len, dim) return mx.stack([mx.cos(emb), mx.sin(emb)], axis=0) # (2, seq_len, dim) def register_buffer(self, name: str, tensor: mx.array): """Register buffer (MLX doesn't need this, but keeping for compatibility)""" setattr(self, name, tensor) def apply_rotary_pos_emb(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array) -> tuple: """ Apply rotary position embeddings to queries and keys Args: q: (batch, seq_len, num_heads, head_dim) k: (batch, seq_len, num_heads, head_dim) cos: (seq_len, head_dim) sin: (seq_len, head_dim) Returns: Rotated q and k """ # Reshape for broadcasting cos = cos.reshape(1, -1, 1, cos.shape[-1]) # (1, seq_len, 1, head_dim) sin = sin.reshape(1, -1, 1, sin.shape[-1]) # Split into two halves for rotation q_half1, q_half2 = mx.split(q, 2, axis=-1) k_half1, k_half2 = mx.split(k, 2, axis=-1) # Apply rotation q_rotated = mx.concatenate([ q_half1 * cos - q_half2 * sin, q_half1 * sin + q_half2 * cos ], axis=-1) k_rotated = mx.concatenate([ k_half1 * cos - k_half2 * sin, k_half1 * sin + k_half2 * cos ], axis=-1) return q_rotated, k_rotated class MultiHeadAttentionRoPE(Module): """ Multi-Head Attention with Rotary Position Embeddings Key features: - RoPE for relative position encoding - Flash attention compatible - Optimized for MLX/Metal """ def __init__( self, dim: int, num_heads: int = 16, qkv_bias: bool = True, dropout: float = 0.0, use_rope: bool = True ): super().__init__() assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}" self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.use_rope = use_rope # QKV projection self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # Output projection self.proj = nn.Linear(dim, dim) # Dropout self.attn_dropout = nn.Dropout(dropout) if dropout > 0 else None self.proj_dropout = nn.Dropout(dropout) if dropout > 0 else None # RoPE if use_rope: self.rope = RoPEEmbedding(self.head_dim) def forward(self, x: mx.array, attn_mask: Optional[mx.array] = None) -> mx.array: """ Forward pass Args: x: (batch, seq_len, dim) attn_mask: Optional attention mask Returns: Output: (batch, seq_len, dim) """ B, N, C = x.shape # QKV projection and reshape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) qkv = qkv.transpose(2, 0, 3, 1, 4) # (3, B, num_heads, N, head_dim) q, k, v = qkv[0], qkv[1], qkv[2] # Apply RoPE if enabled if self.use_rope: rope_emb = self.rope.forward(N) # (2, N, head_dim) cos, sin = rope_emb[0], rope_emb[1] # Transpose for apply_rotary: (B, num_heads, N, head_dim) -> (B, N, num_heads, head_dim) q = q.transpose(0, 2, 1, 3) k = k.transpose(0, 2, 1, 3) q, k = apply_rotary_pos_emb(q, k, cos, sin) # Transpose back q = q.transpose(0, 2, 1, 3) k = k.transpose(0, 2, 1, 3) # Scaled dot-product attention # q, k, v: (B, num_heads, N, head_dim) attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale # (B, num_heads, N, N) # Apply attention mask if provided if attn_mask is not None: attn = attn + attn_mask # Softmax attn = mx.softmax(attn, axis=-1) # Apply dropout if self.attn_dropout is not None: attn = self.attn_dropout(attn) # Apply attention to values x = attn @ v # (B, num_heads, N, head_dim) # Reshape and project x = x.transpose(0, 2, 1, 3).reshape(B, N, C) x = self.proj(x) # Apply output dropout if self.proj_dropout is not None: x = self.proj_dropout(x) return x class WindowedAttention(MultiHeadAttentionRoPE): """ Windowed Multi-Head Attention for local processing Used in certain Hiera blocks for efficiency """ def __init__( self, dim: int, num_heads: int = 16, window_size: int = 14, **kwargs ): super().__init__(dim, num_heads, **kwargs) self.window_size = window_size def create_window_mask(self, seq_len: int) -> mx.array: """Create attention mask for windowed attention""" # Create mask that only allows attention within window_size mask = mx.ones((seq_len, seq_len)) * float('-inf') for i in range(seq_len): start = max(0, i - self.window_size // 2) end = min(seq_len, i + self.window_size // 2 + 1) mask[i, start:end] = 0.0 return mask.reshape(1, 1, seq_len, seq_len) def forward(self, x: mx.array) -> mx.array: """Forward with windowed attention""" B, N, C = x.shape # Create window mask window_mask = self.create_window_mask(N) return super().forward(x, attn_mask=window_mask)