""" Prefill vs Decode phase comparison module. Demonstrates the key difference between: - Prefill: Process entire prompt in parallel (N² attention complexity) - Decode: Generate one token at a time (N attention per token, but sequential) Uses REAL HuggingFace model attention layers for accurate benchmarking. """ import torch import torch.nn.functional as F import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots from .constants import MODEL_CONFIGS, ATTENTION_BACKENDS from .models import load_model from .attention_utils import ( extract_attention_layer, create_attention_inputs, benchmark_attention_layer, get_model_attention_info, ) def get_real_model_config(model_name: str) -> dict: """ Load model and extract ACTUAL config values from model.config. This function ensures we use real model architecture values, NOT hardcoded constants from MODEL_CONFIGS. Args: model_name: Key from MODEL_CONFIGS (e.g., "SmolLM2-360M") Returns: Dict with real model configuration values """ model = load_model(model_name) config = model.config # Extract values directly from model.config num_heads = config.num_attention_heads num_kv_heads = getattr(config, 'num_key_value_heads', num_heads) head_dim = config.hidden_size // num_heads return { "num_layers": config.num_hidden_layers, "num_heads": num_heads, "num_kv_heads": num_kv_heads, "head_dim": head_dim, "hidden_size": config.hidden_size, "model_type": getattr(config, 'model_type', 'unknown'), "gqa_ratio": num_heads // num_kv_heads if num_kv_heads > 0 else 1, } def run_prefill_with_real_model( model, attention_layer, seq_len: int, batch_size: int = 1, num_iterations: int = 5, use_flash: bool = True, ) -> dict: """ Run prefill phase attention using a REAL model's attention layer. Prefill processes the entire prompt at once: - Hidden states have shape [batch, seq_len, hidden_dim] - Full N×N attention matrix computed via the real attention layer Args: model: Loaded HuggingFace model attention_layer: Extracted attention module seq_len: Sequence length batch_size: Batch size num_iterations: Number of timed iterations use_flash: Whether to use FlashAttention backend Returns: Dict with timing and memory stats """ if not torch.cuda.is_available(): return {"error": "CUDA not available"} device = torch.device("cuda") dtype = torch.float16 # Create proper inputs for the attention layer hidden_states, position_ids = create_attention_inputs( model, batch_size, seq_len, device, dtype ) # Backend configuration backend = "flash" if use_flash else "math" # Run benchmark using the utility function result = benchmark_attention_layer( attention_layer=attention_layer, hidden_states=hidden_states, position_ids=position_ids, backend=backend, num_iterations=num_iterations, warmup_iterations=2, ) # Clean up del hidden_states, position_ids torch.cuda.empty_cache() # Add phase info to result result["seq_len"] = seq_len result["phase"] = "prefill" result["using_real_model"] = True return result def run_prefill_benchmark( model_name: str, seq_len: int, batch_size: int = 1, num_iterations: int = 10, use_flash: bool = True, ) -> dict: """ Benchmark prefill phase using F.scaled_dot_product_attention with REAL model dimensions. This function uses the model's actual configuration (from model.config) to create properly-sized Q, K, V tensors, then benchmarks the SDPA operation directly. This is more reliable than calling attention layer forward() methods. Args: model_name: Key from MODEL_CONFIGS (model will be loaded to get real config) seq_len: Sequence length (prompt tokens) batch_size: Batch size num_iterations: Number of timed iterations use_flash: Whether to use FlashAttention backend Returns: Dict with time_ms, memory_mb, and status """ if not torch.cuda.is_available(): return {"time_ms": 0, "memory_mb": 0, "status": "error: CUDA not available"} device = torch.device("cuda") dtype = torch.float16 try: # Get REAL config from loaded model real_config = get_real_model_config(model_name) num_heads = real_config["num_heads"] head_dim = real_config["head_dim"] # Create Q, K, V tensors with REAL model dimensions # Shape: [batch, num_heads, seq_len, head_dim] Q = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) K = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) V = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) # Set backend flags if use_flash: enable_math, enable_flash, enable_mem_efficient = False, True, False else: enable_math, enable_flash, enable_mem_efficient = True, False, False # Warmup for _ in range(3): with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): _ = F.scaled_dot_product_attention(Q, K, V, is_causal=True) torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() # Timed runs start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(num_iterations): with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): output = F.scaled_dot_product_attention(Q, K, V, is_causal=True) end.record() torch.cuda.synchronize() time_ms = start.elapsed_time(end) / num_iterations memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) # Cleanup del Q, K, V, output torch.cuda.empty_cache() return { "time_ms": round(time_ms, 3), "memory_mb": round(memory_mb, 1), "seq_len": seq_len, "phase": "prefill", "backend": "flash" if use_flash else "math", "num_heads": num_heads, "head_dim": head_dim, "status": "success", "using_real_config": True, } except Exception as e: return { "time_ms": 0, "memory_mb": 0, "status": f"error: {str(e)[:100]}", "phase": "prefill", } def run_decode_with_real_model( model, attention_layer, kv_cache_len: int, num_tokens: int = 10, batch_size: int = 1, num_iterations: int = 3, use_flash: bool = True, ) -> dict: """ Run decode phase attention using a REAL model's attention layer. Decode generates one token at a time: - Single query token attending to all past keys/values - Simulates the memory-bound decode phase Args: model: Loaded HuggingFace model attention_layer: Extracted attention module kv_cache_len: Length of the KV cache (context) num_tokens: Number of tokens to simulate generating batch_size: Batch size num_iterations: Iterations for averaging use_flash: Whether to use FlashAttention backend Returns: Dict with per-token timing and memory stats """ if not torch.cuda.is_available(): return {"error": "CUDA not available"} device = torch.device("cuda") dtype = torch.float16 # Create single-token query input (simulating decode) hidden_dim = model.config.hidden_size query_hidden = torch.randn(batch_size, 1, hidden_dim, dtype=dtype, device=device) position_ids = torch.tensor([[kv_cache_len]], device=device).expand(batch_size, 1) # Backend flags if use_flash: enable_math, enable_flash, enable_mem_efficient = False, True, False else: enable_math, enable_flash, enable_mem_efficient = True, False, False try: # Warmup with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): with torch.no_grad(): for _ in range(2): _ = attention_layer(query_hidden, position_ids=position_ids) torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() # Time multiple tokens start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): with torch.no_grad(): start.record() for _ in range(num_tokens * num_iterations): output = attention_layer(query_hidden, position_ids=position_ids) end.record() torch.cuda.synchronize() total_time_ms = start.elapsed_time(end) time_per_token_ms = total_time_ms / (num_tokens * num_iterations) memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) # Clean up del query_hidden torch.cuda.empty_cache() return { "time_ms_per_token": round(time_per_token_ms, 4), "total_time_ms": round(total_time_ms / num_iterations, 3), "memory_mb": round(memory_mb, 1), "kv_cache_len": kv_cache_len, "num_tokens": num_tokens, "phase": "decode", "using_real_model": True, "status": "success", } except Exception as e: return { "time_ms_per_token": 0, "total_time_ms": 0, "memory_mb": 0, "kv_cache_len": kv_cache_len, "num_tokens": num_tokens, "phase": "decode", "status": f"error: {str(e)[:80]}", } def run_decode_benchmark( model_name: str, kv_cache_len: int, num_tokens: int = 10, batch_size: int = 1, num_iterations: int = 5, use_flash: bool = True, ) -> dict: """ Benchmark decode phase using F.scaled_dot_product_attention with REAL model dimensions. Properly simulates decode by: - Creating single query token (Q with seq_len=1) - Creating KV cache tensors with kv_cache_len tokens - Handling GQA by expanding KV heads to match Q heads Args: model_name: Key from MODEL_CONFIGS (model will be loaded to get real config) kv_cache_len: Length of KV cache (context length) num_tokens: Number of decode tokens to simulate batch_size: Batch size num_iterations: Iterations for timing use_flash: Whether to use FlashAttention backend Returns: Dict with time_ms_per_token, memory_mb, and status """ if not torch.cuda.is_available(): return {"time_ms_per_token": 0, "memory_mb": 0, "status": "error: CUDA not available"} device = torch.device("cuda") dtype = torch.float16 try: # Get REAL config from loaded model real_config = get_real_model_config(model_name) num_heads = real_config["num_heads"] num_kv_heads = real_config["num_kv_heads"] head_dim = real_config["head_dim"] # Single query token: [batch, num_heads, 1, head_dim] Q = torch.randn(batch_size, num_heads, 1, head_dim, dtype=dtype, device=device) # KV cache with real model's KV head count: [batch, num_kv_heads, kv_cache_len, head_dim] K_cache = torch.randn(batch_size, num_kv_heads, kv_cache_len, head_dim, dtype=dtype, device=device) V_cache = torch.randn(batch_size, num_kv_heads, kv_cache_len, head_dim, dtype=dtype, device=device) # Handle GQA: expand KV heads to match Q heads if needed if num_kv_heads < num_heads: repeat_factor = num_heads // num_kv_heads K_cache = K_cache.repeat_interleave(repeat_factor, dim=1) V_cache = V_cache.repeat_interleave(repeat_factor, dim=1) # Set backend flags if use_flash: enable_math, enable_flash_flag, enable_mem_efficient = False, True, False else: enable_math, enable_flash_flag, enable_mem_efficient = True, False, False # Warmup for _ in range(3): with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash_flag, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): _ = F.scaled_dot_product_attention(Q, K_cache, V_cache) torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() # Timed runs - simulate generating num_tokens start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(num_tokens * num_iterations): with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash_flag, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): output = F.scaled_dot_product_attention(Q, K_cache, V_cache) end.record() torch.cuda.synchronize() total_time_ms = start.elapsed_time(end) time_per_token_ms = total_time_ms / (num_tokens * num_iterations) memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) # Cleanup del Q, K_cache, V_cache, output torch.cuda.empty_cache() return { "time_ms_per_token": round(time_per_token_ms, 4), "total_time_ms": round(total_time_ms / num_iterations, 3), "memory_mb": round(memory_mb, 1), "kv_cache_len": kv_cache_len, "num_tokens": num_tokens, "phase": "decode", "backend": "flash" if use_flash else "math", "num_heads": num_heads, "num_kv_heads": num_kv_heads, "head_dim": head_dim, "status": "success", "using_real_config": True, } except Exception as e: return { "time_ms_per_token": 0, "total_time_ms": 0, "memory_mb": 0, "kv_cache_len": kv_cache_len, "num_tokens": num_tokens, "phase": "decode", "status": f"error: {str(e)[:100]}", } # Legacy function kept for backwards compatibility def simulate_prefill_attention( batch_size: int, num_heads: int, seq_len: int, head_dim: int, num_iterations: int = 5, use_flash: bool = True, ) -> dict: """ Legacy: Simulate prefill phase attention with random tensors. Use run_prefill_with_real_model() for real model benchmarks. """ if not torch.cuda.is_available(): return {"error": "CUDA not available"} device = torch.device("cuda") dtype = torch.float16 Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) K = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) V = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) if use_flash: enable_math, enable_flash_flag, enable_mem_efficient = False, True, False else: enable_math, enable_flash_flag, enable_mem_efficient = True, False, False # Warmup for _ in range(2): with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash_flag, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): try: _ = F.scaled_dot_product_attention(Q, K, V) except Exception: pass torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(num_iterations): with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash_flag, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): try: output = F.scaled_dot_product_attention(Q, K, V) except Exception: output = F.scaled_dot_product_attention(Q, K, V) end.record() torch.cuda.synchronize() total_time_ms = start.elapsed_time(end) avg_time_ms = total_time_ms / num_iterations peak_memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) del Q, K, V, output torch.cuda.empty_cache() return { "time_ms": avg_time_ms, "memory_mb": peak_memory_mb, "seq_len": seq_len, "phase": "prefill", } # Legacy function kept for backwards compatibility def simulate_decode_attention( batch_size: int, num_heads: int, kv_cache_len: int, head_dim: int, num_tokens: int = 10, use_flash: bool = True, ) -> dict: """ Legacy: Simulate decode phase attention with random tensors. Use run_decode_with_real_model() for real model benchmarks. """ if not torch.cuda.is_available(): return {"error": "CUDA not available"} device = torch.device("cuda") dtype = torch.float16 K_cache = torch.randn(batch_size, num_heads, kv_cache_len, head_dim, device=device, dtype=dtype) V_cache = torch.randn(batch_size, num_heads, kv_cache_len, head_dim, device=device, dtype=dtype) Q = torch.randn(batch_size, num_heads, 1, head_dim, device=device, dtype=dtype) if use_flash: enable_math, enable_flash_flag, enable_mem_efficient = False, True, False else: enable_math, enable_flash_flag, enable_mem_efficient = True, False, False # Warmup for _ in range(2): with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash_flag, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): try: _ = F.scaled_dot_product_attention(Q, K_cache, V_cache) except Exception: pass torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(num_tokens): with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash_flag, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): try: output = F.scaled_dot_product_attention(Q, K_cache, V_cache) except Exception: output = F.scaled_dot_product_attention(Q, K_cache, V_cache) end.record() torch.cuda.synchronize() total_time_ms = start.elapsed_time(end) avg_time_per_token_ms = total_time_ms / num_tokens peak_memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) del Q, K_cache, V_cache, output torch.cuda.empty_cache() return { "time_ms_per_token": avg_time_per_token_ms, "total_time_ms": total_time_ms, "memory_mb": peak_memory_mb, "kv_cache_len": kv_cache_len, "num_tokens": num_tokens, "phase": "decode", } def run_prefill_decode_comparison( model_name: str, context_length: int, decode_tokens: int = 32, ) -> tuple: """ Run full comparison between prefill and decode phases using REAL HuggingFace model. Uses F.scaled_dot_product_attention with real model dimensions for reliable benchmarking. All config values come from model.config, not constants. Returns results dict, comparison chart, KV cache chart, and insight text. """ if model_name not in MODEL_CONFIGS: return {"error": f"Unknown model: {model_name}"}, None, None, "Error: Unknown model" # Get REAL config from model.config (not constants) try: real_config = get_real_model_config(model_name) except Exception as e: return {"error": f"Failed to load model: {str(e)[:50]}"}, None, None, f"Error: {str(e)[:50]}" results = { "model": model_name, "context_length": context_length, "decode_tokens": decode_tokens, "real_config": real_config, "using_real_config": True, } # Run prefill benchmarks using SDPA with REAL model dimensions prefill_flash = run_prefill_benchmark( model_name=model_name, seq_len=context_length, batch_size=1, use_flash=True, ) prefill_math = run_prefill_benchmark( model_name=model_name, seq_len=context_length, batch_size=1, use_flash=False, ) # Run decode benchmarks using SDPA with proper KV cache simulation decode_flash = run_decode_benchmark( model_name=model_name, kv_cache_len=context_length, num_tokens=decode_tokens, batch_size=1, use_flash=True, ) decode_math = run_decode_benchmark( model_name=model_name, kv_cache_len=context_length, num_tokens=decode_tokens, batch_size=1, use_flash=False, ) results["prefill"] = { "flash": prefill_flash, "math": prefill_math, } results["decode"] = { "flash": decode_flash, "math": decode_math, } # Add model info for display results["model_info"] = { "num_heads": real_config["num_heads"], "num_kv_heads": real_config["num_kv_heads"], "head_dim": real_config["head_dim"], "num_layers": real_config["num_layers"], "gqa_ratio": real_config["gqa_ratio"], } # Create comparison chart comparison_chart = create_comparison_chart(results) # Create KV cache growth chart using REAL model config kv_cache_chart = create_kv_cache_chart(model_name, context_length, decode_tokens) # Generate insight insight = generate_phase_insight(results) # Add real model indicator to insight if results.get("using_real_config"): model_indicator = f"\n\n---\n\n*Benchmarked using real **{model_name}** config ({real_config['num_heads']} heads, {real_config['head_dim']}d, GQA {real_config['gqa_ratio']}:1)*" insight = insight + model_indicator return results, comparison_chart, kv_cache_chart, insight def create_comparison_chart(results: dict) -> go.Figure: """Create bar chart comparing prefill vs decode timing.""" prefill_flash = results["prefill"]["flash"] prefill_math = results["prefill"]["math"] decode_flash = results["decode"]["flash"] decode_math = results["decode"]["math"] # Helper to safely get numeric value (handles None) def safe_get(d, key, default=0): val = d.get(key, default) return val if val is not None else default fig = make_subplots( rows=1, cols=2, subplot_titles=("Prefill Time (Full Prompt)", "Decode Time (Per Token)"), horizontal_spacing=0.15, vertical_spacing=0.15, ) # Get max values for proper y-axis scaling with headroom for labels prefill_math_time = safe_get(prefill_math, "time_ms", 0) prefill_flash_time = safe_get(prefill_flash, "time_ms", 0) decode_math_time = safe_get(decode_math, "time_ms_per_token", 0) decode_flash_time = safe_get(decode_flash, "time_ms_per_token", 0) prefill_max = max(prefill_math_time, prefill_flash_time) decode_max = max(decode_math_time, decode_flash_time) # Prefill comparison fig.add_trace( go.Bar( x=["Math
(Standard)", "Flash
Attention"], y=[prefill_math_time, prefill_flash_time], marker_color=["#ef4444", "#22c55e"], text=[f"{prefill_math_time:.2f}ms", f"{prefill_flash_time:.2f}ms"], textposition="inside", textangle=0, insidetextanchor="middle", textfont=dict(color="white", size=12), name="Prefill", showlegend=False, ), row=1, col=1 ) # Decode comparison (per token) fig.add_trace( go.Bar( x=["Math
(Standard)", "Flash
Attention"], y=[decode_math_time, decode_flash_time], marker_color=["#ef4444", "#22c55e"], text=[f"{decode_math_time:.3f}ms", f"{decode_flash_time:.3f}ms"], textposition="inside", textangle=0, insidetextanchor="middle", textfont=dict(color="white", size=12), name="Decode", showlegend=False, ), row=1, col=2 ) # Calculate speedups if prefill_math_time > 0 and prefill_flash_time > 0: prefill_speedup = prefill_math_time / prefill_flash_time else: prefill_speedup = 1.0 if decode_math_time > 0 and decode_flash_time > 0: decode_speedup = decode_math_time / decode_flash_time else: decode_speedup = 1.0 fig.update_layout( title=dict( text=f"Prefill vs Decode: FlashAttention Speedup
" f"" f"Prefill: {prefill_speedup:.1f}× faster | Decode: {decode_speedup:.1f}× faster", x=0.5, font=dict(size=15), ), height=380, margin=dict(l=60, r=40, t=100, b=60), yaxis_title="Time (ms)", yaxis2_title="Time (ms)", ) # Add more y-axis headroom fig.update_yaxes(range=[0, prefill_max * 1.15], row=1, col=1) fig.update_yaxes(range=[0, decode_max * 1.15], row=1, col=2) return fig def create_kv_cache_chart(model_name: str, context_length: int, decode_tokens: int) -> go.Figure: """ Create chart showing KV cache growth during generation. Uses REAL model config values from model.config, not constants. Args: model_name: Model name to load config from context_length: Number of context tokens (prefill) decode_tokens: Number of decode tokens to generate Returns: Plotly figure showing KV cache growth """ # Get REAL config from loaded model (no constants!) real_config = get_real_model_config(model_name) num_kv_heads = real_config["num_kv_heads"] head_dim = real_config["head_dim"] num_layers = real_config["num_layers"] # Calculate KV cache size at each step # KV cache per layer: 2 (K+V) × kv_heads × head_dim × 2 (FP16 bytes) bytes_per_token_per_layer = 2 * num_kv_heads * head_dim * 2 total_bytes_per_token = bytes_per_token_per_layer * num_layers # Generate sequence of token counts token_counts = list(range(0, context_length + decode_tokens + 1, max(1, (context_length + decode_tokens) // 50))) if token_counts[-1] != context_length + decode_tokens: token_counts.append(context_length + decode_tokens) # Calculate cache sizes in MB cache_sizes_mb = [(t * total_bytes_per_token) / (1024 * 1024) for t in token_counts] fig = go.Figure() # Prefill region (0 to context_length) prefill_tokens = [t for t in token_counts if t <= context_length] prefill_sizes = [(t * total_bytes_per_token) / (1024 * 1024) for t in prefill_tokens] fig.add_trace(go.Scatter( x=prefill_tokens, y=prefill_sizes, mode="lines", name="Prefill Phase", fill="tozeroy", line=dict(color="#3b82f6", width=2), fillcolor="rgba(59, 130, 246, 0.3)", )) # Decode region (context_length to end) decode_tokens_list = [t for t in token_counts if t >= context_length] decode_sizes = [(t * total_bytes_per_token) / (1024 * 1024) for t in decode_tokens_list] fig.add_trace(go.Scatter( x=decode_tokens_list, y=decode_sizes, mode="lines", name="Decode Phase", fill="tozeroy", line=dict(color="#22c55e", width=2), fillcolor="rgba(34, 197, 94, 0.3)", )) # Add vertical line at context boundary cache_at_context = (context_length * total_bytes_per_token) / (1024 * 1024) fig.add_vline( x=context_length, line_dash="dash", line_color="rgba(0, 0, 0, 0.5)", annotation_text=f"Prefill→Decode
({cache_at_context:.1f} MB)", annotation_position="top", ) fig.update_layout( title=dict( text=f"KV Cache Growth ({num_kv_heads} KV heads × {num_layers} layers)", x=0.5, ), xaxis_title="Tokens Processed", yaxis_title="KV Cache Size (MB)", height=300, margin=dict(l=50, r=50, t=60, b=50), legend=dict( orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.5, ), yaxis=dict(rangemode='tozero'), ) return fig def generate_phase_insight(results: dict) -> str: """Generate insight text from comparison results.""" prefill_flash = results["prefill"]["flash"] prefill_math = results["prefill"]["math"] decode_flash = results["decode"]["flash"] decode_math = results["decode"]["math"] # Helper to safely get numeric value (handles None) def safe_get(d, key, default=0): val = d.get(key, default) return val if val is not None else default prefill_math_time = safe_get(prefill_math, "time_ms", 0) prefill_flash_time = safe_get(prefill_flash, "time_ms", 0) decode_math_time = safe_get(decode_math, "time_ms_per_token", 0) decode_flash_time = safe_get(decode_flash, "time_ms_per_token", 0) # Calculate speedups if prefill_math_time > 0 and prefill_flash_time > 0: prefill_speedup = prefill_math_time / prefill_flash_time else: prefill_speedup = 1.0 if decode_math_time > 0 and decode_flash_time > 0: decode_speedup = decode_math_time / decode_flash_time else: decode_speedup = 1.0 context_length = results["context_length"] decode_tokens = results["decode_tokens"] insight = f"""### Key Observations **Prefill Phase** (processing {context_length} tokens): - Standard attention: **{prefill_math_time:.2f}ms** - FlashAttention: **{prefill_flash_time:.2f}ms** - Speedup: **{prefill_speedup:.1f}×** **Decode Phase** (generating {decode_tokens} tokens): - Standard attention: **{decode_math_time:.3f}ms/token** - FlashAttention: **{decode_flash_time:.3f}ms/token** - Speedup: **{decode_speedup:.1f}×** --- ### Why the Difference? 1. **Prefill is compute-bound** with N² attention operations - FlashAttention's memory efficiency provides significant speedup - Larger contexts benefit more (quadratic scaling) 2. **Decode is memory-bound** with 1×N attention per token - Each decode step is fast but sequential - KV cache read dominates, limiting FlashAttention's advantage 3. **Optimal strategy**: FlashAttention helps most during prefill; decode phase benefits from KV cache optimizations (GQA/MQA) """ return insight def get_attention_pattern_chart(context_length: int) -> go.Figure: """Create visualization of prefill vs decode attention patterns using scatter.""" # Calculate FLOPs for insight prefill_flops = context_length * context_length # N² attention decode_flops_per_token = context_length # 1×N per decode token fig = make_subplots( rows=1, cols=2, subplot_titles=( f"Prefill: {context_length}×{context_length} = {prefill_flops:,} ops", f"Decode: 1×{context_length} = {decode_flops_per_token:,} ops/token" ), horizontal_spacing=0.15, ) # Prefill: Lower triangular pattern (causal mask) # Dynamic size based on context length for visual feedback if context_length <= 16: size = context_length elif context_length <= 128: size = 12 elif context_length <= 512: size = 10 else: size = 8 # Smaller grid for very large contexts # Adjust marker size based on grid size marker_size = max(10, 22 - size) # Generate coordinates for filled cells (lower triangular) prefill_x = [] prefill_y = [] for row in range(size): for col in range(row + 1): # Only up to diagonal prefill_x.append(col) prefill_y.append(row) fig.add_trace( go.Scatter( x=prefill_x, y=prefill_y, mode="markers", marker=dict( size=marker_size, color="#3b82f6", symbol="square", ), name="Attends", showlegend=False, hovertemplate="Query %{y} → Key %{x}", ), row=1, col=1 ) # Decode: Each step attends to growing sequence num_decode_steps = 6 base_context = max(4, size - num_decode_steps) decode_x = [] decode_y = [] for step in range(num_decode_steps): attend_len = base_context + step + 1 for col in range(min(attend_len, size)): decode_x.append(col) decode_y.append(step) fig.add_trace( go.Scatter( x=decode_x, y=decode_y, mode="markers", marker=dict( size=marker_size + 4, color="#22c55e", symbol="square", ), name="Attends", showlegend=False, hovertemplate="Decode step %{y} → Key %{x}", ), row=1, col=2 ) # Update axes with proper ranges fig.update_xaxes( title_text="Key positions", range=[-0.5, size - 0.5], dtick=2, row=1, col=1 ) fig.update_xaxes( title_text="Key positions (KV cache)", range=[-0.5, size - 0.5], dtick=2, row=1, col=2 ) fig.update_yaxes( title_text="Query positions", range=[-0.5, size - 0.5], dtick=2, row=1, col=1 ) fig.update_yaxes( title_text="Decode steps", range=[-0.5, num_decode_steps - 0.5], dtick=1, row=1, col=2 ) fig.update_layout( height=380, margin=dict(l=60, r=30, t=70, b=50), plot_bgcolor="rgba(241, 245, 249, 0.5)", ) return fig