Spaces:
Running on Zero
Running on Zero
| """ | |
| Attention layer extraction and benchmarking utilities. | |
| Provides functions to: | |
| - Extract attention layers from HuggingFace models | |
| - Create proper inputs for attention forward passes | |
| - Benchmark attention with different SDPA backends | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from typing import Tuple, Dict, Any, Optional | |
| from transformers import PreTrainedModel | |
| def extract_attention_layer(model: PreTrainedModel, layer_idx: int = 0) -> nn.Module: | |
| """ | |
| Extract the attention module from a loaded HuggingFace model. | |
| Works for common architectures: Llama, Qwen, SmolLM, Mistral, etc. | |
| These all follow the pattern: model.model.layers[i].self_attn | |
| Args: | |
| model: Loaded HuggingFace causal LM model | |
| layer_idx: Which layer to extract (default: 0, first layer) | |
| Returns: | |
| The attention module (nn.Module) | |
| """ | |
| # Most decoder-only models follow this pattern | |
| try: | |
| attention = model.model.layers[layer_idx].self_attn | |
| return attention | |
| except AttributeError: | |
| # Fallback for different architectures | |
| if hasattr(model, 'transformer'): | |
| # GPT-2 style | |
| return model.transformer.h[layer_idx].attn | |
| elif hasattr(model, 'gpt_neox'): | |
| # GPT-NeoX style | |
| return model.gpt_neox.layers[layer_idx].attention | |
| else: | |
| raise ValueError( | |
| f"Could not extract attention layer from model type: {type(model).__name__}. " | |
| "Supported architectures: Llama, Qwen, SmolLM, Mistral, GPT-2, GPT-NeoX" | |
| ) | |
| def create_attention_inputs( | |
| model: PreTrainedModel, | |
| batch_size: int, | |
| seq_len: int, | |
| device: torch.device, | |
| dtype: torch.dtype = torch.float16, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Create proper inputs for an attention layer forward pass. | |
| Args: | |
| model: The loaded model (to get hidden_size from config) | |
| batch_size: Batch size | |
| seq_len: Sequence length | |
| device: Target device (cuda/cpu) | |
| dtype: Data type (default: float16) | |
| Returns: | |
| Tuple of (hidden_states, position_ids) | |
| """ | |
| hidden_dim = model.config.hidden_size | |
| # Hidden states: [batch, seq_len, hidden_dim] | |
| hidden_states = torch.randn( | |
| batch_size, seq_len, hidden_dim, | |
| dtype=dtype, device=device | |
| ) | |
| # Position IDs: [batch, seq_len] | |
| position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) | |
| return hidden_states, position_ids | |
| def create_causal_mask( | |
| seq_len: int, | |
| device: torch.device, | |
| dtype: torch.dtype = torch.float16, | |
| ) -> torch.Tensor: | |
| """ | |
| Create a causal attention mask. | |
| Args: | |
| seq_len: Sequence length | |
| device: Target device | |
| dtype: Data type | |
| Returns: | |
| Causal mask tensor [1, 1, seq_len, seq_len] | |
| """ | |
| # Create lower triangular mask (1 = attend, 0 = mask) | |
| mask = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=dtype)) | |
| # Convert to attention mask format (0 = attend, -inf = mask) | |
| mask = mask.masked_fill(mask == 0, float('-inf')) | |
| mask = mask.masked_fill(mask == 1, 0.0) | |
| return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len] | |
| def benchmark_attention_layer( | |
| attention_layer: nn.Module, | |
| hidden_states: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| backend: str = "flash", | |
| num_iterations: int = 10, | |
| warmup_iterations: int = 3, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Benchmark an attention layer with a specific SDPA backend. | |
| Args: | |
| attention_layer: The attention module to benchmark | |
| hidden_states: Input hidden states [batch, seq, hidden_dim] | |
| position_ids: Position IDs [batch, seq] | |
| attention_mask: Optional attention mask | |
| backend: Which SDPA backend ("math", "flash", "mem_efficient") | |
| num_iterations: Number of timed iterations | |
| warmup_iterations: Number of warmup iterations | |
| Returns: | |
| Dict with timing and memory results | |
| """ | |
| if not torch.cuda.is_available(): | |
| return {"error": "CUDA not available", "status": "error"} | |
| # Map backend name to sdp_kernel flags | |
| backend_flags = { | |
| "math": (True, False, False), # enable_math, enable_flash, enable_mem_efficient | |
| "flash": (False, True, False), | |
| "mem_efficient": (False, False, True), | |
| } | |
| if backend not in backend_flags: | |
| return {"error": f"Unknown backend: {backend}", "status": "error"} | |
| enable_math, enable_flash, enable_mem_efficient = backend_flags[backend] | |
| def run_attention(): | |
| """Run attention with fallback for different call signatures.""" | |
| try: | |
| # Try standard call with position_ids | |
| return attention_layer( | |
| hidden_states, | |
| position_ids=position_ids, | |
| ) | |
| except TypeError: | |
| # Fallback: just hidden_states | |
| return attention_layer(hidden_states) | |
| try: | |
| # Warmup | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash, | |
| enable_math=enable_math, | |
| enable_mem_efficient=enable_mem_efficient | |
| ): | |
| with torch.no_grad(): | |
| for _ in range(warmup_iterations): | |
| _ = run_attention() | |
| torch.cuda.synchronize() | |
| torch.cuda.reset_peak_memory_stats() | |
| # Timed runs | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash, | |
| enable_math=enable_math, | |
| enable_mem_efficient=enable_mem_efficient | |
| ): | |
| with torch.no_grad(): | |
| start.record() | |
| for _ in range(num_iterations): | |
| output = run_attention() | |
| end.record() | |
| torch.cuda.synchronize() | |
| time_ms = start.elapsed_time(end) / num_iterations | |
| memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) | |
| return { | |
| "time_ms": round(time_ms, 3), | |
| "memory_mb": round(memory_mb, 1), | |
| "status": "success", | |
| "backend": backend, | |
| } | |
| except Exception as e: | |
| import traceback | |
| error_msg = str(e) | |
| tb = traceback.format_exc() | |
| # Common error: Flash attention not available on certain GPUs | |
| if "flash" in error_msg.lower() or "sm75" in error_msg.lower(): | |
| return { | |
| "time_ms": None, | |
| "memory_mb": None, | |
| "status": f"unsupported: {error_msg[:80]}", | |
| "backend": backend, | |
| } | |
| # Log detailed error for debugging | |
| print(f"[benchmark_attention_layer] Error for {backend}: {error_msg}") | |
| print(f"[benchmark_attention_layer] Traceback: {tb[:500]}") | |
| return { | |
| "time_ms": None, | |
| "memory_mb": None, | |
| "status": f"error: {error_msg[:80]}", | |
| "backend": backend, | |
| } | |
| def create_kv_cache( | |
| model: PreTrainedModel, | |
| batch_size: int, | |
| cache_len: int, | |
| device: torch.device, | |
| dtype: torch.dtype = torch.float16, | |
| layer_idx: int = 0, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Create a simulated KV cache for decode-phase benchmarking. | |
| Args: | |
| model: The loaded model (to get config) | |
| batch_size: Batch size | |
| cache_len: Number of cached tokens | |
| device: Target device | |
| dtype: Data type | |
| layer_idx: Which layer (for future multi-layer support) | |
| Returns: | |
| Tuple of (key_cache, value_cache), each [batch, num_kv_heads, cache_len, head_dim] | |
| """ | |
| config = model.config | |
| # Get number of KV heads (for GQA models) | |
| if hasattr(config, 'num_key_value_heads'): | |
| num_kv_heads = config.num_key_value_heads | |
| else: | |
| num_kv_heads = config.num_attention_heads | |
| head_dim = config.hidden_size // config.num_attention_heads | |
| # Create KV cache tensors | |
| key_cache = torch.randn( | |
| batch_size, num_kv_heads, cache_len, head_dim, | |
| dtype=dtype, device=device | |
| ) | |
| value_cache = torch.randn( | |
| batch_size, num_kv_heads, cache_len, head_dim, | |
| dtype=dtype, device=device | |
| ) | |
| return key_cache, value_cache | |
| def benchmark_decode_attention( | |
| attention_layer: nn.Module, | |
| model: PreTrainedModel, | |
| kv_cache_len: int, | |
| num_tokens: int = 10, | |
| batch_size: int = 1, | |
| backend: str = "flash", | |
| num_iterations: int = 5, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Benchmark decode-phase attention (single query attending to KV cache). | |
| Args: | |
| attention_layer: The attention module | |
| model: The loaded model (for config) | |
| kv_cache_len: Length of the KV cache (context) | |
| num_tokens: Number of decode tokens to simulate | |
| batch_size: Batch size | |
| backend: SDPA backend to use | |
| num_iterations: Iterations per token for averaging | |
| Returns: | |
| Dict with per-token timing and memory stats | |
| """ | |
| if not torch.cuda.is_available(): | |
| return {"error": "CUDA not available", "status": "error"} | |
| device = torch.device("cuda") | |
| dtype = torch.float16 | |
| # Create single-token query input | |
| hidden_dim = model.config.hidden_size | |
| query_hidden = torch.randn(batch_size, 1, hidden_dim, dtype=dtype, device=device) | |
| # Create KV cache | |
| key_cache, value_cache = create_kv_cache( | |
| model, batch_size, kv_cache_len, device, dtype | |
| ) | |
| # Position ID for the new token (at position = cache_len) | |
| position_ids = torch.tensor([[kv_cache_len]], device=device).expand(batch_size, 1) | |
| # Backend flags | |
| backend_flags = { | |
| "math": (True, False, False), | |
| "flash": (False, True, False), | |
| "mem_efficient": (False, False, True), | |
| } | |
| if backend not in backend_flags: | |
| return {"error": f"Unknown backend: {backend}", "status": "error"} | |
| enable_math, enable_flash, enable_mem_efficient = backend_flags[backend] | |
| try: | |
| # Note: For proper decode simulation, we'd need to pass past_key_values | |
| # This is a simplified version that measures attention with asymmetric Q/KV sizes | |
| # Real models handle this via the past_key_value mechanism | |
| # Warmup | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash, | |
| enable_math=enable_math, | |
| enable_mem_efficient=enable_mem_efficient | |
| ): | |
| with torch.no_grad(): | |
| for _ in range(2): | |
| _ = attention_layer( | |
| query_hidden, | |
| position_ids=position_ids, | |
| ) | |
| torch.cuda.synchronize() | |
| torch.cuda.reset_peak_memory_stats() | |
| # Time multiple tokens | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash, | |
| enable_math=enable_math, | |
| enable_mem_efficient=enable_mem_efficient | |
| ): | |
| with torch.no_grad(): | |
| start.record() | |
| for _ in range(num_tokens * num_iterations): | |
| output = attention_layer( | |
| query_hidden, | |
| position_ids=position_ids, | |
| ) | |
| end.record() | |
| torch.cuda.synchronize() | |
| total_time_ms = start.elapsed_time(end) | |
| time_per_token_ms = total_time_ms / (num_tokens * num_iterations) | |
| memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) | |
| # Clean up | |
| del query_hidden, key_cache, value_cache | |
| torch.cuda.empty_cache() | |
| return { | |
| "time_ms_per_token": round(time_per_token_ms, 4), | |
| "total_time_ms": round(total_time_ms / num_iterations, 3), | |
| "memory_mb": round(memory_mb, 1), | |
| "kv_cache_len": kv_cache_len, | |
| "num_tokens": num_tokens, | |
| "status": "success", | |
| "backend": backend, | |
| } | |
| except Exception as e: | |
| return { | |
| "time_ms_per_token": None, | |
| "total_time_ms": None, | |
| "memory_mb": None, | |
| "status": f"error: {str(e)[:80]}", | |
| "backend": backend, | |
| } | |
| def get_model_attention_info(model: PreTrainedModel) -> Dict[str, Any]: | |
| """ | |
| Extract attention-related configuration from a model. | |
| Returns: | |
| Dict with num_heads, num_kv_heads, head_dim, hidden_size, etc. | |
| """ | |
| config = model.config | |
| num_heads = config.num_attention_heads | |
| # GQA models have separate num_key_value_heads | |
| if hasattr(config, 'num_key_value_heads'): | |
| num_kv_heads = config.num_key_value_heads | |
| else: | |
| num_kv_heads = num_heads | |
| head_dim = config.hidden_size // num_heads | |
| return { | |
| "num_attention_heads": num_heads, | |
| "num_kv_heads": num_kv_heads, | |
| "head_dim": head_dim, | |
| "hidden_size": config.hidden_size, | |
| "num_layers": config.num_hidden_layers, | |
| "gqa_ratio": num_heads // num_kv_heads if num_kv_heads > 0 else 1, | |
| "is_gqa": num_kv_heads < num_heads, | |
| } | |