""" Benchmark module for FlashAttention Explorer. GPU benchmark functions for comparing attention backends using real HuggingFace models. """ import torch import torch.nn.functional as F import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots from .constants import GPU_SPECS, ATTENTION_BACKENDS, MODEL_CONFIGS, DEFAULT_GPU, DEFAULT_MODEL from .models import load_model, clear_model_cache from .attention_utils import ( extract_attention_layer, create_attention_inputs, benchmark_attention_layer, get_model_attention_info, ) def detect_gpu() -> dict: """ Detect the actual GPU and return its specs. Returns: Dict with GPU name and specs """ if not torch.cuda.is_available(): return {"name": "CPU (No GPU)", "detected": False, **GPU_SPECS[DEFAULT_GPU]} gpu_name_raw = torch.cuda.get_device_name(0) gpu_name = gpu_name_raw.lower() # Get memory in GB for dynamic spec estimation try: mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) except Exception: mem_gb = 24 # fallback # Match against known GPUs (ordered from newest to oldest) if "h200" in gpu_name: # H200 specs - HBM3e memory, very high bandwidth return { "detected": True, "detected_name": gpu_name_raw, "name": "NVIDIA H200", "tflops_fp16": 989, # Same compute as H100 "bandwidth_gbps": 4800, # HBM3e: 4.8 TB/s "memory_gb": round(mem_gb), "sram_kb": 256, } elif "h100" in gpu_name: return {"detected": True, "detected_name": gpu_name_raw, **GPU_SPECS["H100"]} elif "a100" in gpu_name: return {"detected": True, "detected_name": gpu_name_raw, **GPU_SPECS["A100_80GB"]} elif "a10" in gpu_name: return {"detected": True, "detected_name": gpu_name_raw, **GPU_SPECS["A10G"]} elif "l40" in gpu_name: # L40S specs return { "detected": True, "detected_name": gpu_name_raw, "name": "NVIDIA L40S", "tflops_fp16": 362, "bandwidth_gbps": 864, "memory_gb": round(mem_gb), "sram_kb": 192, } elif "l4" in gpu_name: # L4 specs return { "detected": True, "detected_name": gpu_name_raw, "name": "NVIDIA L4", "tflops_fp16": 121, "bandwidth_gbps": 300, "memory_gb": round(mem_gb), "sram_kb": 96, } elif "t4" in gpu_name: return { "detected": True, "detected_name": gpu_name_raw, "name": "NVIDIA T4", "tflops_fp16": 65, "bandwidth_gbps": 320, "memory_gb": round(mem_gb), "sram_kb": 64, } elif "v100" in gpu_name: return { "detected": True, "detected_name": gpu_name_raw, "name": "NVIDIA V100", "tflops_fp16": 125, "bandwidth_gbps": 900, "memory_gb": round(mem_gb), "sram_kb": 128, } elif "rtx 4090" in gpu_name or "4090" in gpu_name: return { "detected": True, "detected_name": gpu_name_raw, "name": "NVIDIA RTX 4090", "tflops_fp16": 330, "bandwidth_gbps": 1008, "memory_gb": round(mem_gb), "sram_kb": 128, } else: # Unknown GPU - estimate specs using compute capability and SM count # These are the best indicators of performance we can query try: props = torch.cuda.get_device_properties(0) sm_count = props.multi_processor_count major, minor = torch.cuda.get_device_capability(0) # FP16 FLOPs per SM per cycle varies by architecture # Ampere (8.x): 256 FP16 ops/SM/cycle, Hopper (9.x): 512 # Clock speed ~1.5-2 GHz typically if major >= 9: # Hopper/Ada flops_per_sm = 512 clock_ghz = 1.8 bw_per_gb_mem = 50 # Rough: HBM3 ~50 GB/s per GB capacity elif major >= 8: # Ampere flops_per_sm = 256 clock_ghz = 1.5 bw_per_gb_mem = 25 # HBM2e elif major >= 7: # Volta/Turing flops_per_sm = 128 clock_ghz = 1.4 bw_per_gb_mem = 28 else: # Older flops_per_sm = 64 clock_ghz = 1.2 bw_per_gb_mem = 20 # Estimate TFLOPS: SMs × FLOPs/SM/cycle × clock × 2 (FMA) est_tflops = (sm_count * flops_per_sm * clock_ghz * 2) / 1000 est_bw = mem_gb * bw_per_gb_mem except Exception: # Fallback if properties query fails est_tflops = 125 est_bw = 600 return { "detected": True, "detected_name": gpu_name_raw, "name": gpu_name_raw, "tflops_fp16": round(est_tflops), "bandwidth_gbps": round(est_bw), "memory_gb": round(mem_gb), "sram_kb": 128, "estimated": True, # Flag that these are estimated from compute capability "compute_capability": f"{major}.{minor}" if 'major' in dir() else "unknown", } def run_attention_benchmark( model_name: str = None, seq_len: int = 1024, batch_size: int = 1, num_iterations: int = 10, warmup_iterations: int = 3, # Legacy parameters (used if model_name is None) num_heads: int = 16, head_dim: int = 64, ) -> dict: """ Benchmark three SDPA backends using a real HuggingFace model's attention layer. Args: model_name: Name of the model from MODEL_CONFIGS (e.g., "SmolLM2-360M") If None, falls back to legacy random tensor mode seq_len: Sequence length (number of tokens) batch_size: Batch size num_iterations: Number of timed iterations warmup_iterations: Number of warmup iterations num_heads: (Legacy) Number of attention heads if model_name is None head_dim: (Legacy) Dimension per head if model_name is None Returns: Dict with timing and memory results per backend """ if not torch.cuda.is_available(): return {"error": "CUDA not available"} device = torch.device("cuda") dtype = torch.float16 # If model_name is provided, use real model dimensions for benchmarking if model_name is not None and model_name in MODEL_CONFIGS: try: # Load the real HuggingFace model model = load_model(model_name) # Get model attention info for real dimensions attn_info = get_model_attention_info(model) # Extract dimensions from real model model_num_heads = attn_info["num_attention_heads"] model_head_dim = attn_info["head_dim"] results = {"model_name": model_name, "using_real_model": True} results["model_info"] = attn_info # First try: Use actual attention layer forward pass attention_layer_works = False try: attention_layer = extract_attention_layer(model, layer_idx=0) hidden_states, position_ids = create_attention_inputs( model, batch_size, seq_len, device, dtype ) # Test if attention layer works with first backend test_result = benchmark_attention_layer( attention_layer=attention_layer, hidden_states=hidden_states, position_ids=position_ids, backend="flash", num_iterations=2, warmup_iterations=1, ) if test_result.get("time_ms") is not None: attention_layer_works = True del hidden_states, position_ids torch.cuda.empty_cache() except Exception as layer_error: print(f"[run_attention_benchmark] Attention layer extraction failed: {layer_error}") attention_layer_works = False if attention_layer_works: # Use actual attention layer hidden_states, position_ids = create_attention_inputs( model, batch_size, seq_len, device, dtype ) for backend in ["math", "flash", "mem_efficient"]: result = benchmark_attention_layer( attention_layer=attention_layer, hidden_states=hidden_states, position_ids=position_ids, backend=backend, num_iterations=num_iterations, warmup_iterations=warmup_iterations, ) results[backend] = result del hidden_states, position_ids torch.cuda.empty_cache() else: # Fallback: Use F.scaled_dot_product_attention with real model dimensions print(f"[run_attention_benchmark] Falling back to SDPA with model dimensions") results["fallback_mode"] = True # Create Q, K, V tensors with real model dimensions Q = torch.randn(batch_size, model_num_heads, seq_len, model_head_dim, device=device, dtype=dtype) K = torch.randn(batch_size, model_num_heads, seq_len, model_head_dim, device=device, dtype=dtype) V = torch.randn(batch_size, model_num_heads, seq_len, model_head_dim, device=device, dtype=dtype) backends = [ ("math", True, False, False), ("flash", False, True, False), ("mem_efficient", False, False, True), ] for backend_name, enable_math, enable_flash, enable_mem_efficient in backends: try: torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): # Warmup for _ in range(warmup_iterations): _ = F.scaled_dot_product_attention(Q, K, V) torch.cuda.synchronize() # Timed runs start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(num_iterations): _ = F.scaled_dot_product_attention(Q, K, V) end.record() torch.cuda.synchronize() time_ms = start.elapsed_time(end) / num_iterations memory_mb = torch.cuda.max_memory_allocated() / 1e6 results[backend_name] = { "time_ms": round(time_ms, 3), "memory_mb": round(memory_mb, 1), "status": "success" } except Exception as e: results[backend_name] = { "time_ms": None, "memory_mb": None, "status": f"error: {str(e)[:50]}" } del Q, K, V torch.cuda.empty_cache() # Calculate speedups if results.get("math", {}).get("time_ms"): base_time = results["math"]["time_ms"] for backend in ["math", "flash", "mem_efficient"]: if results.get(backend, {}).get("time_ms"): results[backend]["speedup"] = round(base_time / results[backend]["time_ms"], 2) return results except Exception as e: return {"error": f"Failed to load model: {str(e)[:100]}"} # Legacy mode: Use raw SDPA with random tensors (fallback) results = {"using_real_model": False} # Create input tensors Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) K = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) V = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) # Test each backend backends = [ ("math", True, False, False), ("flash", False, True, False), ("mem_efficient", False, False, True), ] for backend_name, enable_math, enable_flash, enable_mem_efficient in backends: try: torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() with torch.backends.cuda.sdp_kernel( enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient ): # Warmup for _ in range(warmup_iterations): _ = F.scaled_dot_product_attention(Q, K, V) torch.cuda.synchronize() # Timed runs start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(num_iterations): _ = F.scaled_dot_product_attention(Q, K, V) end.record() torch.cuda.synchronize() time_ms = start.elapsed_time(end) / num_iterations memory_mb = torch.cuda.max_memory_allocated() / 1e6 results[backend_name] = { "time_ms": round(time_ms, 3), "memory_mb": round(memory_mb, 1), "status": "success" } except Exception as e: results[backend_name] = { "time_ms": None, "memory_mb": None, "status": f"error: {str(e)[:50]}" } # Calculate speedups relative to math backend if results.get("math", {}).get("time_ms"): base_time = results["math"]["time_ms"] for backend in results: if isinstance(results[backend], dict) and results[backend].get("time_ms"): results[backend]["speedup"] = round(base_time / results[backend]["time_ms"], 2) # Clean up del Q, K, V torch.cuda.empty_cache() return results def run_scaling_benchmark( model_name: str = None, seq_lengths: list = None, batch_size: int = 1, # Legacy parameters (used if model_name is None) num_heads: int = 16, head_dim: int = 64, ) -> dict: """ Benchmark attention backends across multiple sequence lengths using a real model. Args: model_name: Name of the model from MODEL_CONFIGS (e.g., "SmolLM2-360M") seq_lengths: List of sequence lengths to test batch_size: Batch size num_heads: (Legacy) Number of attention heads if model_name is None head_dim: (Legacy) Dimension per head if model_name is None Returns: Dict with arrays of timing and memory results for each backend """ if seq_lengths is None: seq_lengths = [512, 1024, 2048, 4096] if not torch.cuda.is_available(): return {"error": "CUDA not available"} results = { "seq_lengths": seq_lengths, "model_name": model_name, "math": {"time_ms": [], "memory_mb": []}, "flash": {"time_ms": [], "memory_mb": []}, "mem_efficient": {"time_ms": [], "memory_mb": []}, } for seq_len in seq_lengths: bench_result = run_attention_benchmark( model_name=model_name, seq_len=seq_len, batch_size=batch_size, num_iterations=5, # Fewer iterations for scaling test warmup_iterations=2, # Legacy params (ignored if model_name is set) num_heads=num_heads, head_dim=head_dim, ) for backend in ["math", "flash", "mem_efficient"]: if bench_result.get(backend, {}).get("time_ms"): results[backend]["time_ms"].append(bench_result[backend]["time_ms"]) results[backend]["memory_mb"].append(bench_result[backend]["memory_mb"]) else: results[backend]["time_ms"].append(None) results[backend]["memory_mb"].append(None) return results def create_benchmark_results_table(results: dict) -> str: """Create a markdown table from benchmark results.""" if "error" in results: return f"**Error:** {results['error']}" # Build table lines = [ "| Backend | Time (ms) | Memory (MB) | Speedup |", "|---------|-----------|-------------|---------|", ] for backend in ["math", "flash", "mem_efficient"]: if backend in results: r = results[backend] name = ATTENTION_BACKENDS.get(backend, backend) time_str = f"{r['time_ms']:.2f}" if r.get('time_ms') else "N/A" mem_str = f"{r['memory_mb']:.0f}" if r.get('memory_mb') else "N/A" speedup_str = f"{r.get('speedup', 1.0):.1f}×" lines.append(f"| {name} | {time_str} | {mem_str} | {speedup_str} |") return "\n".join(lines) def create_benchmark_insight(results: dict) -> str: """Create insight text from benchmark results.""" if "error" in results: return "" flash = results.get("flash", {}) math = results.get("math", {}) if not flash.get("time_ms") or not math.get("time_ms"): return "**Note:** Some backends may not be available on this GPU." speedup = math["time_ms"] / flash["time_ms"] mem_reduction = math["memory_mb"] / flash["memory_mb"] if flash["memory_mb"] > 0 else 1 return f"""**Key Insight:** FlashAttention is **{speedup:.1f}× faster** and uses **{mem_reduction:.1f}× less memory**! This improvement comes from: - Tiling attention into SRAM-sized blocks - Never materializing the full N×N attention matrix in HBM - Fused kernel avoiding multiple HBM round-trips""" def create_scaling_chart(results: dict) -> go.Figure: """Create a scaling chart showing time and memory vs sequence length.""" if "error" in results: fig = go.Figure() fig.add_annotation( x=0.5, y=0.5, text=f"Error: {results['error']}", showarrow=False, font=dict(size=16, color="red") ) return fig seq_lengths = results["seq_lengths"] # Create subplot with two y-axes fig = make_subplots( rows=1, cols=2, subplot_titles=("Execution Time", "Peak Memory"), horizontal_spacing=0.12, ) colors = { "math": "rgba(239, 68, 68, 0.8)", # Red "flash": "rgba(34, 197, 94, 0.8)", # Green "mem_efficient": "rgba(59, 130, 246, 0.8)", # Blue } # Plot time for backend in ["math", "flash", "mem_efficient"]: times = results[backend]["time_ms"] name = ATTENTION_BACKENDS.get(backend, backend) # Filter out None values valid_points = [(s, t) for s, t in zip(seq_lengths, times) if t is not None] if valid_points: x_vals, y_vals = zip(*valid_points) fig.add_trace( go.Scatter( x=list(x_vals), y=list(y_vals), mode="lines+markers", name=name, line=dict(color=colors[backend], width=2), marker=dict(size=8), legendgroup=backend, ), row=1, col=1 ) # Plot memory for backend in ["math", "flash", "mem_efficient"]: memory = results[backend]["memory_mb"] name = ATTENTION_BACKENDS.get(backend, backend) valid_points = [(s, m) for s, m in zip(seq_lengths, memory) if m is not None] if valid_points: x_vals, y_vals = zip(*valid_points) fig.add_trace( go.Scatter( x=list(x_vals), y=list(y_vals), mode="lines+markers", name=name, line=dict(color=colors[backend], width=2), marker=dict(size=8), legendgroup=backend, showlegend=False, ), row=1, col=2 ) fig.update_xaxes(title_text="Sequence Length", row=1, col=1) fig.update_xaxes(title_text="Sequence Length", row=1, col=2) fig.update_yaxes(title_text="Time (ms)", row=1, col=1) fig.update_yaxes(title_text="Memory (MB)", row=1, col=2) fig.update_layout( height=350, margin=dict(l=50, r=50, t=50, b=50), legend=dict( orientation="h", yanchor="bottom", y=-0.3, xanchor="center", x=0.5 ), ) return fig def calculate_attention_flops(seq_len: int, num_heads: int, head_dim: int, batch_size: int = 1) -> float: """ Calculate FLOPs for scaled dot-product attention. FLOPs breakdown: - Q @ K^T: 2 * batch * heads * seq * seq * head_dim - Softmax: ~5 * batch * heads * seq * seq (exp, sum, div) - P @ V: 2 * batch * heads * seq * seq * head_dim Total: ~4 * batch * heads * seq² * head_dim + 5 * batch * heads * seq² """ qk_flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim softmax_flops = 5 * batch_size * num_heads * seq_len * seq_len pv_flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim return qk_flops + softmax_flops + pv_flops def calculate_memory_traffic( seq_len: int, num_heads: int, head_dim: int, batch_size: int = 1, is_flash: bool = False, dtype_bytes: int = 2, # FP16 ) -> float: """ Calculate memory traffic in bytes for attention. Standard Attention: - Read Q, K, V: 3 * batch * heads * seq * head_dim * dtype_bytes - Write S = Q @ K^T: batch * heads * seq * seq * dtype_bytes - Read S for softmax: batch * heads * seq * seq * dtype_bytes - Write P = softmax(S): batch * heads * seq * seq * dtype_bytes - Read P and V: batch * heads * seq * seq + batch * heads * seq * head_dim - Write O: batch * heads * seq * head_dim * dtype_bytes FlashAttention: - Read Q, K, V once: 3 * batch * heads * seq * head_dim * dtype_bytes - Write O once: batch * heads * seq * head_dim * dtype_bytes - No attention matrix written to HBM! """ qkv_size = 3 * batch_size * num_heads * seq_len * head_dim * dtype_bytes output_size = batch_size * num_heads * seq_len * head_dim * dtype_bytes if is_flash: # FlashAttention: Only Q, K, V reads + O write return qkv_size + output_size else: # Standard: Also materializes attention matrix (read + write twice) attention_matrix_size = batch_size * num_heads * seq_len * seq_len * dtype_bytes return qkv_size + output_size + 3 * attention_matrix_size def calculate_roofline_metrics( results: dict, seq_len: int, num_heads: int, head_dim: int, batch_size: int = 1, ) -> dict: """ Calculate arithmetic intensity and achieved TFLOPS from benchmark results. Returns dict with measured metrics for each backend. """ flops = calculate_attention_flops(seq_len, num_heads, head_dim, batch_size) metrics = {} for backend in ["math", "flash", "mem_efficient"]: if backend not in results or results[backend].get("time_ms") is None: continue time_ms = results[backend]["time_ms"] time_s = time_ms / 1000.0 # Calculate achieved TFLOPS achieved_tflops = (flops / time_s) / 1e12 # Calculate memory traffic (approximation) is_flash = backend in ["flash", "mem_efficient"] memory_bytes = calculate_memory_traffic( seq_len, num_heads, head_dim, batch_size, is_flash=is_flash ) # Arithmetic intensity = FLOPs / bytes arith_intensity = flops / memory_bytes metrics[backend] = { "flops": flops, "memory_bytes": memory_bytes, "time_ms": time_ms, "achieved_tflops": achieved_tflops, "arith_intensity": arith_intensity, } return metrics def create_roofline_chart( results: dict, gpu_specs: dict = None, benchmark_metrics: dict = None, ) -> go.Figure: """ Create a roofline chart showing where different attention implementations fall. The roofline model shows: - X-axis: Arithmetic intensity (FLOPs per byte of memory traffic) - Y-axis: Performance (TFLOPS) - The roofline is min(peak_compute, bandwidth * intensity) Args: results: Benchmark results dict (can be empty) gpu_specs: GPU specifications dict (from detect_gpu() or GPU_SPECS) benchmark_metrics: Roofline metrics from calculate_roofline_metrics() If benchmark_metrics is provided, plots MEASURED values. Otherwise, plots theoretical approximations. """ # Use provided specs or default to A10G if gpu_specs is None: gpu = GPU_SPECS[DEFAULT_GPU] else: gpu = gpu_specs peak_tflops = gpu["tflops_fp16"] bandwidth_gbps = gpu["bandwidth_gbps"] # Ridge point: where memory-bound meets compute-bound ridge_point = (peak_tflops * 1e12) / (bandwidth_gbps * 1e9) # Create figure fig = go.Figure() # Roofline curve x_range = np.logspace(0, 3, 100) y_roofline = np.minimum( peak_tflops, bandwidth_gbps * x_range / 1000 ) fig.add_trace(go.Scatter( x=x_range, y=y_roofline, mode="lines", name="Roofline", line=dict(color="rgba(0, 0, 0, 0.6)", width=2), )) # Memory-bound region (dashed) fig.add_trace(go.Scatter( x=[1, ridge_point], y=[bandwidth_gbps / 1000, peak_tflops], mode="lines", name="Memory Bound", line=dict(color="rgba(239, 68, 68, 0.5)", width=3, dash="dash"), )) # Compute-bound region (dashed) fig.add_trace(go.Scatter( x=[ridge_point, 1000], y=[peak_tflops, peak_tflops], mode="lines", name="Compute Bound", line=dict(color="rgba(34, 197, 94, 0.5)", width=3, dash="dash"), )) # Determine if we have measured data or should use theoretical use_measured = benchmark_metrics is not None and len(benchmark_metrics) > 0 if use_measured: # Plot MEASURED data points title_suffix = " (Measured)" # Math/Standard backend if "math" in benchmark_metrics: m = benchmark_metrics["math"] fig.add_trace(go.Scatter( x=[m["arith_intensity"]], y=[m["achieved_tflops"]], mode="markers", name=f"Math ({m['achieved_tflops']:.1f} TFLOPS, {m['time_ms']:.1f}ms)", marker=dict(size=16, color="#dc2626", symbol="circle", line=dict(color="white", width=2)), )) # Add label as annotation for better visibility fig.add_annotation( x=np.log10(m["arith_intensity"]), y=m["achieved_tflops"], text=f"Math
{m['time_ms']:.1f}ms", showarrow=True, arrowhead=2, arrowsize=1, arrowwidth=1, arrowcolor="#dc2626", ax=0, ay=-40, font=dict(size=10, color="#dc2626"), bgcolor="rgba(255, 255, 255, 0.95)", bordercolor="#dc2626", borderwidth=1, borderpad=3, ) # Flash backend if "flash" in benchmark_metrics: m = benchmark_metrics["flash"] fig.add_trace(go.Scatter( x=[m["arith_intensity"]], y=[m["achieved_tflops"]], mode="markers", name=f"Flash ({m['achieved_tflops']:.1f} TFLOPS, {m['time_ms']:.1f}ms)", marker=dict(size=16, color="#16a34a", symbol="circle", line=dict(color="white", width=2)), )) fig.add_annotation( x=np.log10(m["arith_intensity"]), y=m["achieved_tflops"], text=f"Flash
{m['time_ms']:.1f}ms", showarrow=True, arrowhead=2, arrowsize=1, arrowwidth=1, arrowcolor="#16a34a", ax=0, ay=-40, font=dict(size=10, color="#16a34a"), bgcolor="rgba(255, 255, 255, 0.95)", bordercolor="#16a34a", borderwidth=1, borderpad=3, ) # Memory-efficient backend if "mem_efficient" in benchmark_metrics: m = benchmark_metrics["mem_efficient"] fig.add_trace(go.Scatter( x=[m["arith_intensity"]], y=[m["achieved_tflops"]], mode="markers", name=f"MemEff ({m['achieved_tflops']:.1f} TFLOPS, {m['time_ms']:.1f}ms)", marker=dict(size=16, color="#2563eb", symbol="circle", line=dict(color="white", width=2)), )) fig.add_annotation( x=np.log10(m["arith_intensity"]), y=m["achieved_tflops"], text=f"MemEff
{m['time_ms']:.1f}ms", showarrow=True, arrowhead=2, arrowsize=1, arrowwidth=1, arrowcolor="#2563eb", ax=30, # Offset to avoid overlap ay=-30, font=dict(size=10, color="#2563eb"), bgcolor="rgba(255, 255, 255, 0.95)", bordercolor="#2563eb", borderwidth=1, borderpad=3, ) else: # Plot THEORETICAL approximations title_suffix = " (Theoretical)" # Standard attention - memory bound std_intensity = 10 std_achieved = min(peak_tflops * 0.15, bandwidth_gbps * std_intensity / 1000) fig.add_trace(go.Scatter( x=[std_intensity], y=[std_achieved], mode="markers", name="Standard (Theoretical)", marker=dict(size=15, color="rgba(220, 38, 38, 0.6)", symbol="circle-open", line=dict(width=2)), )) fig.add_annotation( x=np.log10(std_intensity), y=std_achieved, text="Standard
(theoretical)", showarrow=True, arrowhead=2, ax=0, ay=-35, font=dict(size=10, color="#dc2626"), bgcolor="rgba(255, 255, 255, 0.9)", bordercolor="rgba(220, 38, 38, 0.5)", borderwidth=1, borderpad=3, ) # FlashAttention - compute bound flash_intensity = 200 flash_achieved = min(peak_tflops * 0.7, bandwidth_gbps * flash_intensity / 1000) fig.add_trace(go.Scatter( x=[flash_intensity], y=[flash_achieved], mode="markers", name="Flash (Theoretical)", marker=dict(size=15, color="rgba(22, 163, 74, 0.6)", symbol="circle-open", line=dict(width=2)), )) fig.add_annotation( x=np.log10(flash_intensity), y=flash_achieved, text="FlashAttention
(theoretical)", showarrow=True, arrowhead=2, ax=0, ay=-35, font=dict(size=10, color="#16a34a"), bgcolor="rgba(255, 255, 255, 0.9)", bordercolor="rgba(22, 163, 74, 0.5)", borderwidth=1, borderpad=3, ) # Add ridge point marker fig.add_trace(go.Scatter( x=[ridge_point], y=[peak_tflops], mode="markers", name=f"Ridge Point ({ridge_point:.0f} FLOPs/byte)", marker=dict(size=10, color="rgba(0, 0, 0, 0.6)", symbol="diamond"), )) # Add annotations with better visibility (white background) fig.add_annotation( x=np.log10(5), y=peak_tflops * 0.1, text="Memory Bound
(limited by bandwidth)", showarrow=False, font=dict(size=11, color="#dc2626"), # Solid red bgcolor="rgba(255, 255, 255, 0.9)", bordercolor="#dc2626", borderwidth=1, borderpad=4, ) fig.add_annotation( x=np.log10(300), y=peak_tflops * 0.65, text="Compute Bound
(limited by TFLOPS)", showarrow=False, font=dict(size=11, color="#16a34a"), # Solid green bgcolor="rgba(255, 255, 255, 0.9)", bordercolor="#16a34a", borderwidth=1, borderpad=4, ) # Use detected_name if available, otherwise use name display_name = gpu.get("detected_name", gpu.get("name", "GPU")) # Add estimated indicator if specs were estimated estimated_note = " (estimated specs)" if gpu.get("estimated") else "" fig.update_layout( title=dict( text=f"Roofline Model: {display_name}{title_suffix}{estimated_note}
" f"" f"Peak: {peak_tflops} TFLOPS | Bandwidth: {bandwidth_gbps} GB/s", x=0.5, font=dict(size=14), ), xaxis=dict( title="Arithmetic Intensity (FLOPs/byte)", type="log", range=[0, 3], ), yaxis=dict( title="Performance (TFLOPS)", range=[0, peak_tflops * 1.2], # More headroom for text ), height=420, margin=dict(l=60, r=40, t=80, b=80), # More room for title and legend legend=dict( orientation="h", yanchor="bottom", y=-0.30, xanchor="center", x=0.5, font=dict(size=10), ), showlegend=True, ) return fig def get_roofline_insight(benchmark_metrics: dict = None) -> str: """Return insight text for the roofline chart.""" base_insight = """**Why FlashAttention is Faster:** The roofline model reveals the key insight: 1. **Standard Attention** sits in the **memory-bound** region (left of ridge point) - Limited by HBM bandwidth, not compute - Reading/writing the N×N attention matrix dominates runtime 2. **FlashAttention** moves to the **compute-bound** region (right of ridge point) - By never materializing the full attention matrix - Arithmetic intensity increases ~20-50× - Can now utilize most of the GPU's TFLOPS *The same FLOPs, but 10× less memory traffic = faster execution!*""" if benchmark_metrics and "math" in benchmark_metrics and "flash" in benchmark_metrics: math_m = benchmark_metrics["math"] flash_m = benchmark_metrics["flash"] speedup = math_m["time_ms"] / flash_m["time_ms"] intensity_ratio = flash_m["arith_intensity"] / math_m["arith_intensity"] measured_insight = f""" --- **📊 Measured Results:** - **Math backend:** {math_m['achieved_tflops']:.1f} TFLOPS @ {math_m['arith_intensity']:.0f} FLOPs/byte - **Flash backend:** {flash_m['achieved_tflops']:.1f} TFLOPS @ {flash_m['arith_intensity']:.0f} FLOPs/byte - **Speedup:** {speedup:.1f}× faster - **Intensity increase:** {intensity_ratio:.0f}× higher arithmetic intensity""" return base_insight + measured_insight return base_insight + "\n\n*Run a benchmark to see measured values on the chart!*"