| | """
|
| | VortexLocalAttention: Local windowed attention with global token support.
|
| | Uses a sliding window of 512 tokens for efficiency, with special handling
|
| | for global tokens that can attend across the entire sequence.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from typing import Optional, Tuple
|
| |
|
| |
|
| | class VortexLocalAttention(nn.Module):
|
| | """
|
| | Local windowed attention with window_size=512.
|
| | Science documents have strong local coherence — equations reference
|
| | nearby text, not distant paragraphs.
|
| | Global tokens (special [SCIENCE] tokens) attend to everything.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | d_model: int,
|
| | num_heads: int,
|
| | window_size: int = 512,
|
| | use_flash_attention: bool = True,
|
| | ):
|
| | """
|
| | Initialize local windowed attention.
|
| |
|
| | Args:
|
| | d_model: Model dimension
|
| | num_heads: Number of attention heads
|
| | window_size: Size of local attention window
|
| | use_flash_attention: Use Flash Attention 2 if available (CUDA only)
|
| | """
|
| | super().__init__()
|
| | self.d_model = d_model
|
| | self.num_heads = num_heads
|
| | self.head_dim = d_model // num_heads
|
| | self.window_size = window_size
|
| | self.use_flash_attention = use_flash_attention
|
| |
|
| | assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
| |
|
| |
|
| | self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
|
| | self.out_proj = nn.Linear(d_model, d_model, bias=False)
|
| |
|
| |
|
| | self.global_qkv = nn.Linear(d_model, d_model * 3, bias=False)
|
| |
|
| |
|
| | self._initialize_weights()
|
| |
|
| | def _initialize_weights(self):
|
| | """Initialize weights."""
|
| | for module in [self.qkv, self.global_qkv, self.out_proj]:
|
| | if hasattr(module, 'weight'):
|
| | nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| |
|
| | def forward(
|
| | self,
|
| | x: torch.Tensor,
|
| | global_mask: Optional[torch.Tensor] = None,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Forward pass with local windowed attention.
|
| |
|
| | Args:
|
| | x: Input tensor (batch, seq_len, d_model)
|
| | global_mask: Boolean mask indicating which tokens are global (attend everywhere)
|
| | Shape: (batch, seq_len) or None
|
| | attention_mask: Optional padding mask (batch, seq_len)
|
| |
|
| | Returns:
|
| | Output tensor (batch, seq_len, d_model)
|
| | """
|
| | batch, seq_len, _ = x.shape
|
| | device = x.device
|
| | dtype = x.dtype
|
| |
|
| | if global_mask is None:
|
| | global_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=device)
|
| |
|
| |
|
| | qkv = self.qkv(x)
|
| | q, k, v = qkv.chunk(3, dim=-1)
|
| |
|
| |
|
| | q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| | k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| | v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| |
|
| |
|
| | if global_mask.any():
|
| | global_qkv = self.global_qkv(x)
|
| | gq, gk, gv = global_qkv.chunk(3, dim=-1)
|
| | gq = gq.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| | gk = gk.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| | gv = gv.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| |
|
| |
|
| | output = torch.zeros_like(x)
|
| |
|
| |
|
| | for t in range(seq_len):
|
| |
|
| | window_start = max(0, t - self.window_size // 2)
|
| | window_end = min(seq_len, t + self.window_size // 2 + 1)
|
| | window_len = window_end - window_start
|
| |
|
| |
|
| | window_indices = slice(window_start, window_end)
|
| |
|
| |
|
| | q_t = q[:, :, t:t+1, :]
|
| |
|
| |
|
| |
|
| |
|
| | k_window = k[:, :, window_indices, :]
|
| | v_window = v[:, :, window_indices, :]
|
| |
|
| |
|
| |
|
| | if global_mask.any():
|
| |
|
| | global_positions = global_mask[0]
|
| | if global_positions.any():
|
| | gk_all = gk[:, :, :, :]
|
| | gv_all = gv[:, :, :, :]
|
| |
|
| |
|
| | k_full = torch.cat([k_window, gk_all], dim=2)
|
| | v_full = torch.cat([v_window, gv_all], dim=2)
|
| | else:
|
| | k_full = k_window
|
| | v_full = v_window
|
| | else:
|
| | k_full = k_window
|
| | v_full = v_window
|
| |
|
| |
|
| |
|
| |
|
| | attn_scores = torch.matmul(q_t, k_full.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
| |
|
| |
|
| |
|
| | if attention_mask is not None:
|
| | mask_t = attention_mask[:, window_indices].unsqueeze(1).unsqueeze(2)
|
| | attn_scores = attn_scores.masked_fill(mask_t == 0, -1e9)
|
| |
|
| |
|
| | attn_weights = F.softmax(attn_scores, dim=-1)
|
| |
|
| |
|
| | attn_output = torch.matmul(attn_weights, v_full)
|
| |
|
| |
|
| |
|
| | attn_output = attn_output.transpose(1, 2).contiguous()
|
| | attn_output = attn_output.view(batch, 1, self.d_model)
|
| | attn_output = self.out_proj(attn_output)
|
| |
|
| |
|
| | output[:, t:t+1, :] = attn_output
|
| |
|
| | return output
|
| |
|
| | def forward_optimized(
|
| | self,
|
| | x: torch.Tensor,
|
| | global_mask: Optional[torch.Tensor] = None,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Optimized forward pass using Flash Attention or efficient windowed attention.
|
| | This is a placeholder for actual Flash Attention integration.
|
| | """
|
| | batch, seq_len, _ = x.shape
|
| |
|
| | if self.use_flash_attention and self.window_size >= seq_len:
|
| |
|
| | return self._flash_attention_forward(x, attention_mask)
|
| | else:
|
| |
|
| | return self._windowed_attention_forward(x, global_mask, attention_mask)
|
| |
|
| | def _flash_attention_forward(
|
| | self,
|
| | x: torch.Tensor,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Use Flash Attention 2 if available.
|
| | Requires: pip install flash-attn
|
| | """
|
| | try:
|
| | from flash_attn import flash_attn_func
|
| |
|
| | batch, seq_len, _ = x.shape
|
| | qkv = self.qkv(x)
|
| | q, k, v = qkv.chunk(3, dim=-1)
|
| |
|
| |
|
| | q = q.view(batch, seq_len, self.num_heads, self.head_dim)
|
| | k = k.view(batch, seq_len, self.num_heads, self.head_dim)
|
| | v = v.view(batch, seq_len, self.num_heads, self.head_dim)
|
| |
|
| |
|
| |
|
| | if attention_mask is not None:
|
| |
|
| | output = flash_attn_func(
|
| | q, k, v,
|
| | causal=False,
|
| | softmax_scale=1.0 / (self.head_dim ** 0.5),
|
| | )
|
| | else:
|
| | output = flash_attn_func(
|
| | q, k, v,
|
| | causal=False,
|
| | )
|
| |
|
| | output = output.view(batch, seq_len, self.d_model)
|
| | return self.out_proj(output)
|
| |
|
| | except ImportError:
|
| | print("Flash Attention not available, falling back to standard attention")
|
| | return self._standard_attention(x, attention_mask)
|
| |
|
| | def _standard_attention(
|
| | self,
|
| | x: torch.Tensor,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | ) -> torch.Tensor:
|
| | """Standard full attention (quadratic)."""
|
| | batch, seq_len, _ = x.shape
|
| | qkv = self.qkv(x)
|
| | q, k, v = qkv.chunk(3, dim=-1)
|
| |
|
| | q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| | k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| | v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| |
|
| |
|
| | attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
| |
|
| | if attention_mask is not None:
|
| | attn_scores = attn_scores.masked_fill(
|
| | attention_mask.unsqueeze(1).unsqueeze(2) == 0,
|
| | -1e9
|
| | )
|
| |
|
| | attn_weights = F.softmax(attn_scores, dim=-1)
|
| | attn_output = torch.matmul(attn_weights, v)
|
| |
|
| | attn_output = attn_output.transpose(1, 2).contiguous()
|
| | attn_output = attn_output.view(batch, seq_len, self.d_model)
|
| | return self.out_proj(attn_output)
|
| |
|
| | def _windowed_attention_forward(
|
| | self,
|
| | x: torch.Tensor,
|
| | global_mask: Optional[torch.Tensor] = None,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Efficient windowed attention implementation.
|
| | Uses unfold to extract windows and batched matrix multiply.
|
| | """
|
| | batch, seq_len, _ = x.shape
|
| | device = x.device
|
| |
|
| | if global_mask is None:
|
| | global_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=device)
|
| |
|
| |
|
| | qkv = self.qkv(x)
|
| | q, k, v = qkv.chunk(3, dim=-1)
|
| |
|
| |
|
| | q = q.view(batch, seq_len, self.num_heads, self.head_dim)
|
| | k = k.view(batch, seq_len, self.num_heads, self.head_dim)
|
| | v = v.view(batch, seq_len, self.num_heads, self.head_dim)
|
| |
|
| |
|
| | pad_len = self.window_size // 2
|
| | k_padded = F.pad(k, (0, 0, 0, 0, pad_len, pad_len))
|
| | v_padded = F.pad(v, (0, 0, 0, 0, pad_len, pad_len))
|
| |
|
| |
|
| |
|
| | k_windows = k_padded.unfold(1, self.window_size, 1)
|
| | v_windows = v_padded.unfold(1, self.window_size, 1)
|
| |
|
| |
|
| | k_windows = k_windows.permute(0, 1, 3, 2, 4)
|
| | v_windows = v_windows.permute(0, 1, 3, 2, 4)
|
| |
|
| |
|
| |
|
| | q_expanded = q.unsqueeze(3)
|
| | k_windows = k_windows
|
| |
|
| |
|
| | attn_scores = torch.matmul(q_expanded, k_windows.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
| | attn_scores = attn_scores.squeeze(3)
|
| |
|
| |
|
| | attn_weights = F.softmax(attn_scores, dim=-1)
|
| |
|
| |
|
| | attn_output = torch.matmul(attn_weights.unsqueeze(3), v_windows).squeeze(3)
|
| |
|
| |
|
| |
|
| | attn_output = attn_output.view(batch, seq_len, self.d_model)
|
| |
|
| |
|
| | if global_mask.any():
|
| |
|
| |
|
| | global_indices = global_mask[0].nonzero(as_tuple=True)[0]
|
| | if len(global_indices) > 0:
|
| |
|
| |
|
| | full_attn = self._standard_attention(x, attention_mask)
|
| |
|
| | attn_output = torch.where(
|
| | global_mask.unsqueeze(-1),
|
| | full_attn,
|
| | attn_output
|
| | )
|
| |
|
| | return self.out_proj(attn_output)
|
| |
|
| |
|
| | def test_vortex_local_attention():
|
| | """Test the VortexLocalAttention layer."""
|
| | batch_size = 2
|
| | seq_len = 256
|
| | d_model = 4096
|
| | num_heads = 32
|
| | window_size = 512
|
| |
|
| | attn = VortexLocalAttention(d_model, num_heads, window_size, use_flash_attention=False)
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| |
|
| |
|
| | output = attn(x)
|
| | print(f"Input shape: {x.shape}")
|
| | print(f"Output shape: {output.shape}")
|
| | assert output.shape == x.shape, f"Expected {x.shape}, got {output.shape}"
|
| |
|
| |
|
| | global_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
|
| | global_mask[0, 0] = True
|
| | global_mask[1, -1] = True
|
| | output2 = attn(x, global_mask=global_mask)
|
| | assert output2.shape == x.shape
|
| |
|
| | print("VortexLocalAttention test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_vortex_local_attention()
|
| |
|