""" VortexLocalAttention: Local windowed attention with global token support. Uses a sliding window of 512 tokens for efficiency, with special handling for global tokens that can attend across the entire sequence. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple class VortexLocalAttention(nn.Module): """ Local windowed attention with window_size=512. Science documents have strong local coherence — equations reference nearby text, not distant paragraphs. Global tokens (special [SCIENCE] tokens) attend to everything. """ def __init__( self, d_model: int, num_heads: int, window_size: int = 512, use_flash_attention: bool = True, ): """ Initialize local windowed attention. Args: d_model: Model dimension num_heads: Number of attention heads window_size: Size of local attention window use_flash_attention: Use Flash Attention 2 if available (CUDA only) """ super().__init__() self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads self.window_size = window_size self.use_flash_attention = use_flash_attention assert d_model % num_heads == 0, "d_model must be divisible by num_heads" # QKV projection self.qkv = nn.Linear(d_model, d_model * 3, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) # Global token projection (for tokens that attend globally) self.global_qkv = nn.Linear(d_model, d_model * 3, bias=False) # Initialize weights self._initialize_weights() def _initialize_weights(self): """Initialize weights.""" for module in [self.qkv, self.global_qkv, self.out_proj]: if hasattr(module, 'weight'): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward( self, x: torch.Tensor, global_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass with local windowed attention. Args: x: Input tensor (batch, seq_len, d_model) global_mask: Boolean mask indicating which tokens are global (attend everywhere) Shape: (batch, seq_len) or None attention_mask: Optional padding mask (batch, seq_len) Returns: Output tensor (batch, seq_len, d_model) """ batch, seq_len, _ = x.shape device = x.device dtype = x.dtype if global_mask is None: global_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=device) # Compute QKV for all tokens qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) # Reshape for multi-head attention q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # Compute global token QKV separately if global_mask.any(): global_qkv = self.global_qkv(x) gq, gk, gv = global_qkv.chunk(3, dim=-1) gq = gq.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) gk = gk.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) gv = gv.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # Build output tensor output = torch.zeros_like(x) # Process each position for t in range(seq_len): # Determine window window_start = max(0, t - self.window_size // 2) window_end = min(seq_len, t + self.window_size // 2 + 1) window_len = window_end - window_start # Get window indices window_indices = slice(window_start, window_end) # Extract window queries (for position t) q_t = q[:, :, t:t+1, :] # (batch, heads, 1, head_dim) # Determine which keys/values to use # Local tokens: only those in window # Global tokens: all positions (if they are global) k_window = k[:, :, window_indices, :] v_window = v[:, :, window_indices, :] # Build full key/value set including global tokens # Global tokens attend to all positions if global_mask.any(): # Find global positions global_positions = global_mask[0] # (seq_len) - assume same across batch if global_positions.any(): gk_all = gk[:, :, :, :] # All global keys gv_all = gv[:, :, :, :] # Concatenate window keys with global keys k_full = torch.cat([k_window, gk_all], dim=2) v_full = torch.cat([v_window, gv_all], dim=2) else: k_full = k_window v_full = v_window else: k_full = k_window v_full = v_window # Compute attention scores # q_t: (batch, heads, 1, head_dim) # k_full: (batch, heads, window_len + num_global, head_dim) attn_scores = torch.matmul(q_t, k_full.transpose(-2, -1)) / (self.head_dim ** 0.5) # (batch, heads, 1, k_len) # Apply attention mask if provided if attention_mask is not None: mask_t = attention_mask[:, window_indices].unsqueeze(1).unsqueeze(2) attn_scores = attn_scores.masked_fill(mask_t == 0, -1e9) # Softmax attn_weights = F.softmax(attn_scores, dim=-1) # Weighted sum attn_output = torch.matmul(attn_weights, v_full) # (batch, heads, 1, head_dim) # Reshape and project attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch, 1, self.d_model) attn_output = self.out_proj(attn_output) # Place in output output[:, t:t+1, :] = attn_output return output def forward_optimized( self, x: torch.Tensor, global_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Optimized forward pass using Flash Attention or efficient windowed attention. This is a placeholder for actual Flash Attention integration. """ batch, seq_len, _ = x.shape if self.use_flash_attention and self.window_size >= seq_len: # For short sequences, can use full attention return self._flash_attention_forward(x, attention_mask) else: # Use windowed attention return self._windowed_attention_forward(x, global_mask, attention_mask) def _flash_attention_forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Use Flash Attention 2 if available. Requires: pip install flash-attn """ try: from flash_attn import flash_attn_func batch, seq_len, _ = x.shape qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) # Reshape for flash attention q = q.view(batch, seq_len, self.num_heads, self.head_dim) k = k.view(batch, seq_len, self.num_heads, self.head_dim) v = v.view(batch, seq_len, self.num_heads, self.head_dim) # Flash attention expects (batch, seq_len, num_heads, head_dim) # and returns same shape if attention_mask is not None: # Flash attention uses causal mask or padding mask output = flash_attn_func( q, k, v, causal=False, softmax_scale=1.0 / (self.head_dim ** 0.5), ) else: output = flash_attn_func( q, k, v, causal=False, ) output = output.view(batch, seq_len, self.d_model) return self.out_proj(output) except ImportError: print("Flash Attention not available, falling back to standard attention") return self._standard_attention(x, attention_mask) def _standard_attention( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Standard full attention (quadratic).""" batch, seq_len, _ = x.shape qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # Compute attention scores attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) if attention_mask is not None: attn_scores = attn_scores.masked_fill( attention_mask.unsqueeze(1).unsqueeze(2) == 0, -1e9 ) attn_weights = F.softmax(attn_scores, dim=-1) attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch, seq_len, self.d_model) return self.out_proj(attn_output) def _windowed_attention_forward( self, x: torch.Tensor, global_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Efficient windowed attention implementation. Uses unfold to extract windows and batched matrix multiply. """ batch, seq_len, _ = x.shape device = x.device if global_mask is None: global_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=device) # Compute QKV qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) # Reshape: (batch, seq_len, num_heads, head_dim) q = q.view(batch, seq_len, self.num_heads, self.head_dim) k = k.view(batch, seq_len, self.num_heads, self.head_dim) v = v.view(batch, seq_len, self.num_heads, self.head_dim) # Pad sequence for windowing pad_len = self.window_size // 2 k_padded = F.pad(k, (0, 0, 0, 0, pad_len, pad_len)) v_padded = F.pad(v, (0, 0, 0, 0, pad_len, pad_len)) # Extract windows using unfold # (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, window_size, num_heads, head_dim) k_windows = k_padded.unfold(1, self.window_size, 1) v_windows = v_padded.unfold(1, self.window_size, 1) # Permute to (batch, seq_len, num_heads, window_size, head_dim) k_windows = k_windows.permute(0, 1, 3, 2, 4) v_windows = v_windows.permute(0, 1, 3, 2, 4) # Compute attention for each position # q: (batch, seq_len, num_heads, 1, head_dim) q_expanded = q.unsqueeze(3) k_windows = k_windows # Scores: (batch, seq_len, num_heads, 1, window_size) attn_scores = torch.matmul(q_expanded, k_windows.transpose(-2, -1)) / (self.head_dim ** 0.5) attn_scores = attn_scores.squeeze(3) # (batch, seq_len, num_heads, window_size) # Apply softmax attn_weights = F.softmax(attn_scores, dim=-1) # Weighted sum attn_output = torch.matmul(attn_weights.unsqueeze(3), v_windows).squeeze(3) # (batch, seq_len, num_heads, head_dim) # Concatenate heads attn_output = attn_output.view(batch, seq_len, self.d_model) # Add global token contribution if any if global_mask.any(): # Compute full attention for global tokens only # This is a simplified version - in practice would be optimized global_indices = global_mask[0].nonzero(as_tuple=True)[0] if len(global_indices) > 0: # For positions with global tokens, add full attention # (simplified: compute full attention for all) full_attn = self._standard_attention(x, attention_mask) # Blend: local for most, full for global positions attn_output = torch.where( global_mask.unsqueeze(-1), full_attn, attn_output ) return self.out_proj(attn_output) def test_vortex_local_attention(): """Test the VortexLocalAttention layer.""" batch_size = 2 seq_len = 256 d_model = 4096 num_heads = 32 window_size = 512 attn = VortexLocalAttention(d_model, num_heads, window_size, use_flash_attention=False) x = torch.randn(batch_size, seq_len, d_model) # Forward pass output = attn(x) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") assert output.shape == x.shape, f"Expected {x.shape}, got {output.shape}" # With global mask global_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool) global_mask[0, 0] = True # First token is global global_mask[1, -1] = True # Last token is global output2 = attn(x, global_mask=global_mask) assert output2.shape == x.shape print("VortexLocalAttention test passed!") if __name__ == "__main__": test_vortex_local_attention()