Vortex-13b-V1 / models /attention_layer.py
Zandy-Wandy's picture
Upload Vortex model
5c43f61 verified
"""
VortexLocalAttention: Local windowed attention with global token support.
Uses a sliding window of 512 tokens for efficiency, with special handling
for global tokens that can attend across the entire sequence.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
class VortexLocalAttention(nn.Module):
"""
Local windowed attention with window_size=512.
Science documents have strong local coherence — equations reference
nearby text, not distant paragraphs.
Global tokens (special [SCIENCE] tokens) attend to everything.
"""
def __init__(
self,
d_model: int,
num_heads: int,
window_size: int = 512,
use_flash_attention: bool = True,
):
"""
Initialize local windowed attention.
Args:
d_model: Model dimension
num_heads: Number of attention heads
window_size: Size of local attention window
use_flash_attention: Use Flash Attention 2 if available (CUDA only)
"""
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.window_size = window_size
self.use_flash_attention = use_flash_attention
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# QKV projection
self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
# Global token projection (for tokens that attend globally)
self.global_qkv = nn.Linear(d_model, d_model * 3, bias=False)
# Initialize weights
self._initialize_weights()
def _initialize_weights(self):
"""Initialize weights."""
for module in [self.qkv, self.global_qkv, self.out_proj]:
if hasattr(module, 'weight'):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
x: torch.Tensor,
global_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass with local windowed attention.
Args:
x: Input tensor (batch, seq_len, d_model)
global_mask: Boolean mask indicating which tokens are global (attend everywhere)
Shape: (batch, seq_len) or None
attention_mask: Optional padding mask (batch, seq_len)
Returns:
Output tensor (batch, seq_len, d_model)
"""
batch, seq_len, _ = x.shape
device = x.device
dtype = x.dtype
if global_mask is None:
global_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=device)
# Compute QKV for all tokens
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape for multi-head attention
q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Compute global token QKV separately
if global_mask.any():
global_qkv = self.global_qkv(x)
gq, gk, gv = global_qkv.chunk(3, dim=-1)
gq = gq.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
gk = gk.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
gv = gv.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Build output tensor
output = torch.zeros_like(x)
# Process each position
for t in range(seq_len):
# Determine window
window_start = max(0, t - self.window_size // 2)
window_end = min(seq_len, t + self.window_size // 2 + 1)
window_len = window_end - window_start
# Get window indices
window_indices = slice(window_start, window_end)
# Extract window queries (for position t)
q_t = q[:, :, t:t+1, :] # (batch, heads, 1, head_dim)
# Determine which keys/values to use
# Local tokens: only those in window
# Global tokens: all positions (if they are global)
k_window = k[:, :, window_indices, :]
v_window = v[:, :, window_indices, :]
# Build full key/value set including global tokens
# Global tokens attend to all positions
if global_mask.any():
# Find global positions
global_positions = global_mask[0] # (seq_len) - assume same across batch
if global_positions.any():
gk_all = gk[:, :, :, :] # All global keys
gv_all = gv[:, :, :, :]
# Concatenate window keys with global keys
k_full = torch.cat([k_window, gk_all], dim=2)
v_full = torch.cat([v_window, gv_all], dim=2)
else:
k_full = k_window
v_full = v_window
else:
k_full = k_window
v_full = v_window
# Compute attention scores
# q_t: (batch, heads, 1, head_dim)
# k_full: (batch, heads, window_len + num_global, head_dim)
attn_scores = torch.matmul(q_t, k_full.transpose(-2, -1)) / (self.head_dim ** 0.5)
# (batch, heads, 1, k_len)
# Apply attention mask if provided
if attention_mask is not None:
mask_t = attention_mask[:, window_indices].unsqueeze(1).unsqueeze(2)
attn_scores = attn_scores.masked_fill(mask_t == 0, -1e9)
# Softmax
attn_weights = F.softmax(attn_scores, dim=-1)
# Weighted sum
attn_output = torch.matmul(attn_weights, v_full)
# (batch, heads, 1, head_dim)
# Reshape and project
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch, 1, self.d_model)
attn_output = self.out_proj(attn_output)
# Place in output
output[:, t:t+1, :] = attn_output
return output
def forward_optimized(
self,
x: torch.Tensor,
global_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Optimized forward pass using Flash Attention or efficient windowed attention.
This is a placeholder for actual Flash Attention integration.
"""
batch, seq_len, _ = x.shape
if self.use_flash_attention and self.window_size >= seq_len:
# For short sequences, can use full attention
return self._flash_attention_forward(x, attention_mask)
else:
# Use windowed attention
return self._windowed_attention_forward(x, global_mask, attention_mask)
def _flash_attention_forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Use Flash Attention 2 if available.
Requires: pip install flash-attn
"""
try:
from flash_attn import flash_attn_func
batch, seq_len, _ = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape for flash attention
q = q.view(batch, seq_len, self.num_heads, self.head_dim)
k = k.view(batch, seq_len, self.num_heads, self.head_dim)
v = v.view(batch, seq_len, self.num_heads, self.head_dim)
# Flash attention expects (batch, seq_len, num_heads, head_dim)
# and returns same shape
if attention_mask is not None:
# Flash attention uses causal mask or padding mask
output = flash_attn_func(
q, k, v,
causal=False,
softmax_scale=1.0 / (self.head_dim ** 0.5),
)
else:
output = flash_attn_func(
q, k, v,
causal=False,
)
output = output.view(batch, seq_len, self.d_model)
return self.out_proj(output)
except ImportError:
print("Flash Attention not available, falling back to standard attention")
return self._standard_attention(x, attention_mask)
def _standard_attention(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Standard full attention (quadratic)."""
batch, seq_len, _ = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Compute attention scores
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
if attention_mask is not None:
attn_scores = attn_scores.masked_fill(
attention_mask.unsqueeze(1).unsqueeze(2) == 0,
-1e9
)
attn_weights = F.softmax(attn_scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch, seq_len, self.d_model)
return self.out_proj(attn_output)
def _windowed_attention_forward(
self,
x: torch.Tensor,
global_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Efficient windowed attention implementation.
Uses unfold to extract windows and batched matrix multiply.
"""
batch, seq_len, _ = x.shape
device = x.device
if global_mask is None:
global_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=device)
# Compute QKV
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape: (batch, seq_len, num_heads, head_dim)
q = q.view(batch, seq_len, self.num_heads, self.head_dim)
k = k.view(batch, seq_len, self.num_heads, self.head_dim)
v = v.view(batch, seq_len, self.num_heads, self.head_dim)
# Pad sequence for windowing
pad_len = self.window_size // 2
k_padded = F.pad(k, (0, 0, 0, 0, pad_len, pad_len))
v_padded = F.pad(v, (0, 0, 0, 0, pad_len, pad_len))
# Extract windows using unfold
# (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, window_size, num_heads, head_dim)
k_windows = k_padded.unfold(1, self.window_size, 1)
v_windows = v_padded.unfold(1, self.window_size, 1)
# Permute to (batch, seq_len, num_heads, window_size, head_dim)
k_windows = k_windows.permute(0, 1, 3, 2, 4)
v_windows = v_windows.permute(0, 1, 3, 2, 4)
# Compute attention for each position
# q: (batch, seq_len, num_heads, 1, head_dim)
q_expanded = q.unsqueeze(3)
k_windows = k_windows
# Scores: (batch, seq_len, num_heads, 1, window_size)
attn_scores = torch.matmul(q_expanded, k_windows.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn_scores = attn_scores.squeeze(3) # (batch, seq_len, num_heads, window_size)
# Apply softmax
attn_weights = F.softmax(attn_scores, dim=-1)
# Weighted sum
attn_output = torch.matmul(attn_weights.unsqueeze(3), v_windows).squeeze(3)
# (batch, seq_len, num_heads, head_dim)
# Concatenate heads
attn_output = attn_output.view(batch, seq_len, self.d_model)
# Add global token contribution if any
if global_mask.any():
# Compute full attention for global tokens only
# This is a simplified version - in practice would be optimized
global_indices = global_mask[0].nonzero(as_tuple=True)[0]
if len(global_indices) > 0:
# For positions with global tokens, add full attention
# (simplified: compute full attention for all)
full_attn = self._standard_attention(x, attention_mask)
# Blend: local for most, full for global positions
attn_output = torch.where(
global_mask.unsqueeze(-1),
full_attn,
attn_output
)
return self.out_proj(attn_output)
def test_vortex_local_attention():
"""Test the VortexLocalAttention layer."""
batch_size = 2
seq_len = 256
d_model = 4096
num_heads = 32
window_size = 512
attn = VortexLocalAttention(d_model, num_heads, window_size, use_flash_attention=False)
x = torch.randn(batch_size, seq_len, d_model)
# Forward pass
output = attn(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
assert output.shape == x.shape, f"Expected {x.shape}, got {output.shape}"
# With global mask
global_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
global_mask[0, 0] = True # First token is global
global_mask[1, -1] = True # Last token is global
output2 = attn(x, global_mask=global_mask)
assert output2.shape == x.shape
print("VortexLocalAttention test passed!")
if __name__ == "__main__":
test_vortex_local_attention()