""" Attention layer extraction and benchmarking utilities. Provides functions to: - Extract attention layers from HuggingFace models - Create proper inputs for attention forward passes - Benchmark attention with different SDPA backends """ import torch import torch.nn as nn from typing import Tuple, Dict, Any, Optional from transformers import PreTrainedModel def extract_attention_layer(model: PreTrainedModel, layer_idx: int = 0) -> nn.Module: """ Extract the attention module from a loaded HuggingFace model. Works for common architectures: Llama, Qwen, SmolLM, Mistral, etc. These all follow the pattern: model.model.layers[i].self_attn Args: model: Loaded HuggingFace causal LM model layer_idx: Which layer to extract (default: 0, first layer) Returns: The attention module (nn.Module) """ # Most decoder-only models follow this pattern try: attention = model.model.layers[layer_idx].self_attn return attention except AttributeError: # Fallback for different architectures if hasattr(model, 'transformer'): # GPT-2 style return model.transformer.h[layer_idx].attn elif hasattr(model, 'gpt_neox'): # GPT-NeoX style return model.gpt_neox.layers[layer_idx].attention else: raise ValueError( f"Could not extract attention layer from model type: {type(model).__name__}. " "Supported architectures: Llama, Qwen, SmolLM, Mistral, GPT-2, GPT-NeoX" ) def create_attention_inputs( model: PreTrainedModel, batch_size: int, seq_len: int, device: torch.device, dtype: torch.dtype = torch.float16, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Create proper inputs for an attention layer forward pass. Args: model: The loaded model (to get hidden_size from config) batch_size: Batch size seq_len: Sequence length device: Target device (cuda/cpu) dtype: Data type (default: float16) Returns: Tuple of (hidden_states, position_ids) """ hidden_dim = model.config.hidden_size # Hidden states: [batch, seq_len, hidden_dim] hidden_states = torch.randn( batch_size, seq_len, hidden_dim, dtype=dtype, device=device ) # Position IDs: [batch, seq_len] position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) return hidden_states, position_ids def create_causal_mask( seq_len: int, device: torch.device, dtype: torch.dtype = torch.float16, ) -> torch.Tensor: """ Create a causal attention mask. Args: seq_len: Sequence length device: Target device dtype: Data type Returns: Causal mask tensor [1, 1, seq_len, seq_len] """ # Create lower triangular mask (1 = attend, 0 = mask) mask = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=dtype)) # Convert to attention mask format (0 = attend, -inf = mask) mask = mask.masked_fill(mask == 0, float('-inf')) mask = mask.masked_fill(mask == 1, 0.0) return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len] def benchmark_attention_layer( attention_layer: nn.Module, hidden_states: torch.Tensor, position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, backend: str = "flash", num_iterations: int = 10, warmup_iterations: int = 3, ) -> Dict[str, Any]: """ Benchmark an attention layer with a specific SDPA backend. Args: attention_layer: The attention module to benchmark hidden_states: Input hidden states [batch, seq, hidden_dim] position_ids: Position IDs [batch, seq] attention_mask: Optional attention mask backend: Which SDPA backend ("math", "flash", "mem_efficient") num_iterations: Number of timed iterations warmup_iterations: Number of warmup iterations Returns: Dict with timing and memory results """ if not torch.cuda.is_available(): return {"error": "CUDA not available", "status": "error"} # Map backend name to sdp_kernel flags backend_flags = { "math": (True, False, False), # enable_math, enable_flash, enable_mem_efficient "flash": (False, True, False), "mem_efficient": (False, False, True), } if backend not in backend_flags: return {"error": f"Unknown backend: {backend}", "status": "error"} enable_math, enable_flash, enable_mem_efficient = backend_flags[backend] def run_attention(): """Run attention with fallback for different call signatures.""" try: # Try standard call with position_ids return attention_layer( hidden_states, position_ids=position_ids, ) except TypeError: # Fallback: just hidden_states return attention_layer(hidden_states) try: # Warmup with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): with torch.no_grad(): for _ in range(warmup_iterations): _ = run_attention() torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() # Timed runs start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): with torch.no_grad(): start.record() for _ in range(num_iterations): output = run_attention() end.record() torch.cuda.synchronize() time_ms = start.elapsed_time(end) / num_iterations memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) return { "time_ms": round(time_ms, 3), "memory_mb": round(memory_mb, 1), "status": "success", "backend": backend, } except Exception as e: import traceback error_msg = str(e) tb = traceback.format_exc() # Common error: Flash attention not available on certain GPUs if "flash" in error_msg.lower() or "sm75" in error_msg.lower(): return { "time_ms": None, "memory_mb": None, "status": f"unsupported: {error_msg[:80]}", "backend": backend, } # Log detailed error for debugging print(f"[benchmark_attention_layer] Error for {backend}: {error_msg}") print(f"[benchmark_attention_layer] Traceback: {tb[:500]}") return { "time_ms": None, "memory_mb": None, "status": f"error: {error_msg[:80]}", "backend": backend, } def create_kv_cache( model: PreTrainedModel, batch_size: int, cache_len: int, device: torch.device, dtype: torch.dtype = torch.float16, layer_idx: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Create a simulated KV cache for decode-phase benchmarking. Args: model: The loaded model (to get config) batch_size: Batch size cache_len: Number of cached tokens device: Target device dtype: Data type layer_idx: Which layer (for future multi-layer support) Returns: Tuple of (key_cache, value_cache), each [batch, num_kv_heads, cache_len, head_dim] """ config = model.config # Get number of KV heads (for GQA models) if hasattr(config, 'num_key_value_heads'): num_kv_heads = config.num_key_value_heads else: num_kv_heads = config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads # Create KV cache tensors key_cache = torch.randn( batch_size, num_kv_heads, cache_len, head_dim, dtype=dtype, device=device ) value_cache = torch.randn( batch_size, num_kv_heads, cache_len, head_dim, dtype=dtype, device=device ) return key_cache, value_cache def benchmark_decode_attention( attention_layer: nn.Module, model: PreTrainedModel, kv_cache_len: int, num_tokens: int = 10, batch_size: int = 1, backend: str = "flash", num_iterations: int = 5, ) -> Dict[str, Any]: """ Benchmark decode-phase attention (single query attending to KV cache). Args: attention_layer: The attention module model: The loaded model (for config) kv_cache_len: Length of the KV cache (context) num_tokens: Number of decode tokens to simulate batch_size: Batch size backend: SDPA backend to use num_iterations: Iterations per token for averaging Returns: Dict with per-token timing and memory stats """ if not torch.cuda.is_available(): return {"error": "CUDA not available", "status": "error"} device = torch.device("cuda") dtype = torch.float16 # Create single-token query input hidden_dim = model.config.hidden_size query_hidden = torch.randn(batch_size, 1, hidden_dim, dtype=dtype, device=device) # Create KV cache key_cache, value_cache = create_kv_cache( model, batch_size, kv_cache_len, device, dtype ) # Position ID for the new token (at position = cache_len) position_ids = torch.tensor([[kv_cache_len]], device=device).expand(batch_size, 1) # Backend flags backend_flags = { "math": (True, False, False), "flash": (False, True, False), "mem_efficient": (False, False, True), } if backend not in backend_flags: return {"error": f"Unknown backend: {backend}", "status": "error"} enable_math, enable_flash, enable_mem_efficient = backend_flags[backend] try: # Note: For proper decode simulation, we'd need to pass past_key_values # This is a simplified version that measures attention with asymmetric Q/KV sizes # Real models handle this via the past_key_value mechanism # Warmup with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): with torch.no_grad(): for _ in range(2): _ = attention_layer( query_hidden, position_ids=position_ids, ) torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() # Time multiple tokens start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): with torch.no_grad(): start.record() for _ in range(num_tokens * num_iterations): output = attention_layer( query_hidden, position_ids=position_ids, ) end.record() torch.cuda.synchronize() total_time_ms = start.elapsed_time(end) time_per_token_ms = total_time_ms / (num_tokens * num_iterations) memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) # Clean up del query_hidden, key_cache, value_cache torch.cuda.empty_cache() return { "time_ms_per_token": round(time_per_token_ms, 4), "total_time_ms": round(total_time_ms / num_iterations, 3), "memory_mb": round(memory_mb, 1), "kv_cache_len": kv_cache_len, "num_tokens": num_tokens, "status": "success", "backend": backend, } except Exception as e: return { "time_ms_per_token": None, "total_time_ms": None, "memory_mb": None, "status": f"error: {str(e)[:80]}", "backend": backend, } def get_model_attention_info(model: PreTrainedModel) -> Dict[str, Any]: """ Extract attention-related configuration from a model. Returns: Dict with num_heads, num_kv_heads, head_dim, hidden_size, etc. """ config = model.config num_heads = config.num_attention_heads # GQA models have separate num_key_value_heads if hasattr(config, 'num_key_value_heads'): num_kv_heads = config.num_key_value_heads else: num_kv_heads = num_heads head_dim = config.hidden_size // num_heads return { "num_attention_heads": num_heads, "num_kv_heads": num_kv_heads, "head_dim": head_dim, "hidden_size": config.hidden_size, "num_layers": config.num_hidden_layers, "gqa_ratio": num_heads // num_kv_heads if num_kv_heads > 0 else 1, "is_gqa": num_kv_heads < num_heads, }