File size: 14,139 Bytes
5c43f61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
"""

VortexLocalAttention: Local windowed attention with global token support.

Uses a sliding window of 512 tokens for efficiency, with special handling

for global tokens that can attend across the entire sequence.

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple


class VortexLocalAttention(nn.Module):
    """

    Local windowed attention with window_size=512.

    Science documents have strong local coherence — equations reference

    nearby text, not distant paragraphs.

    Global tokens (special [SCIENCE] tokens) attend to everything.

    """

    def __init__(

        self,

        d_model: int,

        num_heads: int,

        window_size: int = 512,

        use_flash_attention: bool = True,

    ):
        """

        Initialize local windowed attention.



        Args:

            d_model: Model dimension

            num_heads: Number of attention heads

            window_size: Size of local attention window

            use_flash_attention: Use Flash Attention 2 if available (CUDA only)

        """
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.window_size = window_size
        self.use_flash_attention = use_flash_attention

        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        # QKV projection
        self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

        # Global token projection (for tokens that attend globally)
        self.global_qkv = nn.Linear(d_model, d_model * 3, bias=False)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize weights."""
        for module in [self.qkv, self.global_qkv, self.out_proj]:
            if hasattr(module, 'weight'):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(

        self,

        x: torch.Tensor,

        global_mask: Optional[torch.Tensor] = None,

        attention_mask: Optional[torch.Tensor] = None,

    ) -> torch.Tensor:
        """

        Forward pass with local windowed attention.



        Args:

            x: Input tensor (batch, seq_len, d_model)

            global_mask: Boolean mask indicating which tokens are global (attend everywhere)

                        Shape: (batch, seq_len) or None

            attention_mask: Optional padding mask (batch, seq_len)



        Returns:

            Output tensor (batch, seq_len, d_model)

        """
        batch, seq_len, _ = x.shape
        device = x.device
        dtype = x.dtype

        if global_mask is None:
            global_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=device)

        # Compute QKV for all tokens
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        # Reshape for multi-head attention
        q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute global token QKV separately
        if global_mask.any():
            global_qkv = self.global_qkv(x)
            gq, gk, gv = global_qkv.chunk(3, dim=-1)
            gq = gq.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            gk = gk.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            gv = gv.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Build output tensor
        output = torch.zeros_like(x)

        # Process each position
        for t in range(seq_len):
            # Determine window
            window_start = max(0, t - self.window_size // 2)
            window_end = min(seq_len, t + self.window_size // 2 + 1)
            window_len = window_end - window_start

            # Get window indices
            window_indices = slice(window_start, window_end)

            # Extract window queries (for position t)
            q_t = q[:, :, t:t+1, :]  # (batch, heads, 1, head_dim)

            # Determine which keys/values to use
            # Local tokens: only those in window
            # Global tokens: all positions (if they are global)
            k_window = k[:, :, window_indices, :]
            v_window = v[:, :, window_indices, :]

            # Build full key/value set including global tokens
            # Global tokens attend to all positions
            if global_mask.any():
                # Find global positions
                global_positions = global_mask[0]  # (seq_len) - assume same across batch
                if global_positions.any():
                    gk_all = gk[:, :, :, :]  # All global keys
                    gv_all = gv[:, :, :, :]

                    # Concatenate window keys with global keys
                    k_full = torch.cat([k_window, gk_all], dim=2)
                    v_full = torch.cat([v_window, gv_all], dim=2)
                else:
                    k_full = k_window
                    v_full = v_window
            else:
                k_full = k_window
                v_full = v_window

            # Compute attention scores
            # q_t: (batch, heads, 1, head_dim)
            # k_full: (batch, heads, window_len + num_global, head_dim)
            attn_scores = torch.matmul(q_t, k_full.transpose(-2, -1)) / (self.head_dim ** 0.5)
            # (batch, heads, 1, k_len)

            # Apply attention mask if provided
            if attention_mask is not None:
                mask_t = attention_mask[:, window_indices].unsqueeze(1).unsqueeze(2)
                attn_scores = attn_scores.masked_fill(mask_t == 0, -1e9)

            # Softmax
            attn_weights = F.softmax(attn_scores, dim=-1)

            # Weighted sum
            attn_output = torch.matmul(attn_weights, v_full)
            # (batch, heads, 1, head_dim)

            # Reshape and project
            attn_output = attn_output.transpose(1, 2).contiguous()
            attn_output = attn_output.view(batch, 1, self.d_model)
            attn_output = self.out_proj(attn_output)

            # Place in output
            output[:, t:t+1, :] = attn_output

        return output

    def forward_optimized(

        self,

        x: torch.Tensor,

        global_mask: Optional[torch.Tensor] = None,

        attention_mask: Optional[torch.Tensor] = None,

    ) -> torch.Tensor:
        """

        Optimized forward pass using Flash Attention or efficient windowed attention.

        This is a placeholder for actual Flash Attention integration.

        """
        batch, seq_len, _ = x.shape

        if self.use_flash_attention and self.window_size >= seq_len:
            # For short sequences, can use full attention
            return self._flash_attention_forward(x, attention_mask)
        else:
            # Use windowed attention
            return self._windowed_attention_forward(x, global_mask, attention_mask)

    def _flash_attention_forward(

        self,

        x: torch.Tensor,

        attention_mask: Optional[torch.Tensor] = None,

    ) -> torch.Tensor:
        """

        Use Flash Attention 2 if available.

        Requires: pip install flash-attn

        """
        try:
            from flash_attn import flash_attn_func

            batch, seq_len, _ = x.shape
            qkv = self.qkv(x)
            q, k, v = qkv.chunk(3, dim=-1)

            # Reshape for flash attention
            q = q.view(batch, seq_len, self.num_heads, self.head_dim)
            k = k.view(batch, seq_len, self.num_heads, self.head_dim)
            v = v.view(batch, seq_len, self.num_heads, self.head_dim)

            # Flash attention expects (batch, seq_len, num_heads, head_dim)
            # and returns same shape
            if attention_mask is not None:
                # Flash attention uses causal mask or padding mask
                output = flash_attn_func(
                    q, k, v,
                    causal=False,
                    softmax_scale=1.0 / (self.head_dim ** 0.5),
                )
            else:
                output = flash_attn_func(
                    q, k, v,
                    causal=False,
                )

            output = output.view(batch, seq_len, self.d_model)
            return self.out_proj(output)

        except ImportError:
            print("Flash Attention not available, falling back to standard attention")
            return self._standard_attention(x, attention_mask)

    def _standard_attention(

        self,

        x: torch.Tensor,

        attention_mask: Optional[torch.Tensor] = None,

    ) -> torch.Tensor:
        """Standard full attention (quadratic)."""
        batch, seq_len, _ = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention scores
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        if attention_mask is not None:
            attn_scores = attn_scores.masked_fill(
                attention_mask.unsqueeze(1).unsqueeze(2) == 0,
                -1e9
            )

        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch, seq_len, self.d_model)
        return self.out_proj(attn_output)

    def _windowed_attention_forward(

        self,

        x: torch.Tensor,

        global_mask: Optional[torch.Tensor] = None,

        attention_mask: Optional[torch.Tensor] = None,

    ) -> torch.Tensor:
        """

        Efficient windowed attention implementation.

        Uses unfold to extract windows and batched matrix multiply.

        """
        batch, seq_len, _ = x.shape
        device = x.device

        if global_mask is None:
            global_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=device)

        # Compute QKV
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        # Reshape: (batch, seq_len, num_heads, head_dim)
        q = q.view(batch, seq_len, self.num_heads, self.head_dim)
        k = k.view(batch, seq_len, self.num_heads, self.head_dim)
        v = v.view(batch, seq_len, self.num_heads, self.head_dim)

        # Pad sequence for windowing
        pad_len = self.window_size // 2
        k_padded = F.pad(k, (0, 0, 0, 0, pad_len, pad_len))
        v_padded = F.pad(v, (0, 0, 0, 0, pad_len, pad_len))

        # Extract windows using unfold
        # (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, window_size, num_heads, head_dim)
        k_windows = k_padded.unfold(1, self.window_size, 1)
        v_windows = v_padded.unfold(1, self.window_size, 1)

        # Permute to (batch, seq_len, num_heads, window_size, head_dim)
        k_windows = k_windows.permute(0, 1, 3, 2, 4)
        v_windows = v_windows.permute(0, 1, 3, 2, 4)

        # Compute attention for each position
        # q: (batch, seq_len, num_heads, 1, head_dim)
        q_expanded = q.unsqueeze(3)
        k_windows = k_windows

        # Scores: (batch, seq_len, num_heads, 1, window_size)
        attn_scores = torch.matmul(q_expanded, k_windows.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_scores = attn_scores.squeeze(3)  # (batch, seq_len, num_heads, window_size)

        # Apply softmax
        attn_weights = F.softmax(attn_scores, dim=-1)

        # Weighted sum
        attn_output = torch.matmul(attn_weights.unsqueeze(3), v_windows).squeeze(3)
        # (batch, seq_len, num_heads, head_dim)

        # Concatenate heads
        attn_output = attn_output.view(batch, seq_len, self.d_model)

        # Add global token contribution if any
        if global_mask.any():
            # Compute full attention for global tokens only
            # This is a simplified version - in practice would be optimized
            global_indices = global_mask[0].nonzero(as_tuple=True)[0]
            if len(global_indices) > 0:
                # For positions with global tokens, add full attention
                # (simplified: compute full attention for all)
                full_attn = self._standard_attention(x, attention_mask)
                # Blend: local for most, full for global positions
                attn_output = torch.where(
                    global_mask.unsqueeze(-1),
                    full_attn,
                    attn_output
                )

        return self.out_proj(attn_output)


def test_vortex_local_attention():
    """Test the VortexLocalAttention layer."""
    batch_size = 2
    seq_len = 256
    d_model = 4096
    num_heads = 32
    window_size = 512

    attn = VortexLocalAttention(d_model, num_heads, window_size, use_flash_attention=False)
    x = torch.randn(batch_size, seq_len, d_model)

    # Forward pass
    output = attn(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    assert output.shape == x.shape, f"Expected {x.shape}, got {output.shape}"

    # With global mask
    global_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
    global_mask[0, 0] = True  # First token is global
    global_mask[1, -1] = True  # Last token is global
    output2 = attn(x, global_mask=global_mask)
    assert output2.shape == x.shape

    print("VortexLocalAttention test passed!")


if __name__ == "__main__":
    test_vortex_local_attention()