MLX
File size: 6,382 Bytes
ced11e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
RoPE Multi-Head Attention for SAM3
Implements Rotary Position Embeddings for spatial awareness
"""

import mlx.core as mx
import mlx.nn as nn
from mlx.nn import Module
import math
from typing import Optional

class RoPEEmbedding(Module):
    """Rotary Position Embedding - 2D version for images"""

    def __init__(self, dim: int, max_seq_len: int = 8192):
        super().__init__()
        self.dim = dim

        # Precompute frequency matrix
        inv_freq = 1.0 / (10000 ** (mx.arange(0, dim, 2).astype(mx.float32) / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len: int) -> mx.array:
        """Generate RoPE embeddings for given sequence length"""
        # Generate position indices
        t = mx.arange(seq_len, dtype=mx.float32)

        # Compute frequencies: outer product of positions and inv_freq
        freqs = mx.outer(t, self.inv_freq)  # (seq_len, dim/2)

        # Create sin and cos embeddings
        emb = mx.concatenate([freqs, freqs], axis=-1)  # (seq_len, dim)

        return mx.stack([mx.cos(emb), mx.sin(emb)], axis=0)  # (2, seq_len, dim)

    def register_buffer(self, name: str, tensor: mx.array):
        """Register buffer (MLX doesn't need this, but keeping for compatibility)"""
        setattr(self, name, tensor)


def apply_rotary_pos_emb(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array) -> tuple:
    """
    Apply rotary position embeddings to queries and keys

    Args:
        q: (batch, seq_len, num_heads, head_dim)
        k: (batch, seq_len, num_heads, head_dim)
        cos: (seq_len, head_dim)
        sin: (seq_len, head_dim)

    Returns:
        Rotated q and k
    """
    # Reshape for broadcasting
    cos = cos.reshape(1, -1, 1, cos.shape[-1])  # (1, seq_len, 1, head_dim)
    sin = sin.reshape(1, -1, 1, sin.shape[-1])

    # Split into two halves for rotation
    q_half1, q_half2 = mx.split(q, 2, axis=-1)
    k_half1, k_half2 = mx.split(k, 2, axis=-1)

    # Apply rotation
    q_rotated = mx.concatenate([
        q_half1 * cos - q_half2 * sin,
        q_half1 * sin + q_half2 * cos
    ], axis=-1)

    k_rotated = mx.concatenate([
        k_half1 * cos - k_half2 * sin,
        k_half1 * sin + k_half2 * cos
    ], axis=-1)

    return q_rotated, k_rotated


class MultiHeadAttentionRoPE(Module):
    """
    Multi-Head Attention with Rotary Position Embeddings

    Key features:
    - RoPE for relative position encoding
    - Flash attention compatible
    - Optimized for MLX/Metal
    """

    def __init__(
        self,
        dim: int,
        num_heads: int = 16,
        qkv_bias: bool = True,
        dropout: float = 0.0,
        use_rope: bool = True
    ):
        super().__init__()

        assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.use_rope = use_rope

        # QKV projection
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        # Output projection
        self.proj = nn.Linear(dim, dim)

        # Dropout
        self.attn_dropout = nn.Dropout(dropout) if dropout > 0 else None
        self.proj_dropout = nn.Dropout(dropout) if dropout > 0 else None

        # RoPE
        if use_rope:
            self.rope = RoPEEmbedding(self.head_dim)

    def forward(self, x: mx.array, attn_mask: Optional[mx.array] = None) -> mx.array:
        """
        Forward pass

        Args:
            x: (batch, seq_len, dim)
            attn_mask: Optional attention mask

        Returns:
            Output: (batch, seq_len, dim)
        """
        B, N, C = x.shape

        # QKV projection and reshape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.transpose(2, 0, 3, 1, 4)  # (3, B, num_heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Apply RoPE if enabled
        if self.use_rope:
            rope_emb = self.rope.forward(N)  # (2, N, head_dim)
            cos, sin = rope_emb[0], rope_emb[1]

            # Transpose for apply_rotary: (B, num_heads, N, head_dim) -> (B, N, num_heads, head_dim)
            q = q.transpose(0, 2, 1, 3)
            k = k.transpose(0, 2, 1, 3)

            q, k = apply_rotary_pos_emb(q, k, cos, sin)

            # Transpose back
            q = q.transpose(0, 2, 1, 3)
            k = k.transpose(0, 2, 1, 3)

        # Scaled dot-product attention
        # q, k, v: (B, num_heads, N, head_dim)
        attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale  # (B, num_heads, N, N)

        # Apply attention mask if provided
        if attn_mask is not None:
            attn = attn + attn_mask

        # Softmax
        attn = mx.softmax(attn, axis=-1)

        # Apply dropout
        if self.attn_dropout is not None:
            attn = self.attn_dropout(attn)

        # Apply attention to values
        x = attn @ v  # (B, num_heads, N, head_dim)

        # Reshape and project
        x = x.transpose(0, 2, 1, 3).reshape(B, N, C)
        x = self.proj(x)

        # Apply output dropout
        if self.proj_dropout is not None:
            x = self.proj_dropout(x)

        return x


class WindowedAttention(MultiHeadAttentionRoPE):
    """
    Windowed Multi-Head Attention for local processing
    Used in certain Hiera blocks for efficiency
    """

    def __init__(
        self,
        dim: int,
        num_heads: int = 16,
        window_size: int = 14,
        **kwargs
    ):
        super().__init__(dim, num_heads, **kwargs)
        self.window_size = window_size

    def create_window_mask(self, seq_len: int) -> mx.array:
        """Create attention mask for windowed attention"""
        # Create mask that only allows attention within window_size
        mask = mx.ones((seq_len, seq_len)) * float('-inf')

        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = 0.0

        return mask.reshape(1, 1, seq_len, seq_len)

    def forward(self, x: mx.array) -> mx.array:
        """Forward with windowed attention"""
        B, N, C = x.shape

        # Create window mask
        window_mask = self.create_window_mask(N)

        return super().forward(x, attn_mask=window_mask)