MLX
MLX_SAM3 / attention.py
Hoodrobot's picture
Upload 15 files
ced11e2 verified
"""
RoPE Multi-Head Attention for SAM3
Implements Rotary Position Embeddings for spatial awareness
"""
import mlx.core as mx
import mlx.nn as nn
from mlx.nn import Module
import math
from typing import Optional
class RoPEEmbedding(Module):
"""Rotary Position Embedding - 2D version for images"""
def __init__(self, dim: int, max_seq_len: int = 8192):
super().__init__()
self.dim = dim
# Precompute frequency matrix
inv_freq = 1.0 / (10000 ** (mx.arange(0, dim, 2).astype(mx.float32) / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, seq_len: int) -> mx.array:
"""Generate RoPE embeddings for given sequence length"""
# Generate position indices
t = mx.arange(seq_len, dtype=mx.float32)
# Compute frequencies: outer product of positions and inv_freq
freqs = mx.outer(t, self.inv_freq) # (seq_len, dim/2)
# Create sin and cos embeddings
emb = mx.concatenate([freqs, freqs], axis=-1) # (seq_len, dim)
return mx.stack([mx.cos(emb), mx.sin(emb)], axis=0) # (2, seq_len, dim)
def register_buffer(self, name: str, tensor: mx.array):
"""Register buffer (MLX doesn't need this, but keeping for compatibility)"""
setattr(self, name, tensor)
def apply_rotary_pos_emb(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array) -> tuple:
"""
Apply rotary position embeddings to queries and keys
Args:
q: (batch, seq_len, num_heads, head_dim)
k: (batch, seq_len, num_heads, head_dim)
cos: (seq_len, head_dim)
sin: (seq_len, head_dim)
Returns:
Rotated q and k
"""
# Reshape for broadcasting
cos = cos.reshape(1, -1, 1, cos.shape[-1]) # (1, seq_len, 1, head_dim)
sin = sin.reshape(1, -1, 1, sin.shape[-1])
# Split into two halves for rotation
q_half1, q_half2 = mx.split(q, 2, axis=-1)
k_half1, k_half2 = mx.split(k, 2, axis=-1)
# Apply rotation
q_rotated = mx.concatenate([
q_half1 * cos - q_half2 * sin,
q_half1 * sin + q_half2 * cos
], axis=-1)
k_rotated = mx.concatenate([
k_half1 * cos - k_half2 * sin,
k_half1 * sin + k_half2 * cos
], axis=-1)
return q_rotated, k_rotated
class MultiHeadAttentionRoPE(Module):
"""
Multi-Head Attention with Rotary Position Embeddings
Key features:
- RoPE for relative position encoding
- Flash attention compatible
- Optimized for MLX/Metal
"""
def __init__(
self,
dim: int,
num_heads: int = 16,
qkv_bias: bool = True,
dropout: float = 0.0,
use_rope: bool = True
):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.use_rope = use_rope
# QKV projection
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# Output projection
self.proj = nn.Linear(dim, dim)
# Dropout
self.attn_dropout = nn.Dropout(dropout) if dropout > 0 else None
self.proj_dropout = nn.Dropout(dropout) if dropout > 0 else None
# RoPE
if use_rope:
self.rope = RoPEEmbedding(self.head_dim)
def forward(self, x: mx.array, attn_mask: Optional[mx.array] = None) -> mx.array:
"""
Forward pass
Args:
x: (batch, seq_len, dim)
attn_mask: Optional attention mask
Returns:
Output: (batch, seq_len, dim)
"""
B, N, C = x.shape
# QKV projection and reshape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.transpose(2, 0, 3, 1, 4) # (3, B, num_heads, N, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# Apply RoPE if enabled
if self.use_rope:
rope_emb = self.rope.forward(N) # (2, N, head_dim)
cos, sin = rope_emb[0], rope_emb[1]
# Transpose for apply_rotary: (B, num_heads, N, head_dim) -> (B, N, num_heads, head_dim)
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# Transpose back
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
# Scaled dot-product attention
# q, k, v: (B, num_heads, N, head_dim)
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale # (B, num_heads, N, N)
# Apply attention mask if provided
if attn_mask is not None:
attn = attn + attn_mask
# Softmax
attn = mx.softmax(attn, axis=-1)
# Apply dropout
if self.attn_dropout is not None:
attn = self.attn_dropout(attn)
# Apply attention to values
x = attn @ v # (B, num_heads, N, head_dim)
# Reshape and project
x = x.transpose(0, 2, 1, 3).reshape(B, N, C)
x = self.proj(x)
# Apply output dropout
if self.proj_dropout is not None:
x = self.proj_dropout(x)
return x
class WindowedAttention(MultiHeadAttentionRoPE):
"""
Windowed Multi-Head Attention for local processing
Used in certain Hiera blocks for efficiency
"""
def __init__(
self,
dim: int,
num_heads: int = 16,
window_size: int = 14,
**kwargs
):
super().__init__(dim, num_heads, **kwargs)
self.window_size = window_size
def create_window_mask(self, seq_len: int) -> mx.array:
"""Create attention mask for windowed attention"""
# Create mask that only allows attention within window_size
mask = mx.ones((seq_len, seq_len)) * float('-inf')
for i in range(seq_len):
start = max(0, i - self.window_size // 2)
end = min(seq_len, i + self.window_size // 2 + 1)
mask[i, start:end] = 0.0
return mask.reshape(1, 1, seq_len, seq_len)
def forward(self, x: mx.array) -> mx.array:
"""Forward with windowed attention"""
B, N, C = x.shape
# Create window mask
window_mask = self.create_window_mask(N)
return super().forward(x, attn_mask=window_mask)