flash-attention-explorer / src /attention_utils.py
a0y0346
fix: Add fallback SDPA benchmark when attention layer fails
685194e
"""
Attention layer extraction and benchmarking utilities.
Provides functions to:
- Extract attention layers from HuggingFace models
- Create proper inputs for attention forward passes
- Benchmark attention with different SDPA backends
"""
import torch
import torch.nn as nn
from typing import Tuple, Dict, Any, Optional
from transformers import PreTrainedModel
def extract_attention_layer(model: PreTrainedModel, layer_idx: int = 0) -> nn.Module:
"""
Extract the attention module from a loaded HuggingFace model.
Works for common architectures: Llama, Qwen, SmolLM, Mistral, etc.
These all follow the pattern: model.model.layers[i].self_attn
Args:
model: Loaded HuggingFace causal LM model
layer_idx: Which layer to extract (default: 0, first layer)
Returns:
The attention module (nn.Module)
"""
# Most decoder-only models follow this pattern
try:
attention = model.model.layers[layer_idx].self_attn
return attention
except AttributeError:
# Fallback for different architectures
if hasattr(model, 'transformer'):
# GPT-2 style
return model.transformer.h[layer_idx].attn
elif hasattr(model, 'gpt_neox'):
# GPT-NeoX style
return model.gpt_neox.layers[layer_idx].attention
else:
raise ValueError(
f"Could not extract attention layer from model type: {type(model).__name__}. "
"Supported architectures: Llama, Qwen, SmolLM, Mistral, GPT-2, GPT-NeoX"
)
def create_attention_inputs(
model: PreTrainedModel,
batch_size: int,
seq_len: int,
device: torch.device,
dtype: torch.dtype = torch.float16,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Create proper inputs for an attention layer forward pass.
Args:
model: The loaded model (to get hidden_size from config)
batch_size: Batch size
seq_len: Sequence length
device: Target device (cuda/cpu)
dtype: Data type (default: float16)
Returns:
Tuple of (hidden_states, position_ids)
"""
hidden_dim = model.config.hidden_size
# Hidden states: [batch, seq_len, hidden_dim]
hidden_states = torch.randn(
batch_size, seq_len, hidden_dim,
dtype=dtype, device=device
)
# Position IDs: [batch, seq_len]
position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
return hidden_states, position_ids
def create_causal_mask(
seq_len: int,
device: torch.device,
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""
Create a causal attention mask.
Args:
seq_len: Sequence length
device: Target device
dtype: Data type
Returns:
Causal mask tensor [1, 1, seq_len, seq_len]
"""
# Create lower triangular mask (1 = attend, 0 = mask)
mask = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=dtype))
# Convert to attention mask format (0 = attend, -inf = mask)
mask = mask.masked_fill(mask == 0, float('-inf'))
mask = mask.masked_fill(mask == 1, 0.0)
return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
def benchmark_attention_layer(
attention_layer: nn.Module,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
backend: str = "flash",
num_iterations: int = 10,
warmup_iterations: int = 3,
) -> Dict[str, Any]:
"""
Benchmark an attention layer with a specific SDPA backend.
Args:
attention_layer: The attention module to benchmark
hidden_states: Input hidden states [batch, seq, hidden_dim]
position_ids: Position IDs [batch, seq]
attention_mask: Optional attention mask
backend: Which SDPA backend ("math", "flash", "mem_efficient")
num_iterations: Number of timed iterations
warmup_iterations: Number of warmup iterations
Returns:
Dict with timing and memory results
"""
if not torch.cuda.is_available():
return {"error": "CUDA not available", "status": "error"}
# Map backend name to sdp_kernel flags
backend_flags = {
"math": (True, False, False), # enable_math, enable_flash, enable_mem_efficient
"flash": (False, True, False),
"mem_efficient": (False, False, True),
}
if backend not in backend_flags:
return {"error": f"Unknown backend: {backend}", "status": "error"}
enable_math, enable_flash, enable_mem_efficient = backend_flags[backend]
def run_attention():
"""Run attention with fallback for different call signatures."""
try:
# Try standard call with position_ids
return attention_layer(
hidden_states,
position_ids=position_ids,
)
except TypeError:
# Fallback: just hidden_states
return attention_layer(hidden_states)
try:
# Warmup
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
with torch.no_grad():
for _ in range(warmup_iterations):
_ = run_attention()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
# Timed runs
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
with torch.no_grad():
start.record()
for _ in range(num_iterations):
output = run_attention()
end.record()
torch.cuda.synchronize()
time_ms = start.elapsed_time(end) / num_iterations
memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
return {
"time_ms": round(time_ms, 3),
"memory_mb": round(memory_mb, 1),
"status": "success",
"backend": backend,
}
except Exception as e:
import traceback
error_msg = str(e)
tb = traceback.format_exc()
# Common error: Flash attention not available on certain GPUs
if "flash" in error_msg.lower() or "sm75" in error_msg.lower():
return {
"time_ms": None,
"memory_mb": None,
"status": f"unsupported: {error_msg[:80]}",
"backend": backend,
}
# Log detailed error for debugging
print(f"[benchmark_attention_layer] Error for {backend}: {error_msg}")
print(f"[benchmark_attention_layer] Traceback: {tb[:500]}")
return {
"time_ms": None,
"memory_mb": None,
"status": f"error: {error_msg[:80]}",
"backend": backend,
}
def create_kv_cache(
model: PreTrainedModel,
batch_size: int,
cache_len: int,
device: torch.device,
dtype: torch.dtype = torch.float16,
layer_idx: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Create a simulated KV cache for decode-phase benchmarking.
Args:
model: The loaded model (to get config)
batch_size: Batch size
cache_len: Number of cached tokens
device: Target device
dtype: Data type
layer_idx: Which layer (for future multi-layer support)
Returns:
Tuple of (key_cache, value_cache), each [batch, num_kv_heads, cache_len, head_dim]
"""
config = model.config
# Get number of KV heads (for GQA models)
if hasattr(config, 'num_key_value_heads'):
num_kv_heads = config.num_key_value_heads
else:
num_kv_heads = config.num_attention_heads
head_dim = config.hidden_size // config.num_attention_heads
# Create KV cache tensors
key_cache = torch.randn(
batch_size, num_kv_heads, cache_len, head_dim,
dtype=dtype, device=device
)
value_cache = torch.randn(
batch_size, num_kv_heads, cache_len, head_dim,
dtype=dtype, device=device
)
return key_cache, value_cache
def benchmark_decode_attention(
attention_layer: nn.Module,
model: PreTrainedModel,
kv_cache_len: int,
num_tokens: int = 10,
batch_size: int = 1,
backend: str = "flash",
num_iterations: int = 5,
) -> Dict[str, Any]:
"""
Benchmark decode-phase attention (single query attending to KV cache).
Args:
attention_layer: The attention module
model: The loaded model (for config)
kv_cache_len: Length of the KV cache (context)
num_tokens: Number of decode tokens to simulate
batch_size: Batch size
backend: SDPA backend to use
num_iterations: Iterations per token for averaging
Returns:
Dict with per-token timing and memory stats
"""
if not torch.cuda.is_available():
return {"error": "CUDA not available", "status": "error"}
device = torch.device("cuda")
dtype = torch.float16
# Create single-token query input
hidden_dim = model.config.hidden_size
query_hidden = torch.randn(batch_size, 1, hidden_dim, dtype=dtype, device=device)
# Create KV cache
key_cache, value_cache = create_kv_cache(
model, batch_size, kv_cache_len, device, dtype
)
# Position ID for the new token (at position = cache_len)
position_ids = torch.tensor([[kv_cache_len]], device=device).expand(batch_size, 1)
# Backend flags
backend_flags = {
"math": (True, False, False),
"flash": (False, True, False),
"mem_efficient": (False, False, True),
}
if backend not in backend_flags:
return {"error": f"Unknown backend: {backend}", "status": "error"}
enable_math, enable_flash, enable_mem_efficient = backend_flags[backend]
try:
# Note: For proper decode simulation, we'd need to pass past_key_values
# This is a simplified version that measures attention with asymmetric Q/KV sizes
# Real models handle this via the past_key_value mechanism
# Warmup
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
with torch.no_grad():
for _ in range(2):
_ = attention_layer(
query_hidden,
position_ids=position_ids,
)
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
# Time multiple tokens
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
with torch.no_grad():
start.record()
for _ in range(num_tokens * num_iterations):
output = attention_layer(
query_hidden,
position_ids=position_ids,
)
end.record()
torch.cuda.synchronize()
total_time_ms = start.elapsed_time(end)
time_per_token_ms = total_time_ms / (num_tokens * num_iterations)
memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
# Clean up
del query_hidden, key_cache, value_cache
torch.cuda.empty_cache()
return {
"time_ms_per_token": round(time_per_token_ms, 4),
"total_time_ms": round(total_time_ms / num_iterations, 3),
"memory_mb": round(memory_mb, 1),
"kv_cache_len": kv_cache_len,
"num_tokens": num_tokens,
"status": "success",
"backend": backend,
}
except Exception as e:
return {
"time_ms_per_token": None,
"total_time_ms": None,
"memory_mb": None,
"status": f"error: {str(e)[:80]}",
"backend": backend,
}
def get_model_attention_info(model: PreTrainedModel) -> Dict[str, Any]:
"""
Extract attention-related configuration from a model.
Returns:
Dict with num_heads, num_kv_heads, head_dim, hidden_size, etc.
"""
config = model.config
num_heads = config.num_attention_heads
# GQA models have separate num_key_value_heads
if hasattr(config, 'num_key_value_heads'):
num_kv_heads = config.num_key_value_heads
else:
num_kv_heads = num_heads
head_dim = config.hidden_size // num_heads
return {
"num_attention_heads": num_heads,
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"hidden_size": config.hidden_size,
"num_layers": config.num_hidden_layers,
"gqa_ratio": num_heads // num_kv_heads if num_kv_heads > 0 else 1,
"is_gqa": num_kv_heads < num_heads,
}