File size: 11,714 Bytes
fb67af8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
"""Multi-Head Attention with RoPE integration and memory optimizations.
Critical implementation details:
1. Apply RoPE only to Q and K, never to V
2. Use SDPA for Flash Attention 2 support
3. Pre-normalization architecture
4. Memory-efficient implementation
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
from .rope import RotaryPositionEmbeddings
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention with RoPE and Flash Attention support.
This implementation:
- Uses Rotary Position Embeddings (RoPE) on Q and K only
- Supports Flash Attention 2 via torch.nn.functional.scaled_dot_product_attention
- Uses no bias terms (modern approach)
- Includes proper causal masking
- Memory-efficient implementation
"""
def __init__(
self,
d_model: int = 768,
n_heads: int = 12,
dropout: float = 0.1,
max_seq_len: int = 2048,
rope_base: int = 10000,
rope_percentage: float = 0.5,
use_flash_attention: bool = True,
):
super().__init__()
assert d_model % n_heads == 0, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
# Windows Flash Attention: Test with PyTorch 2.10+ nightly
# Older versions had freezing issues, but newer versions may work
import sys
import logging
logger = logging.getLogger(__name__)
if sys.platform == 'win32' and use_flash_attention:
# Allow Flash Attention on Windows with PyTorch 2.10+
# If freezing occurs, set use_flash_attention: false in config
self.use_flash_attention = use_flash_attention
logger.info("[Windows] Attempting Flash Attention with PyTorch 2.10+ - if freezing occurs, disable in config")
elif sys.platform == 'win32':
self.use_flash_attention = False
logger.info("[Windows] Flash Attention disabled - using manual attention")
else:
self.use_flash_attention = use_flash_attention
self.dropout = dropout
self.scale = 1.0 / math.sqrt(self.head_dim)
# Q, K, V projections (no bias)
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.o_proj = nn.Linear(d_model, d_model, bias=False)
# RoPE for positional encoding
# Apply to only part of head dimensions (typically 50%)
rope_dim = int(self.head_dim * rope_percentage)
self.rope_dim = rope_dim
self.rope = RotaryPositionEmbeddings(
head_dim=rope_dim,
max_seq_len=max_seq_len,
base=rope_base
)
# Dropout
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
# Pre-allocate causal mask more efficiently
# We'll create it on-demand based on sequence length
self.register_buffer('cached_mask', None, persistent=False)
self.register_buffer('cached_mask_size', torch.tensor(0), persistent=False)
def _get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
"""Get or create causal mask for the given sequence length.
CRITICAL: Always returns mask on the specified device to prevent CPU OOM errors.
"""
if self.cached_mask is None or self.cached_mask_size < seq_len:
# Create a new mask directly on the target device
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
self.cached_mask = mask
self.cached_mask_size = torch.tensor(seq_len)
# CRITICAL: Ensure the returned mask is on the correct device
# This prevents CPU OOM when broadcasting during attn_scores + causal_mask
return self.cached_mask[:seq_len, :seq_len].to(device)
def _apply_rope(
self,
q: torch.Tensor,
k: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply RoPE to partial dimensions of Q and K.
Args:
q: Query tensor [batch, seq_len, n_heads, head_dim]
k: Key tensor [batch, seq_len, n_heads, head_dim]
position_ids: Optional custom position IDs
Returns:
Rotated Q and K tensors
"""
# Split into RoPE and pass-through dimensions
if self.rope_dim > 0:
q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:]
k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:]
# Apply RoPE to the first part
q_rope, k_rope = self.rope(q_rope, k_rope, position_ids)
# Concatenate back
q = torch.cat([q_rope, q_pass], dim=-1)
k = torch.cat([k_rope, k_pass], dim=-1)
return q, k
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
use_cache: bool = False,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
"""Forward pass of multi-head attention.
Args:
x: Input tensor [batch, seq_len, d_model]
attention_mask: Optional attention mask
position_ids: Optional position IDs for RoPE
use_cache: Whether to return KV cache for inference
past_kv: Past key-value cache for inference
Returns:
Output tensor and optional KV cache
"""
batch_size, seq_len, _ = x.size()
# Project to Q, K, V
q = self.q_proj(x) # [batch, seq_len, d_model]
k = self.k_proj(x) # [batch, seq_len, d_model]
v = self.v_proj(x) # [batch, seq_len, d_model]
# Reshape for multi-head attention
# [batch, seq_len, d_model] -> [batch, seq_len, n_heads, head_dim]
q = q.view(batch_size, seq_len, self.n_heads, self.head_dim)
k = k.view(batch_size, seq_len, self.n_heads, self.head_dim)
v = v.view(batch_size, seq_len, self.n_heads, self.head_dim)
# Apply RoPE to Q and K only (not V!)
q, k = self._apply_rope(q, k, position_ids)
# Handle KV cache for inference
if use_cache and past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=1)
v = torch.cat([past_v, v], dim=1)
kv_cache = (k, v) if use_cache else None
# Transpose for attention computation
# [batch, seq_len, n_heads, head_dim] -> [batch, n_heads, seq_len, head_dim]
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
# Use Flash Attention 2 via SDPA when available
# This is MUCH more memory efficient than manual attention
if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'):
# Flash Attention 2 is automatically used when available
# It handles the causal mask internally when is_causal=True
# NOTE: Windows compatibility - skip context manager to avoid freezing
import sys
if sys.platform == 'win32':
# On Windows, use SDPA without explicit kernel selection
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True if attention_mask is None else False,
scale=self.scale,
)
else:
# On Linux, use explicit kernel selection for best performance
with torch.backends.cuda.sdp_kernel(
enable_flash=True, # Use Flash Attention when possible
enable_math=True, # Fallback to math implementation
enable_mem_efficient=True # Use memory-efficient attention
):
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True if attention_mask is None else False,
scale=self.scale,
)
else:
# Manual attention computation (fallback)
# This is memory-intensive and should only be used for small sequences
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# Apply causal mask
if attention_mask is None:
causal_mask = self._get_causal_mask(seq_len, x.device)
# Expand mask for batch and heads
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
attn_scores = attn_scores + causal_mask
else:
attn_scores = attn_scores + attention_mask
# Apply softmax
attn_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Compute output
attn_output = torch.matmul(attn_weights, v)
# Reshape back
# [batch, n_heads, seq_len, head_dim] -> [batch, seq_len, d_model]
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
# Output projection
output = self.o_proj(attn_output)
output = self.resid_dropout(output)
return output, kv_cache
# Test the attention implementation
def test_attention():
"""Test multi-head attention with various configurations."""
print("Testing Multi-Head Attention...")
# Test configuration
batch_size = 2
seq_len = 128
d_model = 768
n_heads = 12
# Create attention module
attention = MultiHeadAttention(
d_model=d_model,
n_heads=n_heads,
dropout=0.1,
max_seq_len=2048,
rope_percentage=0.5,
use_flash_attention=True, # Enable Flash Attention
)
# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attention = attention.to(device)
attention.eval() # Set to eval mode for testing
# Create dummy input
x = torch.randn(batch_size, seq_len, d_model, device=device, dtype=torch.bfloat16)
# Forward pass
with torch.no_grad():
output, _ = attention(x)
# Check output shape
assert output.shape == (batch_size, seq_len, d_model), \
f"Expected shape {(batch_size, seq_len, d_model)}, got {output.shape}"
# Check for NaN
assert not torch.isnan(output).any(), "Output contains NaN values!"
print("✓ Multi-Head Attention test passed!")
print(f" Input shape: {x.shape}")
print(f" Output shape: {output.shape}")
print(f" Device: {device}")
print(f" Memory allocated: {torch.cuda.memory_allocated(device) / 1024**3:.2f} GB")
return True
if __name__ == "__main__":
test_attention() |