Instructions to use Hoodrobot/MLX_SAM3 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use Hoodrobot/MLX_SAM3 with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir MLX_SAM3 Hoodrobot/MLX_SAM3
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
File size: 6,382 Bytes
ced11e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | """
RoPE Multi-Head Attention for SAM3
Implements Rotary Position Embeddings for spatial awareness
"""
import mlx.core as mx
import mlx.nn as nn
from mlx.nn import Module
import math
from typing import Optional
class RoPEEmbedding(Module):
"""Rotary Position Embedding - 2D version for images"""
def __init__(self, dim: int, max_seq_len: int = 8192):
super().__init__()
self.dim = dim
# Precompute frequency matrix
inv_freq = 1.0 / (10000 ** (mx.arange(0, dim, 2).astype(mx.float32) / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, seq_len: int) -> mx.array:
"""Generate RoPE embeddings for given sequence length"""
# Generate position indices
t = mx.arange(seq_len, dtype=mx.float32)
# Compute frequencies: outer product of positions and inv_freq
freqs = mx.outer(t, self.inv_freq) # (seq_len, dim/2)
# Create sin and cos embeddings
emb = mx.concatenate([freqs, freqs], axis=-1) # (seq_len, dim)
return mx.stack([mx.cos(emb), mx.sin(emb)], axis=0) # (2, seq_len, dim)
def register_buffer(self, name: str, tensor: mx.array):
"""Register buffer (MLX doesn't need this, but keeping for compatibility)"""
setattr(self, name, tensor)
def apply_rotary_pos_emb(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array) -> tuple:
"""
Apply rotary position embeddings to queries and keys
Args:
q: (batch, seq_len, num_heads, head_dim)
k: (batch, seq_len, num_heads, head_dim)
cos: (seq_len, head_dim)
sin: (seq_len, head_dim)
Returns:
Rotated q and k
"""
# Reshape for broadcasting
cos = cos.reshape(1, -1, 1, cos.shape[-1]) # (1, seq_len, 1, head_dim)
sin = sin.reshape(1, -1, 1, sin.shape[-1])
# Split into two halves for rotation
q_half1, q_half2 = mx.split(q, 2, axis=-1)
k_half1, k_half2 = mx.split(k, 2, axis=-1)
# Apply rotation
q_rotated = mx.concatenate([
q_half1 * cos - q_half2 * sin,
q_half1 * sin + q_half2 * cos
], axis=-1)
k_rotated = mx.concatenate([
k_half1 * cos - k_half2 * sin,
k_half1 * sin + k_half2 * cos
], axis=-1)
return q_rotated, k_rotated
class MultiHeadAttentionRoPE(Module):
"""
Multi-Head Attention with Rotary Position Embeddings
Key features:
- RoPE for relative position encoding
- Flash attention compatible
- Optimized for MLX/Metal
"""
def __init__(
self,
dim: int,
num_heads: int = 16,
qkv_bias: bool = True,
dropout: float = 0.0,
use_rope: bool = True
):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.use_rope = use_rope
# QKV projection
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# Output projection
self.proj = nn.Linear(dim, dim)
# Dropout
self.attn_dropout = nn.Dropout(dropout) if dropout > 0 else None
self.proj_dropout = nn.Dropout(dropout) if dropout > 0 else None
# RoPE
if use_rope:
self.rope = RoPEEmbedding(self.head_dim)
def forward(self, x: mx.array, attn_mask: Optional[mx.array] = None) -> mx.array:
"""
Forward pass
Args:
x: (batch, seq_len, dim)
attn_mask: Optional attention mask
Returns:
Output: (batch, seq_len, dim)
"""
B, N, C = x.shape
# QKV projection and reshape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.transpose(2, 0, 3, 1, 4) # (3, B, num_heads, N, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# Apply RoPE if enabled
if self.use_rope:
rope_emb = self.rope.forward(N) # (2, N, head_dim)
cos, sin = rope_emb[0], rope_emb[1]
# Transpose for apply_rotary: (B, num_heads, N, head_dim) -> (B, N, num_heads, head_dim)
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# Transpose back
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
# Scaled dot-product attention
# q, k, v: (B, num_heads, N, head_dim)
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale # (B, num_heads, N, N)
# Apply attention mask if provided
if attn_mask is not None:
attn = attn + attn_mask
# Softmax
attn = mx.softmax(attn, axis=-1)
# Apply dropout
if self.attn_dropout is not None:
attn = self.attn_dropout(attn)
# Apply attention to values
x = attn @ v # (B, num_heads, N, head_dim)
# Reshape and project
x = x.transpose(0, 2, 1, 3).reshape(B, N, C)
x = self.proj(x)
# Apply output dropout
if self.proj_dropout is not None:
x = self.proj_dropout(x)
return x
class WindowedAttention(MultiHeadAttentionRoPE):
"""
Windowed Multi-Head Attention for local processing
Used in certain Hiera blocks for efficiency
"""
def __init__(
self,
dim: int,
num_heads: int = 16,
window_size: int = 14,
**kwargs
):
super().__init__(dim, num_heads, **kwargs)
self.window_size = window_size
def create_window_mask(self, seq_len: int) -> mx.array:
"""Create attention mask for windowed attention"""
# Create mask that only allows attention within window_size
mask = mx.ones((seq_len, seq_len)) * float('-inf')
for i in range(seq_len):
start = max(0, i - self.window_size // 2)
end = min(seq_len, i + self.window_size // 2 + 1)
mask[i, start:end] = 0.0
return mask.reshape(1, 1, seq_len, seq_len)
def forward(self, x: mx.array) -> mx.array:
"""Forward with windowed attention"""
B, N, C = x.shape
# Create window mask
window_mask = self.create_window_mask(N)
return super().forward(x, attn_mask=window_mask)
|