| """ |
| RoPE Multi-Head Attention for SAM3 |
| Implements Rotary Position Embeddings for spatial awareness |
| """ |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
| from mlx.nn import Module |
| import math |
| from typing import Optional |
|
|
| class RoPEEmbedding(Module): |
| """Rotary Position Embedding - 2D version for images""" |
|
|
| def __init__(self, dim: int, max_seq_len: int = 8192): |
| super().__init__() |
| self.dim = dim |
|
|
| |
| inv_freq = 1.0 / (10000 ** (mx.arange(0, dim, 2).astype(mx.float32) / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
|
|
| def forward(self, seq_len: int) -> mx.array: |
| """Generate RoPE embeddings for given sequence length""" |
| |
| t = mx.arange(seq_len, dtype=mx.float32) |
|
|
| |
| freqs = mx.outer(t, self.inv_freq) |
|
|
| |
| emb = mx.concatenate([freqs, freqs], axis=-1) |
|
|
| return mx.stack([mx.cos(emb), mx.sin(emb)], axis=0) |
|
|
| def register_buffer(self, name: str, tensor: mx.array): |
| """Register buffer (MLX doesn't need this, but keeping for compatibility)""" |
| setattr(self, name, tensor) |
|
|
|
|
| def apply_rotary_pos_emb(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array) -> tuple: |
| """ |
| Apply rotary position embeddings to queries and keys |
| |
| Args: |
| q: (batch, seq_len, num_heads, head_dim) |
| k: (batch, seq_len, num_heads, head_dim) |
| cos: (seq_len, head_dim) |
| sin: (seq_len, head_dim) |
| |
| Returns: |
| Rotated q and k |
| """ |
| |
| cos = cos.reshape(1, -1, 1, cos.shape[-1]) |
| sin = sin.reshape(1, -1, 1, sin.shape[-1]) |
|
|
| |
| q_half1, q_half2 = mx.split(q, 2, axis=-1) |
| k_half1, k_half2 = mx.split(k, 2, axis=-1) |
|
|
| |
| q_rotated = mx.concatenate([ |
| q_half1 * cos - q_half2 * sin, |
| q_half1 * sin + q_half2 * cos |
| ], axis=-1) |
|
|
| k_rotated = mx.concatenate([ |
| k_half1 * cos - k_half2 * sin, |
| k_half1 * sin + k_half2 * cos |
| ], axis=-1) |
|
|
| return q_rotated, k_rotated |
|
|
|
|
| class MultiHeadAttentionRoPE(Module): |
| """ |
| Multi-Head Attention with Rotary Position Embeddings |
| |
| Key features: |
| - RoPE for relative position encoding |
| - Flash attention compatible |
| - Optimized for MLX/Metal |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| num_heads: int = 16, |
| qkv_bias: bool = True, |
| dropout: float = 0.0, |
| use_rope: bool = True |
| ): |
| super().__init__() |
|
|
| assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}" |
|
|
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.scale = self.head_dim ** -0.5 |
| self.use_rope = use_rope |
|
|
| |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
| |
| self.proj = nn.Linear(dim, dim) |
|
|
| |
| self.attn_dropout = nn.Dropout(dropout) if dropout > 0 else None |
| self.proj_dropout = nn.Dropout(dropout) if dropout > 0 else None |
|
|
| |
| if use_rope: |
| self.rope = RoPEEmbedding(self.head_dim) |
|
|
| def forward(self, x: mx.array, attn_mask: Optional[mx.array] = None) -> mx.array: |
| """ |
| Forward pass |
| |
| Args: |
| x: (batch, seq_len, dim) |
| attn_mask: Optional attention mask |
| |
| Returns: |
| Output: (batch, seq_len, dim) |
| """ |
| B, N, C = x.shape |
|
|
| |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) |
| qkv = qkv.transpose(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| |
| if self.use_rope: |
| rope_emb = self.rope.forward(N) |
| cos, sin = rope_emb[0], rope_emb[1] |
|
|
| |
| q = q.transpose(0, 2, 1, 3) |
| k = k.transpose(0, 2, 1, 3) |
|
|
| q, k = apply_rotary_pos_emb(q, k, cos, sin) |
|
|
| |
| q = q.transpose(0, 2, 1, 3) |
| k = k.transpose(0, 2, 1, 3) |
|
|
| |
| |
| attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale |
|
|
| |
| if attn_mask is not None: |
| attn = attn + attn_mask |
|
|
| |
| attn = mx.softmax(attn, axis=-1) |
|
|
| |
| if self.attn_dropout is not None: |
| attn = self.attn_dropout(attn) |
|
|
| |
| x = attn @ v |
|
|
| |
| x = x.transpose(0, 2, 1, 3).reshape(B, N, C) |
| x = self.proj(x) |
|
|
| |
| if self.proj_dropout is not None: |
| x = self.proj_dropout(x) |
|
|
| return x |
|
|
|
|
| class WindowedAttention(MultiHeadAttentionRoPE): |
| """ |
| Windowed Multi-Head Attention for local processing |
| Used in certain Hiera blocks for efficiency |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| num_heads: int = 16, |
| window_size: int = 14, |
| **kwargs |
| ): |
| super().__init__(dim, num_heads, **kwargs) |
| self.window_size = window_size |
|
|
| def create_window_mask(self, seq_len: int) -> mx.array: |
| """Create attention mask for windowed attention""" |
| |
| mask = mx.ones((seq_len, seq_len)) * float('-inf') |
|
|
| for i in range(seq_len): |
| start = max(0, i - self.window_size // 2) |
| end = min(seq_len, i + self.window_size // 2 + 1) |
| mask[i, start:end] = 0.0 |
|
|
| return mask.reshape(1, 1, seq_len, seq_len) |
|
|
| def forward(self, x: mx.array) -> mx.array: |
| """Forward with windowed attention""" |
| B, N, C = x.shape |
|
|
| |
| window_mask = self.create_window_mask(N) |
|
|
| return super().forward(x, attn_mask=window_mask) |
|
|