karthick
Upload TinyStories 24.5M model - article generation success
fb67af8
"""Multi-Head Attention with RoPE integration and memory optimizations.
Critical implementation details:
1. Apply RoPE only to Q and K, never to V
2. Use SDPA for Flash Attention 2 support
3. Pre-normalization architecture
4. Memory-efficient implementation
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
from .rope import RotaryPositionEmbeddings
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention with RoPE and Flash Attention support.
This implementation:
- Uses Rotary Position Embeddings (RoPE) on Q and K only
- Supports Flash Attention 2 via torch.nn.functional.scaled_dot_product_attention
- Uses no bias terms (modern approach)
- Includes proper causal masking
- Memory-efficient implementation
"""
def __init__(
self,
d_model: int = 768,
n_heads: int = 12,
dropout: float = 0.1,
max_seq_len: int = 2048,
rope_base: int = 10000,
rope_percentage: float = 0.5,
use_flash_attention: bool = True,
):
super().__init__()
assert d_model % n_heads == 0, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
# Windows Flash Attention: Test with PyTorch 2.10+ nightly
# Older versions had freezing issues, but newer versions may work
import sys
import logging
logger = logging.getLogger(__name__)
if sys.platform == 'win32' and use_flash_attention:
# Allow Flash Attention on Windows with PyTorch 2.10+
# If freezing occurs, set use_flash_attention: false in config
self.use_flash_attention = use_flash_attention
logger.info("[Windows] Attempting Flash Attention with PyTorch 2.10+ - if freezing occurs, disable in config")
elif sys.platform == 'win32':
self.use_flash_attention = False
logger.info("[Windows] Flash Attention disabled - using manual attention")
else:
self.use_flash_attention = use_flash_attention
self.dropout = dropout
self.scale = 1.0 / math.sqrt(self.head_dim)
# Q, K, V projections (no bias)
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.o_proj = nn.Linear(d_model, d_model, bias=False)
# RoPE for positional encoding
# Apply to only part of head dimensions (typically 50%)
rope_dim = int(self.head_dim * rope_percentage)
self.rope_dim = rope_dim
self.rope = RotaryPositionEmbeddings(
head_dim=rope_dim,
max_seq_len=max_seq_len,
base=rope_base
)
# Dropout
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
# Pre-allocate causal mask more efficiently
# We'll create it on-demand based on sequence length
self.register_buffer('cached_mask', None, persistent=False)
self.register_buffer('cached_mask_size', torch.tensor(0), persistent=False)
def _get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
"""Get or create causal mask for the given sequence length.
CRITICAL: Always returns mask on the specified device to prevent CPU OOM errors.
"""
if self.cached_mask is None or self.cached_mask_size < seq_len:
# Create a new mask directly on the target device
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
self.cached_mask = mask
self.cached_mask_size = torch.tensor(seq_len)
# CRITICAL: Ensure the returned mask is on the correct device
# This prevents CPU OOM when broadcasting during attn_scores + causal_mask
return self.cached_mask[:seq_len, :seq_len].to(device)
def _apply_rope(
self,
q: torch.Tensor,
k: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply RoPE to partial dimensions of Q and K.
Args:
q: Query tensor [batch, seq_len, n_heads, head_dim]
k: Key tensor [batch, seq_len, n_heads, head_dim]
position_ids: Optional custom position IDs
Returns:
Rotated Q and K tensors
"""
# Split into RoPE and pass-through dimensions
if self.rope_dim > 0:
q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:]
k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:]
# Apply RoPE to the first part
q_rope, k_rope = self.rope(q_rope, k_rope, position_ids)
# Concatenate back
q = torch.cat([q_rope, q_pass], dim=-1)
k = torch.cat([k_rope, k_pass], dim=-1)
return q, k
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
use_cache: bool = False,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
"""Forward pass of multi-head attention.
Args:
x: Input tensor [batch, seq_len, d_model]
attention_mask: Optional attention mask
position_ids: Optional position IDs for RoPE
use_cache: Whether to return KV cache for inference
past_kv: Past key-value cache for inference
Returns:
Output tensor and optional KV cache
"""
batch_size, seq_len, _ = x.size()
# Project to Q, K, V
q = self.q_proj(x) # [batch, seq_len, d_model]
k = self.k_proj(x) # [batch, seq_len, d_model]
v = self.v_proj(x) # [batch, seq_len, d_model]
# Reshape for multi-head attention
# [batch, seq_len, d_model] -> [batch, seq_len, n_heads, head_dim]
q = q.view(batch_size, seq_len, self.n_heads, self.head_dim)
k = k.view(batch_size, seq_len, self.n_heads, self.head_dim)
v = v.view(batch_size, seq_len, self.n_heads, self.head_dim)
# Apply RoPE to Q and K only (not V!)
q, k = self._apply_rope(q, k, position_ids)
# Handle KV cache for inference
if use_cache and past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=1)
v = torch.cat([past_v, v], dim=1)
kv_cache = (k, v) if use_cache else None
# Transpose for attention computation
# [batch, seq_len, n_heads, head_dim] -> [batch, n_heads, seq_len, head_dim]
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
# Use Flash Attention 2 via SDPA when available
# This is MUCH more memory efficient than manual attention
if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'):
# Flash Attention 2 is automatically used when available
# It handles the causal mask internally when is_causal=True
# NOTE: Windows compatibility - skip context manager to avoid freezing
import sys
if sys.platform == 'win32':
# On Windows, use SDPA without explicit kernel selection
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True if attention_mask is None else False,
scale=self.scale,
)
else:
# On Linux, use explicit kernel selection for best performance
with torch.backends.cuda.sdp_kernel(
enable_flash=True, # Use Flash Attention when possible
enable_math=True, # Fallback to math implementation
enable_mem_efficient=True # Use memory-efficient attention
):
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True if attention_mask is None else False,
scale=self.scale,
)
else:
# Manual attention computation (fallback)
# This is memory-intensive and should only be used for small sequences
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# Apply causal mask
if attention_mask is None:
causal_mask = self._get_causal_mask(seq_len, x.device)
# Expand mask for batch and heads
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
attn_scores = attn_scores + causal_mask
else:
attn_scores = attn_scores + attention_mask
# Apply softmax
attn_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Compute output
attn_output = torch.matmul(attn_weights, v)
# Reshape back
# [batch, n_heads, seq_len, head_dim] -> [batch, seq_len, d_model]
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
# Output projection
output = self.o_proj(attn_output)
output = self.resid_dropout(output)
return output, kv_cache
# Test the attention implementation
def test_attention():
"""Test multi-head attention with various configurations."""
print("Testing Multi-Head Attention...")
# Test configuration
batch_size = 2
seq_len = 128
d_model = 768
n_heads = 12
# Create attention module
attention = MultiHeadAttention(
d_model=d_model,
n_heads=n_heads,
dropout=0.1,
max_seq_len=2048,
rope_percentage=0.5,
use_flash_attention=True, # Enable Flash Attention
)
# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attention = attention.to(device)
attention.eval() # Set to eval mode for testing
# Create dummy input
x = torch.randn(batch_size, seq_len, d_model, device=device, dtype=torch.bfloat16)
# Forward pass
with torch.no_grad():
output, _ = attention(x)
# Check output shape
assert output.shape == (batch_size, seq_len, d_model), \
f"Expected shape {(batch_size, seq_len, d_model)}, got {output.shape}"
# Check for NaN
assert not torch.isnan(output).any(), "Output contains NaN values!"
print("✓ Multi-Head Attention test passed!")
print(f" Input shape: {x.shape}")
print(f" Output shape: {output.shape}")
print(f" Device: {device}")
print(f" Memory allocated: {torch.cuda.memory_allocated(device) / 1024**3:.2f} GB")
return True
if __name__ == "__main__":
test_attention()