Spaces:
Running on Zero
Running on Zero
| """ | |
| Prefill vs Decode phase comparison module. | |
| Demonstrates the key difference between: | |
| - Prefill: Process entire prompt in parallel (N² attention complexity) | |
| - Decode: Generate one token at a time (N attention per token, but sequential) | |
| Uses REAL HuggingFace model attention layers for accurate benchmarking. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| from .constants import MODEL_CONFIGS, ATTENTION_BACKENDS | |
| from .models import load_model | |
| from .attention_utils import ( | |
| extract_attention_layer, | |
| create_attention_inputs, | |
| benchmark_attention_layer, | |
| get_model_attention_info, | |
| ) | |
| def get_real_model_config(model_name: str) -> dict: | |
| """ | |
| Load model and extract ACTUAL config values from model.config. | |
| This function ensures we use real model architecture values, | |
| NOT hardcoded constants from MODEL_CONFIGS. | |
| Args: | |
| model_name: Key from MODEL_CONFIGS (e.g., "SmolLM2-360M") | |
| Returns: | |
| Dict with real model configuration values | |
| """ | |
| model = load_model(model_name) | |
| config = model.config | |
| # Extract values directly from model.config | |
| num_heads = config.num_attention_heads | |
| num_kv_heads = getattr(config, 'num_key_value_heads', num_heads) | |
| head_dim = config.hidden_size // num_heads | |
| return { | |
| "num_layers": config.num_hidden_layers, | |
| "num_heads": num_heads, | |
| "num_kv_heads": num_kv_heads, | |
| "head_dim": head_dim, | |
| "hidden_size": config.hidden_size, | |
| "model_type": getattr(config, 'model_type', 'unknown'), | |
| "gqa_ratio": num_heads // num_kv_heads if num_kv_heads > 0 else 1, | |
| } | |
| def run_prefill_with_real_model( | |
| model, | |
| attention_layer, | |
| seq_len: int, | |
| batch_size: int = 1, | |
| num_iterations: int = 5, | |
| use_flash: bool = True, | |
| ) -> dict: | |
| """ | |
| Run prefill phase attention using a REAL model's attention layer. | |
| Prefill processes the entire prompt at once: | |
| - Hidden states have shape [batch, seq_len, hidden_dim] | |
| - Full N×N attention matrix computed via the real attention layer | |
| Args: | |
| model: Loaded HuggingFace model | |
| attention_layer: Extracted attention module | |
| seq_len: Sequence length | |
| batch_size: Batch size | |
| num_iterations: Number of timed iterations | |
| use_flash: Whether to use FlashAttention backend | |
| Returns: | |
| Dict with timing and memory stats | |
| """ | |
| if not torch.cuda.is_available(): | |
| return {"error": "CUDA not available"} | |
| device = torch.device("cuda") | |
| dtype = torch.float16 | |
| # Create proper inputs for the attention layer | |
| hidden_states, position_ids = create_attention_inputs( | |
| model, batch_size, seq_len, device, dtype | |
| ) | |
| # Backend configuration | |
| backend = "flash" if use_flash else "math" | |
| # Run benchmark using the utility function | |
| result = benchmark_attention_layer( | |
| attention_layer=attention_layer, | |
| hidden_states=hidden_states, | |
| position_ids=position_ids, | |
| backend=backend, | |
| num_iterations=num_iterations, | |
| warmup_iterations=2, | |
| ) | |
| # Clean up | |
| del hidden_states, position_ids | |
| torch.cuda.empty_cache() | |
| # Add phase info to result | |
| result["seq_len"] = seq_len | |
| result["phase"] = "prefill" | |
| result["using_real_model"] = True | |
| return result | |
| def run_prefill_benchmark( | |
| model_name: str, | |
| seq_len: int, | |
| batch_size: int = 1, | |
| num_iterations: int = 10, | |
| use_flash: bool = True, | |
| ) -> dict: | |
| """ | |
| Benchmark prefill phase using F.scaled_dot_product_attention with REAL model dimensions. | |
| This function uses the model's actual configuration (from model.config) to create | |
| properly-sized Q, K, V tensors, then benchmarks the SDPA operation directly. | |
| This is more reliable than calling attention layer forward() methods. | |
| Args: | |
| model_name: Key from MODEL_CONFIGS (model will be loaded to get real config) | |
| seq_len: Sequence length (prompt tokens) | |
| batch_size: Batch size | |
| num_iterations: Number of timed iterations | |
| use_flash: Whether to use FlashAttention backend | |
| Returns: | |
| Dict with time_ms, memory_mb, and status | |
| """ | |
| if not torch.cuda.is_available(): | |
| return {"time_ms": 0, "memory_mb": 0, "status": "error: CUDA not available"} | |
| device = torch.device("cuda") | |
| dtype = torch.float16 | |
| try: | |
| # Get REAL config from loaded model | |
| real_config = get_real_model_config(model_name) | |
| num_heads = real_config["num_heads"] | |
| head_dim = real_config["head_dim"] | |
| # Create Q, K, V tensors with REAL model dimensions | |
| # Shape: [batch, num_heads, seq_len, head_dim] | |
| Q = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) | |
| K = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) | |
| V = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) | |
| # Set backend flags | |
| if use_flash: | |
| enable_math, enable_flash, enable_mem_efficient = False, True, False | |
| else: | |
| enable_math, enable_flash, enable_mem_efficient = True, False, False | |
| # Warmup | |
| for _ in range(3): | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash, | |
| enable_math=enable_math, | |
| enable_mem_efficient=enable_mem_efficient | |
| ): | |
| _ = F.scaled_dot_product_attention(Q, K, V, is_causal=True) | |
| torch.cuda.synchronize() | |
| torch.cuda.reset_peak_memory_stats() | |
| # Timed runs | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| for _ in range(num_iterations): | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash, | |
| enable_math=enable_math, | |
| enable_mem_efficient=enable_mem_efficient | |
| ): | |
| output = F.scaled_dot_product_attention(Q, K, V, is_causal=True) | |
| end.record() | |
| torch.cuda.synchronize() | |
| time_ms = start.elapsed_time(end) / num_iterations | |
| memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) | |
| # Cleanup | |
| del Q, K, V, output | |
| torch.cuda.empty_cache() | |
| return { | |
| "time_ms": round(time_ms, 3), | |
| "memory_mb": round(memory_mb, 1), | |
| "seq_len": seq_len, | |
| "phase": "prefill", | |
| "backend": "flash" if use_flash else "math", | |
| "num_heads": num_heads, | |
| "head_dim": head_dim, | |
| "status": "success", | |
| "using_real_config": True, | |
| } | |
| except Exception as e: | |
| return { | |
| "time_ms": 0, | |
| "memory_mb": 0, | |
| "status": f"error: {str(e)[:100]}", | |
| "phase": "prefill", | |
| } | |
| def run_decode_with_real_model( | |
| model, | |
| attention_layer, | |
| kv_cache_len: int, | |
| num_tokens: int = 10, | |
| batch_size: int = 1, | |
| num_iterations: int = 3, | |
| use_flash: bool = True, | |
| ) -> dict: | |
| """ | |
| Run decode phase attention using a REAL model's attention layer. | |
| Decode generates one token at a time: | |
| - Single query token attending to all past keys/values | |
| - Simulates the memory-bound decode phase | |
| Args: | |
| model: Loaded HuggingFace model | |
| attention_layer: Extracted attention module | |
| kv_cache_len: Length of the KV cache (context) | |
| num_tokens: Number of tokens to simulate generating | |
| batch_size: Batch size | |
| num_iterations: Iterations for averaging | |
| use_flash: Whether to use FlashAttention backend | |
| Returns: | |
| Dict with per-token timing and memory stats | |
| """ | |
| if not torch.cuda.is_available(): | |
| return {"error": "CUDA not available"} | |
| device = torch.device("cuda") | |
| dtype = torch.float16 | |
| # Create single-token query input (simulating decode) | |
| hidden_dim = model.config.hidden_size | |
| query_hidden = torch.randn(batch_size, 1, hidden_dim, dtype=dtype, device=device) | |
| position_ids = torch.tensor([[kv_cache_len]], device=device).expand(batch_size, 1) | |
| # Backend flags | |
| if use_flash: | |
| enable_math, enable_flash, enable_mem_efficient = False, True, False | |
| else: | |
| enable_math, enable_flash, enable_mem_efficient = True, False, False | |
| try: | |
| # Warmup | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash, | |
| enable_math=enable_math, | |
| enable_mem_efficient=enable_mem_efficient | |
| ): | |
| with torch.no_grad(): | |
| for _ in range(2): | |
| _ = attention_layer(query_hidden, position_ids=position_ids) | |
| torch.cuda.synchronize() | |
| torch.cuda.reset_peak_memory_stats() | |
| # Time multiple tokens | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash, | |
| enable_math=enable_math, | |
| enable_mem_efficient=enable_mem_efficient | |
| ): | |
| with torch.no_grad(): | |
| start.record() | |
| for _ in range(num_tokens * num_iterations): | |
| output = attention_layer(query_hidden, position_ids=position_ids) | |
| end.record() | |
| torch.cuda.synchronize() | |
| total_time_ms = start.elapsed_time(end) | |
| time_per_token_ms = total_time_ms / (num_tokens * num_iterations) | |
| memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) | |
| # Clean up | |
| del query_hidden | |
| torch.cuda.empty_cache() | |
| return { | |
| "time_ms_per_token": round(time_per_token_ms, 4), | |
| "total_time_ms": round(total_time_ms / num_iterations, 3), | |
| "memory_mb": round(memory_mb, 1), | |
| "kv_cache_len": kv_cache_len, | |
| "num_tokens": num_tokens, | |
| "phase": "decode", | |
| "using_real_model": True, | |
| "status": "success", | |
| } | |
| except Exception as e: | |
| return { | |
| "time_ms_per_token": 0, | |
| "total_time_ms": 0, | |
| "memory_mb": 0, | |
| "kv_cache_len": kv_cache_len, | |
| "num_tokens": num_tokens, | |
| "phase": "decode", | |
| "status": f"error: {str(e)[:80]}", | |
| } | |
| def run_decode_benchmark( | |
| model_name: str, | |
| kv_cache_len: int, | |
| num_tokens: int = 10, | |
| batch_size: int = 1, | |
| num_iterations: int = 5, | |
| use_flash: bool = True, | |
| ) -> dict: | |
| """ | |
| Benchmark decode phase using F.scaled_dot_product_attention with REAL model dimensions. | |
| Properly simulates decode by: | |
| - Creating single query token (Q with seq_len=1) | |
| - Creating KV cache tensors with kv_cache_len tokens | |
| - Handling GQA by expanding KV heads to match Q heads | |
| Args: | |
| model_name: Key from MODEL_CONFIGS (model will be loaded to get real config) | |
| kv_cache_len: Length of KV cache (context length) | |
| num_tokens: Number of decode tokens to simulate | |
| batch_size: Batch size | |
| num_iterations: Iterations for timing | |
| use_flash: Whether to use FlashAttention backend | |
| Returns: | |
| Dict with time_ms_per_token, memory_mb, and status | |
| """ | |
| if not torch.cuda.is_available(): | |
| return {"time_ms_per_token": 0, "memory_mb": 0, "status": "error: CUDA not available"} | |
| device = torch.device("cuda") | |
| dtype = torch.float16 | |
| try: | |
| # Get REAL config from loaded model | |
| real_config = get_real_model_config(model_name) | |
| num_heads = real_config["num_heads"] | |
| num_kv_heads = real_config["num_kv_heads"] | |
| head_dim = real_config["head_dim"] | |
| # Single query token: [batch, num_heads, 1, head_dim] | |
| Q = torch.randn(batch_size, num_heads, 1, head_dim, dtype=dtype, device=device) | |
| # KV cache with real model's KV head count: [batch, num_kv_heads, kv_cache_len, head_dim] | |
| K_cache = torch.randn(batch_size, num_kv_heads, kv_cache_len, head_dim, dtype=dtype, device=device) | |
| V_cache = torch.randn(batch_size, num_kv_heads, kv_cache_len, head_dim, dtype=dtype, device=device) | |
| # Handle GQA: expand KV heads to match Q heads if needed | |
| if num_kv_heads < num_heads: | |
| repeat_factor = num_heads // num_kv_heads | |
| K_cache = K_cache.repeat_interleave(repeat_factor, dim=1) | |
| V_cache = V_cache.repeat_interleave(repeat_factor, dim=1) | |
| # Set backend flags | |
| if use_flash: | |
| enable_math, enable_flash_flag, enable_mem_efficient = False, True, False | |
| else: | |
| enable_math, enable_flash_flag, enable_mem_efficient = True, False, False | |
| # Warmup | |
| for _ in range(3): | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash_flag, | |
| enable_math=enable_math, | |
| enable_mem_efficient=enable_mem_efficient | |
| ): | |
| _ = F.scaled_dot_product_attention(Q, K_cache, V_cache) | |
| torch.cuda.synchronize() | |
| torch.cuda.reset_peak_memory_stats() | |
| # Timed runs - simulate generating num_tokens | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| for _ in range(num_tokens * num_iterations): | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash_flag, | |
| enable_math=enable_math, | |
| enable_mem_efficient=enable_mem_efficient | |
| ): | |
| output = F.scaled_dot_product_attention(Q, K_cache, V_cache) | |
| end.record() | |
| torch.cuda.synchronize() | |
| total_time_ms = start.elapsed_time(end) | |
| time_per_token_ms = total_time_ms / (num_tokens * num_iterations) | |
| memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) | |
| # Cleanup | |
| del Q, K_cache, V_cache, output | |
| torch.cuda.empty_cache() | |
| return { | |
| "time_ms_per_token": round(time_per_token_ms, 4), | |
| "total_time_ms": round(total_time_ms / num_iterations, 3), | |
| "memory_mb": round(memory_mb, 1), | |
| "kv_cache_len": kv_cache_len, | |
| "num_tokens": num_tokens, | |
| "phase": "decode", | |
| "backend": "flash" if use_flash else "math", | |
| "num_heads": num_heads, | |
| "num_kv_heads": num_kv_heads, | |
| "head_dim": head_dim, | |
| "status": "success", | |
| "using_real_config": True, | |
| } | |
| except Exception as e: | |
| return { | |
| "time_ms_per_token": 0, | |
| "total_time_ms": 0, | |
| "memory_mb": 0, | |
| "kv_cache_len": kv_cache_len, | |
| "num_tokens": num_tokens, | |
| "phase": "decode", | |
| "status": f"error: {str(e)[:100]}", | |
| } | |
| # Legacy function kept for backwards compatibility | |
| def simulate_prefill_attention( | |
| batch_size: int, | |
| num_heads: int, | |
| seq_len: int, | |
| head_dim: int, | |
| num_iterations: int = 5, | |
| use_flash: bool = True, | |
| ) -> dict: | |
| """ | |
| Legacy: Simulate prefill phase attention with random tensors. | |
| Use run_prefill_with_real_model() for real model benchmarks. | |
| """ | |
| if not torch.cuda.is_available(): | |
| return {"error": "CUDA not available"} | |
| device = torch.device("cuda") | |
| dtype = torch.float16 | |
| Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) | |
| K = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) | |
| V = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) | |
| if use_flash: | |
| enable_math, enable_flash_flag, enable_mem_efficient = False, True, False | |
| else: | |
| enable_math, enable_flash_flag, enable_mem_efficient = True, False, False | |
| # Warmup | |
| for _ in range(2): | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash_flag, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient | |
| ): | |
| try: | |
| _ = F.scaled_dot_product_attention(Q, K, V) | |
| except Exception: | |
| pass | |
| torch.cuda.synchronize() | |
| torch.cuda.reset_peak_memory_stats() | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| for _ in range(num_iterations): | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash_flag, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient | |
| ): | |
| try: | |
| output = F.scaled_dot_product_attention(Q, K, V) | |
| except Exception: | |
| output = F.scaled_dot_product_attention(Q, K, V) | |
| end.record() | |
| torch.cuda.synchronize() | |
| total_time_ms = start.elapsed_time(end) | |
| avg_time_ms = total_time_ms / num_iterations | |
| peak_memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) | |
| del Q, K, V, output | |
| torch.cuda.empty_cache() | |
| return { | |
| "time_ms": avg_time_ms, | |
| "memory_mb": peak_memory_mb, | |
| "seq_len": seq_len, | |
| "phase": "prefill", | |
| } | |
| # Legacy function kept for backwards compatibility | |
| def simulate_decode_attention( | |
| batch_size: int, | |
| num_heads: int, | |
| kv_cache_len: int, | |
| head_dim: int, | |
| num_tokens: int = 10, | |
| use_flash: bool = True, | |
| ) -> dict: | |
| """ | |
| Legacy: Simulate decode phase attention with random tensors. | |
| Use run_decode_with_real_model() for real model benchmarks. | |
| """ | |
| if not torch.cuda.is_available(): | |
| return {"error": "CUDA not available"} | |
| device = torch.device("cuda") | |
| dtype = torch.float16 | |
| K_cache = torch.randn(batch_size, num_heads, kv_cache_len, head_dim, device=device, dtype=dtype) | |
| V_cache = torch.randn(batch_size, num_heads, kv_cache_len, head_dim, device=device, dtype=dtype) | |
| Q = torch.randn(batch_size, num_heads, 1, head_dim, device=device, dtype=dtype) | |
| if use_flash: | |
| enable_math, enable_flash_flag, enable_mem_efficient = False, True, False | |
| else: | |
| enable_math, enable_flash_flag, enable_mem_efficient = True, False, False | |
| # Warmup | |
| for _ in range(2): | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash_flag, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient | |
| ): | |
| try: | |
| _ = F.scaled_dot_product_attention(Q, K_cache, V_cache) | |
| except Exception: | |
| pass | |
| torch.cuda.synchronize() | |
| torch.cuda.reset_peak_memory_stats() | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| for _ in range(num_tokens): | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash_flag, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient | |
| ): | |
| try: | |
| output = F.scaled_dot_product_attention(Q, K_cache, V_cache) | |
| except Exception: | |
| output = F.scaled_dot_product_attention(Q, K_cache, V_cache) | |
| end.record() | |
| torch.cuda.synchronize() | |
| total_time_ms = start.elapsed_time(end) | |
| avg_time_per_token_ms = total_time_ms / num_tokens | |
| peak_memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) | |
| del Q, K_cache, V_cache, output | |
| torch.cuda.empty_cache() | |
| return { | |
| "time_ms_per_token": avg_time_per_token_ms, | |
| "total_time_ms": total_time_ms, | |
| "memory_mb": peak_memory_mb, | |
| "kv_cache_len": kv_cache_len, | |
| "num_tokens": num_tokens, | |
| "phase": "decode", | |
| } | |
| def run_prefill_decode_comparison( | |
| model_name: str, | |
| context_length: int, | |
| decode_tokens: int = 32, | |
| ) -> tuple: | |
| """ | |
| Run full comparison between prefill and decode phases using REAL HuggingFace model. | |
| Uses F.scaled_dot_product_attention with real model dimensions for reliable benchmarking. | |
| All config values come from model.config, not constants. | |
| Returns results dict, comparison chart, KV cache chart, and insight text. | |
| """ | |
| if model_name not in MODEL_CONFIGS: | |
| return {"error": f"Unknown model: {model_name}"}, None, None, "Error: Unknown model" | |
| # Get REAL config from model.config (not constants) | |
| try: | |
| real_config = get_real_model_config(model_name) | |
| except Exception as e: | |
| return {"error": f"Failed to load model: {str(e)[:50]}"}, None, None, f"Error: {str(e)[:50]}" | |
| results = { | |
| "model": model_name, | |
| "context_length": context_length, | |
| "decode_tokens": decode_tokens, | |
| "real_config": real_config, | |
| "using_real_config": True, | |
| } | |
| # Run prefill benchmarks using SDPA with REAL model dimensions | |
| prefill_flash = run_prefill_benchmark( | |
| model_name=model_name, | |
| seq_len=context_length, | |
| batch_size=1, | |
| use_flash=True, | |
| ) | |
| prefill_math = run_prefill_benchmark( | |
| model_name=model_name, | |
| seq_len=context_length, | |
| batch_size=1, | |
| use_flash=False, | |
| ) | |
| # Run decode benchmarks using SDPA with proper KV cache simulation | |
| decode_flash = run_decode_benchmark( | |
| model_name=model_name, | |
| kv_cache_len=context_length, | |
| num_tokens=decode_tokens, | |
| batch_size=1, | |
| use_flash=True, | |
| ) | |
| decode_math = run_decode_benchmark( | |
| model_name=model_name, | |
| kv_cache_len=context_length, | |
| num_tokens=decode_tokens, | |
| batch_size=1, | |
| use_flash=False, | |
| ) | |
| results["prefill"] = { | |
| "flash": prefill_flash, | |
| "math": prefill_math, | |
| } | |
| results["decode"] = { | |
| "flash": decode_flash, | |
| "math": decode_math, | |
| } | |
| # Add model info for display | |
| results["model_info"] = { | |
| "num_heads": real_config["num_heads"], | |
| "num_kv_heads": real_config["num_kv_heads"], | |
| "head_dim": real_config["head_dim"], | |
| "num_layers": real_config["num_layers"], | |
| "gqa_ratio": real_config["gqa_ratio"], | |
| } | |
| # Create comparison chart | |
| comparison_chart = create_comparison_chart(results) | |
| # Create KV cache growth chart using REAL model config | |
| kv_cache_chart = create_kv_cache_chart(model_name, context_length, decode_tokens) | |
| # Generate insight | |
| insight = generate_phase_insight(results) | |
| # Add real model indicator to insight | |
| if results.get("using_real_config"): | |
| model_indicator = f"\n\n---\n\n*Benchmarked using real **{model_name}** config ({real_config['num_heads']} heads, {real_config['head_dim']}d, GQA {real_config['gqa_ratio']}:1)*" | |
| insight = insight + model_indicator | |
| return results, comparison_chart, kv_cache_chart, insight | |
| def create_comparison_chart(results: dict) -> go.Figure: | |
| """Create bar chart comparing prefill vs decode timing.""" | |
| prefill_flash = results["prefill"]["flash"] | |
| prefill_math = results["prefill"]["math"] | |
| decode_flash = results["decode"]["flash"] | |
| decode_math = results["decode"]["math"] | |
| # Helper to safely get numeric value (handles None) | |
| def safe_get(d, key, default=0): | |
| val = d.get(key, default) | |
| return val if val is not None else default | |
| fig = make_subplots( | |
| rows=1, cols=2, | |
| subplot_titles=("<b>Prefill Time</b> (Full Prompt)", "<b>Decode Time</b> (Per Token)"), | |
| horizontal_spacing=0.15, | |
| vertical_spacing=0.15, | |
| ) | |
| # Get max values for proper y-axis scaling with headroom for labels | |
| prefill_math_time = safe_get(prefill_math, "time_ms", 0) | |
| prefill_flash_time = safe_get(prefill_flash, "time_ms", 0) | |
| decode_math_time = safe_get(decode_math, "time_ms_per_token", 0) | |
| decode_flash_time = safe_get(decode_flash, "time_ms_per_token", 0) | |
| prefill_max = max(prefill_math_time, prefill_flash_time) | |
| decode_max = max(decode_math_time, decode_flash_time) | |
| # Prefill comparison | |
| fig.add_trace( | |
| go.Bar( | |
| x=["Math<br>(Standard)", "Flash<br>Attention"], | |
| y=[prefill_math_time, prefill_flash_time], | |
| marker_color=["#ef4444", "#22c55e"], | |
| text=[f"<b>{prefill_math_time:.2f}ms</b>", f"<b>{prefill_flash_time:.2f}ms</b>"], | |
| textposition="inside", | |
| textangle=0, | |
| insidetextanchor="middle", | |
| textfont=dict(color="white", size=12), | |
| name="Prefill", | |
| showlegend=False, | |
| ), | |
| row=1, col=1 | |
| ) | |
| # Decode comparison (per token) | |
| fig.add_trace( | |
| go.Bar( | |
| x=["Math<br>(Standard)", "Flash<br>Attention"], | |
| y=[decode_math_time, decode_flash_time], | |
| marker_color=["#ef4444", "#22c55e"], | |
| text=[f"<b>{decode_math_time:.3f}ms</b>", f"<b>{decode_flash_time:.3f}ms</b>"], | |
| textposition="inside", | |
| textangle=0, | |
| insidetextanchor="middle", | |
| textfont=dict(color="white", size=12), | |
| name="Decode", | |
| showlegend=False, | |
| ), | |
| row=1, col=2 | |
| ) | |
| # Calculate speedups | |
| if prefill_math_time > 0 and prefill_flash_time > 0: | |
| prefill_speedup = prefill_math_time / prefill_flash_time | |
| else: | |
| prefill_speedup = 1.0 | |
| if decode_math_time > 0 and decode_flash_time > 0: | |
| decode_speedup = decode_math_time / decode_flash_time | |
| else: | |
| decode_speedup = 1.0 | |
| fig.update_layout( | |
| title=dict( | |
| text=f"<b>Prefill vs Decode: FlashAttention Speedup</b><br>" | |
| f"<span style='font-size:13px;color:#16a34a'>" | |
| f"Prefill: {prefill_speedup:.1f}× faster | Decode: {decode_speedup:.1f}× faster</span>", | |
| x=0.5, | |
| font=dict(size=15), | |
| ), | |
| height=380, | |
| margin=dict(l=60, r=40, t=100, b=60), | |
| yaxis_title="Time (ms)", | |
| yaxis2_title="Time (ms)", | |
| ) | |
| # Add more y-axis headroom | |
| fig.update_yaxes(range=[0, prefill_max * 1.15], row=1, col=1) | |
| fig.update_yaxes(range=[0, decode_max * 1.15], row=1, col=2) | |
| return fig | |
| def create_kv_cache_chart(model_name: str, context_length: int, decode_tokens: int) -> go.Figure: | |
| """ | |
| Create chart showing KV cache growth during generation. | |
| Uses REAL model config values from model.config, not constants. | |
| Args: | |
| model_name: Model name to load config from | |
| context_length: Number of context tokens (prefill) | |
| decode_tokens: Number of decode tokens to generate | |
| Returns: | |
| Plotly figure showing KV cache growth | |
| """ | |
| # Get REAL config from loaded model (no constants!) | |
| real_config = get_real_model_config(model_name) | |
| num_kv_heads = real_config["num_kv_heads"] | |
| head_dim = real_config["head_dim"] | |
| num_layers = real_config["num_layers"] | |
| # Calculate KV cache size at each step | |
| # KV cache per layer: 2 (K+V) × kv_heads × head_dim × 2 (FP16 bytes) | |
| bytes_per_token_per_layer = 2 * num_kv_heads * head_dim * 2 | |
| total_bytes_per_token = bytes_per_token_per_layer * num_layers | |
| # Generate sequence of token counts | |
| token_counts = list(range(0, context_length + decode_tokens + 1, max(1, (context_length + decode_tokens) // 50))) | |
| if token_counts[-1] != context_length + decode_tokens: | |
| token_counts.append(context_length + decode_tokens) | |
| # Calculate cache sizes in MB | |
| cache_sizes_mb = [(t * total_bytes_per_token) / (1024 * 1024) for t in token_counts] | |
| fig = go.Figure() | |
| # Prefill region (0 to context_length) | |
| prefill_tokens = [t for t in token_counts if t <= context_length] | |
| prefill_sizes = [(t * total_bytes_per_token) / (1024 * 1024) for t in prefill_tokens] | |
| fig.add_trace(go.Scatter( | |
| x=prefill_tokens, | |
| y=prefill_sizes, | |
| mode="lines", | |
| name="Prefill Phase", | |
| fill="tozeroy", | |
| line=dict(color="#3b82f6", width=2), | |
| fillcolor="rgba(59, 130, 246, 0.3)", | |
| )) | |
| # Decode region (context_length to end) | |
| decode_tokens_list = [t for t in token_counts if t >= context_length] | |
| decode_sizes = [(t * total_bytes_per_token) / (1024 * 1024) for t in decode_tokens_list] | |
| fig.add_trace(go.Scatter( | |
| x=decode_tokens_list, | |
| y=decode_sizes, | |
| mode="lines", | |
| name="Decode Phase", | |
| fill="tozeroy", | |
| line=dict(color="#22c55e", width=2), | |
| fillcolor="rgba(34, 197, 94, 0.3)", | |
| )) | |
| # Add vertical line at context boundary | |
| cache_at_context = (context_length * total_bytes_per_token) / (1024 * 1024) | |
| fig.add_vline( | |
| x=context_length, | |
| line_dash="dash", | |
| line_color="rgba(0, 0, 0, 0.5)", | |
| annotation_text=f"Prefill→Decode<br>({cache_at_context:.1f} MB)", | |
| annotation_position="top", | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text=f"KV Cache Growth ({num_kv_heads} KV heads × {num_layers} layers)", | |
| x=0.5, | |
| ), | |
| xaxis_title="Tokens Processed", | |
| yaxis_title="KV Cache Size (MB)", | |
| height=300, | |
| margin=dict(l=50, r=50, t=60, b=50), | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=-0.25, | |
| xanchor="center", | |
| x=0.5, | |
| ), | |
| yaxis=dict(rangemode='tozero'), | |
| ) | |
| return fig | |
| def generate_phase_insight(results: dict) -> str: | |
| """Generate insight text from comparison results.""" | |
| prefill_flash = results["prefill"]["flash"] | |
| prefill_math = results["prefill"]["math"] | |
| decode_flash = results["decode"]["flash"] | |
| decode_math = results["decode"]["math"] | |
| # Helper to safely get numeric value (handles None) | |
| def safe_get(d, key, default=0): | |
| val = d.get(key, default) | |
| return val if val is not None else default | |
| prefill_math_time = safe_get(prefill_math, "time_ms", 0) | |
| prefill_flash_time = safe_get(prefill_flash, "time_ms", 0) | |
| decode_math_time = safe_get(decode_math, "time_ms_per_token", 0) | |
| decode_flash_time = safe_get(decode_flash, "time_ms_per_token", 0) | |
| # Calculate speedups | |
| if prefill_math_time > 0 and prefill_flash_time > 0: | |
| prefill_speedup = prefill_math_time / prefill_flash_time | |
| else: | |
| prefill_speedup = 1.0 | |
| if decode_math_time > 0 and decode_flash_time > 0: | |
| decode_speedup = decode_math_time / decode_flash_time | |
| else: | |
| decode_speedup = 1.0 | |
| context_length = results["context_length"] | |
| decode_tokens = results["decode_tokens"] | |
| insight = f"""### Key Observations | |
| **Prefill Phase** (processing {context_length} tokens): | |
| - Standard attention: **{prefill_math_time:.2f}ms** | |
| - FlashAttention: **{prefill_flash_time:.2f}ms** | |
| - Speedup: **{prefill_speedup:.1f}×** | |
| **Decode Phase** (generating {decode_tokens} tokens): | |
| - Standard attention: **{decode_math_time:.3f}ms/token** | |
| - FlashAttention: **{decode_flash_time:.3f}ms/token** | |
| - Speedup: **{decode_speedup:.1f}×** | |
| --- | |
| ### Why the Difference? | |
| 1. **Prefill is compute-bound** with N² attention operations | |
| - FlashAttention's memory efficiency provides significant speedup | |
| - Larger contexts benefit more (quadratic scaling) | |
| 2. **Decode is memory-bound** with 1×N attention per token | |
| - Each decode step is fast but sequential | |
| - KV cache read dominates, limiting FlashAttention's advantage | |
| 3. **Optimal strategy**: FlashAttention helps most during prefill; | |
| decode phase benefits from KV cache optimizations (GQA/MQA) | |
| """ | |
| return insight | |
| def get_attention_pattern_chart(context_length: int) -> go.Figure: | |
| """Create visualization of prefill vs decode attention patterns using scatter.""" | |
| # Calculate FLOPs for insight | |
| prefill_flops = context_length * context_length # N² attention | |
| decode_flops_per_token = context_length # 1×N per decode token | |
| fig = make_subplots( | |
| rows=1, cols=2, | |
| subplot_titles=( | |
| f"<b>Prefill:</b> {context_length}×{context_length} = {prefill_flops:,} ops", | |
| f"<b>Decode:</b> 1×{context_length} = {decode_flops_per_token:,} ops/token" | |
| ), | |
| horizontal_spacing=0.15, | |
| ) | |
| # Prefill: Lower triangular pattern (causal mask) | |
| # Dynamic size based on context length for visual feedback | |
| if context_length <= 16: | |
| size = context_length | |
| elif context_length <= 128: | |
| size = 12 | |
| elif context_length <= 512: | |
| size = 10 | |
| else: | |
| size = 8 # Smaller grid for very large contexts | |
| # Adjust marker size based on grid size | |
| marker_size = max(10, 22 - size) | |
| # Generate coordinates for filled cells (lower triangular) | |
| prefill_x = [] | |
| prefill_y = [] | |
| for row in range(size): | |
| for col in range(row + 1): # Only up to diagonal | |
| prefill_x.append(col) | |
| prefill_y.append(row) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=prefill_x, | |
| y=prefill_y, | |
| mode="markers", | |
| marker=dict( | |
| size=marker_size, | |
| color="#3b82f6", | |
| symbol="square", | |
| ), | |
| name="Attends", | |
| showlegend=False, | |
| hovertemplate="Query %{y} → Key %{x}<extra></extra>", | |
| ), | |
| row=1, col=1 | |
| ) | |
| # Decode: Each step attends to growing sequence | |
| num_decode_steps = 6 | |
| base_context = max(4, size - num_decode_steps) | |
| decode_x = [] | |
| decode_y = [] | |
| for step in range(num_decode_steps): | |
| attend_len = base_context + step + 1 | |
| for col in range(min(attend_len, size)): | |
| decode_x.append(col) | |
| decode_y.append(step) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=decode_x, | |
| y=decode_y, | |
| mode="markers", | |
| marker=dict( | |
| size=marker_size + 4, | |
| color="#22c55e", | |
| symbol="square", | |
| ), | |
| name="Attends", | |
| showlegend=False, | |
| hovertemplate="Decode step %{y} → Key %{x}<extra></extra>", | |
| ), | |
| row=1, col=2 | |
| ) | |
| # Update axes with proper ranges | |
| fig.update_xaxes( | |
| title_text="Key positions", | |
| range=[-0.5, size - 0.5], | |
| dtick=2, | |
| row=1, col=1 | |
| ) | |
| fig.update_xaxes( | |
| title_text="Key positions (KV cache)", | |
| range=[-0.5, size - 0.5], | |
| dtick=2, | |
| row=1, col=2 | |
| ) | |
| fig.update_yaxes( | |
| title_text="Query positions", | |
| range=[-0.5, size - 0.5], | |
| dtick=2, | |
| row=1, col=1 | |
| ) | |
| fig.update_yaxes( | |
| title_text="Decode steps", | |
| range=[-0.5, num_decode_steps - 0.5], | |
| dtick=1, | |
| row=1, col=2 | |
| ) | |
| fig.update_layout( | |
| height=380, | |
| margin=dict(l=60, r=30, t=70, b=50), | |
| plot_bgcolor="rgba(241, 245, 249, 0.5)", | |
| ) | |
| return fig | |