Spaces:
Sleeping
Sleeping
| """ | |
| Benchmark module for FlashAttention Explorer. | |
| GPU benchmark functions for comparing attention backends using real HuggingFace models. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| from .constants import GPU_SPECS, ATTENTION_BACKENDS, MODEL_CONFIGS, DEFAULT_GPU, DEFAULT_MODEL | |
| from .models import load_model, clear_model_cache | |
| from .attention_utils import ( | |
| extract_attention_layer, | |
| create_attention_inputs, | |
| benchmark_attention_layer, | |
| get_model_attention_info, | |
| ) | |
| def detect_gpu() -> dict: | |
| """ | |
| Detect the actual GPU and return its specs. | |
| Returns: | |
| Dict with GPU name and specs | |
| """ | |
| if not torch.cuda.is_available(): | |
| return {"name": "CPU (No GPU)", "detected": False, **GPU_SPECS[DEFAULT_GPU]} | |
| gpu_name_raw = torch.cuda.get_device_name(0) | |
| gpu_name = gpu_name_raw.lower() | |
| # Get memory in GB for dynamic spec estimation | |
| try: | |
| mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| except Exception: | |
| mem_gb = 24 # fallback | |
| # Match against known GPUs (ordered from newest to oldest) | |
| if "h200" in gpu_name: | |
| # H200 specs - HBM3e memory, very high bandwidth | |
| return { | |
| "detected": True, | |
| "detected_name": gpu_name_raw, | |
| "name": "NVIDIA H200", | |
| "tflops_fp16": 989, # Same compute as H100 | |
| "bandwidth_gbps": 4800, # HBM3e: 4.8 TB/s | |
| "memory_gb": round(mem_gb), | |
| "sram_kb": 256, | |
| } | |
| elif "h100" in gpu_name: | |
| return {"detected": True, "detected_name": gpu_name_raw, **GPU_SPECS["H100"]} | |
| elif "a100" in gpu_name: | |
| return {"detected": True, "detected_name": gpu_name_raw, **GPU_SPECS["A100_80GB"]} | |
| elif "a10" in gpu_name: | |
| return {"detected": True, "detected_name": gpu_name_raw, **GPU_SPECS["A10G"]} | |
| elif "l40" in gpu_name: | |
| # L40S specs | |
| return { | |
| "detected": True, | |
| "detected_name": gpu_name_raw, | |
| "name": "NVIDIA L40S", | |
| "tflops_fp16": 362, | |
| "bandwidth_gbps": 864, | |
| "memory_gb": round(mem_gb), | |
| "sram_kb": 192, | |
| } | |
| elif "l4" in gpu_name: | |
| # L4 specs | |
| return { | |
| "detected": True, | |
| "detected_name": gpu_name_raw, | |
| "name": "NVIDIA L4", | |
| "tflops_fp16": 121, | |
| "bandwidth_gbps": 300, | |
| "memory_gb": round(mem_gb), | |
| "sram_kb": 96, | |
| } | |
| elif "t4" in gpu_name: | |
| return { | |
| "detected": True, | |
| "detected_name": gpu_name_raw, | |
| "name": "NVIDIA T4", | |
| "tflops_fp16": 65, | |
| "bandwidth_gbps": 320, | |
| "memory_gb": round(mem_gb), | |
| "sram_kb": 64, | |
| } | |
| elif "v100" in gpu_name: | |
| return { | |
| "detected": True, | |
| "detected_name": gpu_name_raw, | |
| "name": "NVIDIA V100", | |
| "tflops_fp16": 125, | |
| "bandwidth_gbps": 900, | |
| "memory_gb": round(mem_gb), | |
| "sram_kb": 128, | |
| } | |
| elif "rtx 4090" in gpu_name or "4090" in gpu_name: | |
| return { | |
| "detected": True, | |
| "detected_name": gpu_name_raw, | |
| "name": "NVIDIA RTX 4090", | |
| "tflops_fp16": 330, | |
| "bandwidth_gbps": 1008, | |
| "memory_gb": round(mem_gb), | |
| "sram_kb": 128, | |
| } | |
| else: | |
| # Unknown GPU - estimate specs using compute capability and SM count | |
| # These are the best indicators of performance we can query | |
| try: | |
| props = torch.cuda.get_device_properties(0) | |
| sm_count = props.multi_processor_count | |
| major, minor = torch.cuda.get_device_capability(0) | |
| # FP16 FLOPs per SM per cycle varies by architecture | |
| # Ampere (8.x): 256 FP16 ops/SM/cycle, Hopper (9.x): 512 | |
| # Clock speed ~1.5-2 GHz typically | |
| if major >= 9: # Hopper/Ada | |
| flops_per_sm = 512 | |
| clock_ghz = 1.8 | |
| bw_per_gb_mem = 50 # Rough: HBM3 ~50 GB/s per GB capacity | |
| elif major >= 8: # Ampere | |
| flops_per_sm = 256 | |
| clock_ghz = 1.5 | |
| bw_per_gb_mem = 25 # HBM2e | |
| elif major >= 7: # Volta/Turing | |
| flops_per_sm = 128 | |
| clock_ghz = 1.4 | |
| bw_per_gb_mem = 28 | |
| else: # Older | |
| flops_per_sm = 64 | |
| clock_ghz = 1.2 | |
| bw_per_gb_mem = 20 | |
| # Estimate TFLOPS: SMs × FLOPs/SM/cycle × clock × 2 (FMA) | |
| est_tflops = (sm_count * flops_per_sm * clock_ghz * 2) / 1000 | |
| est_bw = mem_gb * bw_per_gb_mem | |
| except Exception: | |
| # Fallback if properties query fails | |
| est_tflops = 125 | |
| est_bw = 600 | |
| return { | |
| "detected": True, | |
| "detected_name": gpu_name_raw, | |
| "name": gpu_name_raw, | |
| "tflops_fp16": round(est_tflops), | |
| "bandwidth_gbps": round(est_bw), | |
| "memory_gb": round(mem_gb), | |
| "sram_kb": 128, | |
| "estimated": True, # Flag that these are estimated from compute capability | |
| "compute_capability": f"{major}.{minor}" if 'major' in dir() else "unknown", | |
| } | |
| def run_attention_benchmark( | |
| model_name: str = None, | |
| seq_len: int = 1024, | |
| batch_size: int = 1, | |
| num_iterations: int = 10, | |
| warmup_iterations: int = 3, | |
| # Legacy parameters (used if model_name is None) | |
| num_heads: int = 16, | |
| head_dim: int = 64, | |
| ) -> dict: | |
| """ | |
| Benchmark three SDPA backends using a real HuggingFace model's attention layer. | |
| Args: | |
| model_name: Name of the model from MODEL_CONFIGS (e.g., "SmolLM2-360M") | |
| If None, falls back to legacy random tensor mode | |
| seq_len: Sequence length (number of tokens) | |
| batch_size: Batch size | |
| num_iterations: Number of timed iterations | |
| warmup_iterations: Number of warmup iterations | |
| num_heads: (Legacy) Number of attention heads if model_name is None | |
| head_dim: (Legacy) Dimension per head if model_name is None | |
| Returns: | |
| Dict with timing and memory results per backend | |
| """ | |
| if not torch.cuda.is_available(): | |
| return {"error": "CUDA not available"} | |
| device = torch.device("cuda") | |
| dtype = torch.float16 | |
| # If model_name is provided, use real model dimensions for benchmarking | |
| if model_name is not None and model_name in MODEL_CONFIGS: | |
| try: | |
| # Load the real HuggingFace model | |
| model = load_model(model_name) | |
| # Get model attention info for real dimensions | |
| attn_info = get_model_attention_info(model) | |
| # Extract dimensions from real model | |
| model_num_heads = attn_info["num_attention_heads"] | |
| model_head_dim = attn_info["head_dim"] | |
| results = {"model_name": model_name, "using_real_model": True} | |
| results["model_info"] = attn_info | |
| # First try: Use actual attention layer forward pass | |
| attention_layer_works = False | |
| try: | |
| attention_layer = extract_attention_layer(model, layer_idx=0) | |
| hidden_states, position_ids = create_attention_inputs( | |
| model, batch_size, seq_len, device, dtype | |
| ) | |
| # Test if attention layer works with first backend | |
| test_result = benchmark_attention_layer( | |
| attention_layer=attention_layer, | |
| hidden_states=hidden_states, | |
| position_ids=position_ids, | |
| backend="flash", | |
| num_iterations=2, | |
| warmup_iterations=1, | |
| ) | |
| if test_result.get("time_ms") is not None: | |
| attention_layer_works = True | |
| del hidden_states, position_ids | |
| torch.cuda.empty_cache() | |
| except Exception as layer_error: | |
| print(f"[run_attention_benchmark] Attention layer extraction failed: {layer_error}") | |
| attention_layer_works = False | |
| if attention_layer_works: | |
| # Use actual attention layer | |
| hidden_states, position_ids = create_attention_inputs( | |
| model, batch_size, seq_len, device, dtype | |
| ) | |
| for backend in ["math", "flash", "mem_efficient"]: | |
| result = benchmark_attention_layer( | |
| attention_layer=attention_layer, | |
| hidden_states=hidden_states, | |
| position_ids=position_ids, | |
| backend=backend, | |
| num_iterations=num_iterations, | |
| warmup_iterations=warmup_iterations, | |
| ) | |
| results[backend] = result | |
| del hidden_states, position_ids | |
| torch.cuda.empty_cache() | |
| else: | |
| # Fallback: Use F.scaled_dot_product_attention with real model dimensions | |
| print(f"[run_attention_benchmark] Falling back to SDPA with model dimensions") | |
| results["fallback_mode"] = True | |
| # Create Q, K, V tensors with real model dimensions | |
| Q = torch.randn(batch_size, model_num_heads, seq_len, model_head_dim, device=device, dtype=dtype) | |
| K = torch.randn(batch_size, model_num_heads, seq_len, model_head_dim, device=device, dtype=dtype) | |
| V = torch.randn(batch_size, model_num_heads, seq_len, model_head_dim, device=device, dtype=dtype) | |
| backends = [ | |
| ("math", True, False, False), | |
| ("flash", False, True, False), | |
| ("mem_efficient", False, False, True), | |
| ] | |
| for backend_name, enable_math, enable_flash, enable_mem_efficient in backends: | |
| try: | |
| torch.cuda.reset_peak_memory_stats() | |
| torch.cuda.synchronize() | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash, | |
| enable_math=enable_math, | |
| enable_mem_efficient=enable_mem_efficient | |
| ): | |
| # Warmup | |
| for _ in range(warmup_iterations): | |
| _ = F.scaled_dot_product_attention(Q, K, V) | |
| torch.cuda.synchronize() | |
| # Timed runs | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| for _ in range(num_iterations): | |
| _ = F.scaled_dot_product_attention(Q, K, V) | |
| end.record() | |
| torch.cuda.synchronize() | |
| time_ms = start.elapsed_time(end) / num_iterations | |
| memory_mb = torch.cuda.max_memory_allocated() / 1e6 | |
| results[backend_name] = { | |
| "time_ms": round(time_ms, 3), | |
| "memory_mb": round(memory_mb, 1), | |
| "status": "success" | |
| } | |
| except Exception as e: | |
| results[backend_name] = { | |
| "time_ms": None, | |
| "memory_mb": None, | |
| "status": f"error: {str(e)[:50]}" | |
| } | |
| del Q, K, V | |
| torch.cuda.empty_cache() | |
| # Calculate speedups | |
| if results.get("math", {}).get("time_ms"): | |
| base_time = results["math"]["time_ms"] | |
| for backend in ["math", "flash", "mem_efficient"]: | |
| if results.get(backend, {}).get("time_ms"): | |
| results[backend]["speedup"] = round(base_time / results[backend]["time_ms"], 2) | |
| return results | |
| except Exception as e: | |
| return {"error": f"Failed to load model: {str(e)[:100]}"} | |
| # Legacy mode: Use raw SDPA with random tensors (fallback) | |
| results = {"using_real_model": False} | |
| # Create input tensors | |
| Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) | |
| K = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) | |
| V = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) | |
| # Test each backend | |
| backends = [ | |
| ("math", True, False, False), | |
| ("flash", False, True, False), | |
| ("mem_efficient", False, False, True), | |
| ] | |
| for backend_name, enable_math, enable_flash, enable_mem_efficient in backends: | |
| try: | |
| torch.cuda.reset_peak_memory_stats() | |
| torch.cuda.synchronize() | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=enable_flash, | |
| enable_math=enable_math, | |
| enable_mem_efficient=enable_mem_efficient | |
| ): | |
| # Warmup | |
| for _ in range(warmup_iterations): | |
| _ = F.scaled_dot_product_attention(Q, K, V) | |
| torch.cuda.synchronize() | |
| # Timed runs | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| for _ in range(num_iterations): | |
| _ = F.scaled_dot_product_attention(Q, K, V) | |
| end.record() | |
| torch.cuda.synchronize() | |
| time_ms = start.elapsed_time(end) / num_iterations | |
| memory_mb = torch.cuda.max_memory_allocated() / 1e6 | |
| results[backend_name] = { | |
| "time_ms": round(time_ms, 3), | |
| "memory_mb": round(memory_mb, 1), | |
| "status": "success" | |
| } | |
| except Exception as e: | |
| results[backend_name] = { | |
| "time_ms": None, | |
| "memory_mb": None, | |
| "status": f"error: {str(e)[:50]}" | |
| } | |
| # Calculate speedups relative to math backend | |
| if results.get("math", {}).get("time_ms"): | |
| base_time = results["math"]["time_ms"] | |
| for backend in results: | |
| if isinstance(results[backend], dict) and results[backend].get("time_ms"): | |
| results[backend]["speedup"] = round(base_time / results[backend]["time_ms"], 2) | |
| # Clean up | |
| del Q, K, V | |
| torch.cuda.empty_cache() | |
| return results | |
| def run_scaling_benchmark( | |
| model_name: str = None, | |
| seq_lengths: list = None, | |
| batch_size: int = 1, | |
| # Legacy parameters (used if model_name is None) | |
| num_heads: int = 16, | |
| head_dim: int = 64, | |
| ) -> dict: | |
| """ | |
| Benchmark attention backends across multiple sequence lengths using a real model. | |
| Args: | |
| model_name: Name of the model from MODEL_CONFIGS (e.g., "SmolLM2-360M") | |
| seq_lengths: List of sequence lengths to test | |
| batch_size: Batch size | |
| num_heads: (Legacy) Number of attention heads if model_name is None | |
| head_dim: (Legacy) Dimension per head if model_name is None | |
| Returns: | |
| Dict with arrays of timing and memory results for each backend | |
| """ | |
| if seq_lengths is None: | |
| seq_lengths = [512, 1024, 2048, 4096] | |
| if not torch.cuda.is_available(): | |
| return {"error": "CUDA not available"} | |
| results = { | |
| "seq_lengths": seq_lengths, | |
| "model_name": model_name, | |
| "math": {"time_ms": [], "memory_mb": []}, | |
| "flash": {"time_ms": [], "memory_mb": []}, | |
| "mem_efficient": {"time_ms": [], "memory_mb": []}, | |
| } | |
| for seq_len in seq_lengths: | |
| bench_result = run_attention_benchmark( | |
| model_name=model_name, | |
| seq_len=seq_len, | |
| batch_size=batch_size, | |
| num_iterations=5, # Fewer iterations for scaling test | |
| warmup_iterations=2, | |
| # Legacy params (ignored if model_name is set) | |
| num_heads=num_heads, | |
| head_dim=head_dim, | |
| ) | |
| for backend in ["math", "flash", "mem_efficient"]: | |
| if bench_result.get(backend, {}).get("time_ms"): | |
| results[backend]["time_ms"].append(bench_result[backend]["time_ms"]) | |
| results[backend]["memory_mb"].append(bench_result[backend]["memory_mb"]) | |
| else: | |
| results[backend]["time_ms"].append(None) | |
| results[backend]["memory_mb"].append(None) | |
| return results | |
| def create_benchmark_results_table(results: dict) -> str: | |
| """Create a markdown table from benchmark results.""" | |
| if "error" in results: | |
| return f"**Error:** {results['error']}" | |
| # Build table | |
| lines = [ | |
| "| Backend | Time (ms) | Memory (MB) | Speedup |", | |
| "|---------|-----------|-------------|---------|", | |
| ] | |
| for backend in ["math", "flash", "mem_efficient"]: | |
| if backend in results: | |
| r = results[backend] | |
| name = ATTENTION_BACKENDS.get(backend, backend) | |
| time_str = f"{r['time_ms']:.2f}" if r.get('time_ms') else "N/A" | |
| mem_str = f"{r['memory_mb']:.0f}" if r.get('memory_mb') else "N/A" | |
| speedup_str = f"{r.get('speedup', 1.0):.1f}×" | |
| lines.append(f"| {name} | {time_str} | {mem_str} | {speedup_str} |") | |
| return "\n".join(lines) | |
| def create_benchmark_insight(results: dict) -> str: | |
| """Create insight text from benchmark results.""" | |
| if "error" in results: | |
| return "" | |
| flash = results.get("flash", {}) | |
| math = results.get("math", {}) | |
| if not flash.get("time_ms") or not math.get("time_ms"): | |
| return "**Note:** Some backends may not be available on this GPU." | |
| speedup = math["time_ms"] / flash["time_ms"] | |
| mem_reduction = math["memory_mb"] / flash["memory_mb"] if flash["memory_mb"] > 0 else 1 | |
| return f"""**Key Insight:** | |
| FlashAttention is **{speedup:.1f}× faster** and uses **{mem_reduction:.1f}× less memory**! | |
| This improvement comes from: | |
| - Tiling attention into SRAM-sized blocks | |
| - Never materializing the full N×N attention matrix in HBM | |
| - Fused kernel avoiding multiple HBM round-trips""" | |
| def create_scaling_chart(results: dict) -> go.Figure: | |
| """Create a scaling chart showing time and memory vs sequence length.""" | |
| if "error" in results: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| x=0.5, y=0.5, | |
| text=f"Error: {results['error']}", | |
| showarrow=False, | |
| font=dict(size=16, color="red") | |
| ) | |
| return fig | |
| seq_lengths = results["seq_lengths"] | |
| # Create subplot with two y-axes | |
| fig = make_subplots( | |
| rows=1, cols=2, | |
| subplot_titles=("Execution Time", "Peak Memory"), | |
| horizontal_spacing=0.12, | |
| ) | |
| colors = { | |
| "math": "rgba(239, 68, 68, 0.8)", # Red | |
| "flash": "rgba(34, 197, 94, 0.8)", # Green | |
| "mem_efficient": "rgba(59, 130, 246, 0.8)", # Blue | |
| } | |
| # Plot time | |
| for backend in ["math", "flash", "mem_efficient"]: | |
| times = results[backend]["time_ms"] | |
| name = ATTENTION_BACKENDS.get(backend, backend) | |
| # Filter out None values | |
| valid_points = [(s, t) for s, t in zip(seq_lengths, times) if t is not None] | |
| if valid_points: | |
| x_vals, y_vals = zip(*valid_points) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=list(x_vals), | |
| y=list(y_vals), | |
| mode="lines+markers", | |
| name=name, | |
| line=dict(color=colors[backend], width=2), | |
| marker=dict(size=8), | |
| legendgroup=backend, | |
| ), | |
| row=1, col=1 | |
| ) | |
| # Plot memory | |
| for backend in ["math", "flash", "mem_efficient"]: | |
| memory = results[backend]["memory_mb"] | |
| name = ATTENTION_BACKENDS.get(backend, backend) | |
| valid_points = [(s, m) for s, m in zip(seq_lengths, memory) if m is not None] | |
| if valid_points: | |
| x_vals, y_vals = zip(*valid_points) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=list(x_vals), | |
| y=list(y_vals), | |
| mode="lines+markers", | |
| name=name, | |
| line=dict(color=colors[backend], width=2), | |
| marker=dict(size=8), | |
| legendgroup=backend, | |
| showlegend=False, | |
| ), | |
| row=1, col=2 | |
| ) | |
| fig.update_xaxes(title_text="Sequence Length", row=1, col=1) | |
| fig.update_xaxes(title_text="Sequence Length", row=1, col=2) | |
| fig.update_yaxes(title_text="Time (ms)", row=1, col=1) | |
| fig.update_yaxes(title_text="Memory (MB)", row=1, col=2) | |
| fig.update_layout( | |
| height=350, | |
| margin=dict(l=50, r=50, t=50, b=50), | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=-0.3, | |
| xanchor="center", | |
| x=0.5 | |
| ), | |
| ) | |
| return fig | |
| def calculate_attention_flops(seq_len: int, num_heads: int, head_dim: int, batch_size: int = 1) -> float: | |
| """ | |
| Calculate FLOPs for scaled dot-product attention. | |
| FLOPs breakdown: | |
| - Q @ K^T: 2 * batch * heads * seq * seq * head_dim | |
| - Softmax: ~5 * batch * heads * seq * seq (exp, sum, div) | |
| - P @ V: 2 * batch * heads * seq * seq * head_dim | |
| Total: ~4 * batch * heads * seq² * head_dim + 5 * batch * heads * seq² | |
| """ | |
| qk_flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim | |
| softmax_flops = 5 * batch_size * num_heads * seq_len * seq_len | |
| pv_flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim | |
| return qk_flops + softmax_flops + pv_flops | |
| def calculate_memory_traffic( | |
| seq_len: int, | |
| num_heads: int, | |
| head_dim: int, | |
| batch_size: int = 1, | |
| is_flash: bool = False, | |
| dtype_bytes: int = 2, # FP16 | |
| ) -> float: | |
| """ | |
| Calculate memory traffic in bytes for attention. | |
| Standard Attention: | |
| - Read Q, K, V: 3 * batch * heads * seq * head_dim * dtype_bytes | |
| - Write S = Q @ K^T: batch * heads * seq * seq * dtype_bytes | |
| - Read S for softmax: batch * heads * seq * seq * dtype_bytes | |
| - Write P = softmax(S): batch * heads * seq * seq * dtype_bytes | |
| - Read P and V: batch * heads * seq * seq + batch * heads * seq * head_dim | |
| - Write O: batch * heads * seq * head_dim * dtype_bytes | |
| FlashAttention: | |
| - Read Q, K, V once: 3 * batch * heads * seq * head_dim * dtype_bytes | |
| - Write O once: batch * heads * seq * head_dim * dtype_bytes | |
| - No attention matrix written to HBM! | |
| """ | |
| qkv_size = 3 * batch_size * num_heads * seq_len * head_dim * dtype_bytes | |
| output_size = batch_size * num_heads * seq_len * head_dim * dtype_bytes | |
| if is_flash: | |
| # FlashAttention: Only Q, K, V reads + O write | |
| return qkv_size + output_size | |
| else: | |
| # Standard: Also materializes attention matrix (read + write twice) | |
| attention_matrix_size = batch_size * num_heads * seq_len * seq_len * dtype_bytes | |
| return qkv_size + output_size + 3 * attention_matrix_size | |
| def calculate_roofline_metrics( | |
| results: dict, | |
| seq_len: int, | |
| num_heads: int, | |
| head_dim: int, | |
| batch_size: int = 1, | |
| ) -> dict: | |
| """ | |
| Calculate arithmetic intensity and achieved TFLOPS from benchmark results. | |
| Returns dict with measured metrics for each backend. | |
| """ | |
| flops = calculate_attention_flops(seq_len, num_heads, head_dim, batch_size) | |
| metrics = {} | |
| for backend in ["math", "flash", "mem_efficient"]: | |
| if backend not in results or results[backend].get("time_ms") is None: | |
| continue | |
| time_ms = results[backend]["time_ms"] | |
| time_s = time_ms / 1000.0 | |
| # Calculate achieved TFLOPS | |
| achieved_tflops = (flops / time_s) / 1e12 | |
| # Calculate memory traffic (approximation) | |
| is_flash = backend in ["flash", "mem_efficient"] | |
| memory_bytes = calculate_memory_traffic( | |
| seq_len, num_heads, head_dim, batch_size, is_flash=is_flash | |
| ) | |
| # Arithmetic intensity = FLOPs / bytes | |
| arith_intensity = flops / memory_bytes | |
| metrics[backend] = { | |
| "flops": flops, | |
| "memory_bytes": memory_bytes, | |
| "time_ms": time_ms, | |
| "achieved_tflops": achieved_tflops, | |
| "arith_intensity": arith_intensity, | |
| } | |
| return metrics | |
| def create_roofline_chart( | |
| results: dict, | |
| gpu_specs: dict = None, | |
| benchmark_metrics: dict = None, | |
| ) -> go.Figure: | |
| """ | |
| Create a roofline chart showing where different attention implementations fall. | |
| The roofline model shows: | |
| - X-axis: Arithmetic intensity (FLOPs per byte of memory traffic) | |
| - Y-axis: Performance (TFLOPS) | |
| - The roofline is min(peak_compute, bandwidth * intensity) | |
| Args: | |
| results: Benchmark results dict (can be empty) | |
| gpu_specs: GPU specifications dict (from detect_gpu() or GPU_SPECS) | |
| benchmark_metrics: Roofline metrics from calculate_roofline_metrics() | |
| If benchmark_metrics is provided, plots MEASURED values. | |
| Otherwise, plots theoretical approximations. | |
| """ | |
| # Use provided specs or default to A10G | |
| if gpu_specs is None: | |
| gpu = GPU_SPECS[DEFAULT_GPU] | |
| else: | |
| gpu = gpu_specs | |
| peak_tflops = gpu["tflops_fp16"] | |
| bandwidth_gbps = gpu["bandwidth_gbps"] | |
| # Ridge point: where memory-bound meets compute-bound | |
| ridge_point = (peak_tflops * 1e12) / (bandwidth_gbps * 1e9) | |
| # Create figure | |
| fig = go.Figure() | |
| # Roofline curve | |
| x_range = np.logspace(0, 3, 100) | |
| y_roofline = np.minimum( | |
| peak_tflops, | |
| bandwidth_gbps * x_range / 1000 | |
| ) | |
| fig.add_trace(go.Scatter( | |
| x=x_range, | |
| y=y_roofline, | |
| mode="lines", | |
| name="Roofline", | |
| line=dict(color="rgba(0, 0, 0, 0.6)", width=2), | |
| )) | |
| # Memory-bound region (dashed) | |
| fig.add_trace(go.Scatter( | |
| x=[1, ridge_point], | |
| y=[bandwidth_gbps / 1000, peak_tflops], | |
| mode="lines", | |
| name="Memory Bound", | |
| line=dict(color="rgba(239, 68, 68, 0.5)", width=3, dash="dash"), | |
| )) | |
| # Compute-bound region (dashed) | |
| fig.add_trace(go.Scatter( | |
| x=[ridge_point, 1000], | |
| y=[peak_tflops, peak_tflops], | |
| mode="lines", | |
| name="Compute Bound", | |
| line=dict(color="rgba(34, 197, 94, 0.5)", width=3, dash="dash"), | |
| )) | |
| # Determine if we have measured data or should use theoretical | |
| use_measured = benchmark_metrics is not None and len(benchmark_metrics) > 0 | |
| if use_measured: | |
| # Plot MEASURED data points | |
| title_suffix = " (Measured)" | |
| # Math/Standard backend | |
| if "math" in benchmark_metrics: | |
| m = benchmark_metrics["math"] | |
| fig.add_trace(go.Scatter( | |
| x=[m["arith_intensity"]], | |
| y=[m["achieved_tflops"]], | |
| mode="markers", | |
| name=f"Math ({m['achieved_tflops']:.1f} TFLOPS, {m['time_ms']:.1f}ms)", | |
| marker=dict(size=16, color="#dc2626", symbol="circle", | |
| line=dict(color="white", width=2)), | |
| )) | |
| # Add label as annotation for better visibility | |
| fig.add_annotation( | |
| x=np.log10(m["arith_intensity"]), | |
| y=m["achieved_tflops"], | |
| text=f"<b>Math</b><br>{m['time_ms']:.1f}ms", | |
| showarrow=True, | |
| arrowhead=2, | |
| arrowsize=1, | |
| arrowwidth=1, | |
| arrowcolor="#dc2626", | |
| ax=0, | |
| ay=-40, | |
| font=dict(size=10, color="#dc2626"), | |
| bgcolor="rgba(255, 255, 255, 0.95)", | |
| bordercolor="#dc2626", | |
| borderwidth=1, | |
| borderpad=3, | |
| ) | |
| # Flash backend | |
| if "flash" in benchmark_metrics: | |
| m = benchmark_metrics["flash"] | |
| fig.add_trace(go.Scatter( | |
| x=[m["arith_intensity"]], | |
| y=[m["achieved_tflops"]], | |
| mode="markers", | |
| name=f"Flash ({m['achieved_tflops']:.1f} TFLOPS, {m['time_ms']:.1f}ms)", | |
| marker=dict(size=16, color="#16a34a", symbol="circle", | |
| line=dict(color="white", width=2)), | |
| )) | |
| fig.add_annotation( | |
| x=np.log10(m["arith_intensity"]), | |
| y=m["achieved_tflops"], | |
| text=f"<b>Flash</b><br>{m['time_ms']:.1f}ms", | |
| showarrow=True, | |
| arrowhead=2, | |
| arrowsize=1, | |
| arrowwidth=1, | |
| arrowcolor="#16a34a", | |
| ax=0, | |
| ay=-40, | |
| font=dict(size=10, color="#16a34a"), | |
| bgcolor="rgba(255, 255, 255, 0.95)", | |
| bordercolor="#16a34a", | |
| borderwidth=1, | |
| borderpad=3, | |
| ) | |
| # Memory-efficient backend | |
| if "mem_efficient" in benchmark_metrics: | |
| m = benchmark_metrics["mem_efficient"] | |
| fig.add_trace(go.Scatter( | |
| x=[m["arith_intensity"]], | |
| y=[m["achieved_tflops"]], | |
| mode="markers", | |
| name=f"MemEff ({m['achieved_tflops']:.1f} TFLOPS, {m['time_ms']:.1f}ms)", | |
| marker=dict(size=16, color="#2563eb", symbol="circle", | |
| line=dict(color="white", width=2)), | |
| )) | |
| fig.add_annotation( | |
| x=np.log10(m["arith_intensity"]), | |
| y=m["achieved_tflops"], | |
| text=f"<b>MemEff</b><br>{m['time_ms']:.1f}ms", | |
| showarrow=True, | |
| arrowhead=2, | |
| arrowsize=1, | |
| arrowwidth=1, | |
| arrowcolor="#2563eb", | |
| ax=30, # Offset to avoid overlap | |
| ay=-30, | |
| font=dict(size=10, color="#2563eb"), | |
| bgcolor="rgba(255, 255, 255, 0.95)", | |
| bordercolor="#2563eb", | |
| borderwidth=1, | |
| borderpad=3, | |
| ) | |
| else: | |
| # Plot THEORETICAL approximations | |
| title_suffix = " (Theoretical)" | |
| # Standard attention - memory bound | |
| std_intensity = 10 | |
| std_achieved = min(peak_tflops * 0.15, bandwidth_gbps * std_intensity / 1000) | |
| fig.add_trace(go.Scatter( | |
| x=[std_intensity], | |
| y=[std_achieved], | |
| mode="markers", | |
| name="Standard (Theoretical)", | |
| marker=dict(size=15, color="rgba(220, 38, 38, 0.6)", symbol="circle-open", | |
| line=dict(width=2)), | |
| )) | |
| fig.add_annotation( | |
| x=np.log10(std_intensity), | |
| y=std_achieved, | |
| text="<b>Standard</b><br>(theoretical)", | |
| showarrow=True, | |
| arrowhead=2, | |
| ax=0, | |
| ay=-35, | |
| font=dict(size=10, color="#dc2626"), | |
| bgcolor="rgba(255, 255, 255, 0.9)", | |
| bordercolor="rgba(220, 38, 38, 0.5)", | |
| borderwidth=1, | |
| borderpad=3, | |
| ) | |
| # FlashAttention - compute bound | |
| flash_intensity = 200 | |
| flash_achieved = min(peak_tflops * 0.7, bandwidth_gbps * flash_intensity / 1000) | |
| fig.add_trace(go.Scatter( | |
| x=[flash_intensity], | |
| y=[flash_achieved], | |
| mode="markers", | |
| name="Flash (Theoretical)", | |
| marker=dict(size=15, color="rgba(22, 163, 74, 0.6)", symbol="circle-open", | |
| line=dict(width=2)), | |
| )) | |
| fig.add_annotation( | |
| x=np.log10(flash_intensity), | |
| y=flash_achieved, | |
| text="<b>FlashAttention</b><br>(theoretical)", | |
| showarrow=True, | |
| arrowhead=2, | |
| ax=0, | |
| ay=-35, | |
| font=dict(size=10, color="#16a34a"), | |
| bgcolor="rgba(255, 255, 255, 0.9)", | |
| bordercolor="rgba(22, 163, 74, 0.5)", | |
| borderwidth=1, | |
| borderpad=3, | |
| ) | |
| # Add ridge point marker | |
| fig.add_trace(go.Scatter( | |
| x=[ridge_point], | |
| y=[peak_tflops], | |
| mode="markers", | |
| name=f"Ridge Point ({ridge_point:.0f} FLOPs/byte)", | |
| marker=dict(size=10, color="rgba(0, 0, 0, 0.6)", symbol="diamond"), | |
| )) | |
| # Add annotations with better visibility (white background) | |
| fig.add_annotation( | |
| x=np.log10(5), | |
| y=peak_tflops * 0.1, | |
| text="<b>Memory Bound</b><br>(limited by bandwidth)", | |
| showarrow=False, | |
| font=dict(size=11, color="#dc2626"), # Solid red | |
| bgcolor="rgba(255, 255, 255, 0.9)", | |
| bordercolor="#dc2626", | |
| borderwidth=1, | |
| borderpad=4, | |
| ) | |
| fig.add_annotation( | |
| x=np.log10(300), | |
| y=peak_tflops * 0.65, | |
| text="<b>Compute Bound</b><br>(limited by TFLOPS)", | |
| showarrow=False, | |
| font=dict(size=11, color="#16a34a"), # Solid green | |
| bgcolor="rgba(255, 255, 255, 0.9)", | |
| bordercolor="#16a34a", | |
| borderwidth=1, | |
| borderpad=4, | |
| ) | |
| # Use detected_name if available, otherwise use name | |
| display_name = gpu.get("detected_name", gpu.get("name", "GPU")) | |
| # Add estimated indicator if specs were estimated | |
| estimated_note = " (estimated specs)" if gpu.get("estimated") else "" | |
| fig.update_layout( | |
| title=dict( | |
| text=f"Roofline Model: {display_name}{title_suffix}{estimated_note}<br>" | |
| f"<span style='font-size:12px;color:#666'>" | |
| f"Peak: {peak_tflops} TFLOPS | Bandwidth: {bandwidth_gbps} GB/s</span>", | |
| x=0.5, | |
| font=dict(size=14), | |
| ), | |
| xaxis=dict( | |
| title="Arithmetic Intensity (FLOPs/byte)", | |
| type="log", | |
| range=[0, 3], | |
| ), | |
| yaxis=dict( | |
| title="Performance (TFLOPS)", | |
| range=[0, peak_tflops * 1.2], # More headroom for text | |
| ), | |
| height=420, | |
| margin=dict(l=60, r=40, t=80, b=80), # More room for title and legend | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=-0.30, | |
| xanchor="center", | |
| x=0.5, | |
| font=dict(size=10), | |
| ), | |
| showlegend=True, | |
| ) | |
| return fig | |
| def get_roofline_insight(benchmark_metrics: dict = None) -> str: | |
| """Return insight text for the roofline chart.""" | |
| base_insight = """**Why FlashAttention is Faster:** | |
| The roofline model reveals the key insight: | |
| 1. **Standard Attention** sits in the **memory-bound** region (left of ridge point) | |
| - Limited by HBM bandwidth, not compute | |
| - Reading/writing the N×N attention matrix dominates runtime | |
| 2. **FlashAttention** moves to the **compute-bound** region (right of ridge point) | |
| - By never materializing the full attention matrix | |
| - Arithmetic intensity increases ~20-50× | |
| - Can now utilize most of the GPU's TFLOPS | |
| *The same FLOPs, but 10× less memory traffic = faster execution!*""" | |
| if benchmark_metrics and "math" in benchmark_metrics and "flash" in benchmark_metrics: | |
| math_m = benchmark_metrics["math"] | |
| flash_m = benchmark_metrics["flash"] | |
| speedup = math_m["time_ms"] / flash_m["time_ms"] | |
| intensity_ratio = flash_m["arith_intensity"] / math_m["arith_intensity"] | |
| measured_insight = f""" | |
| --- | |
| **📊 Measured Results:** | |
| - **Math backend:** {math_m['achieved_tflops']:.1f} TFLOPS @ {math_m['arith_intensity']:.0f} FLOPs/byte | |
| - **Flash backend:** {flash_m['achieved_tflops']:.1f} TFLOPS @ {flash_m['arith_intensity']:.0f} FLOPs/byte | |
| - **Speedup:** {speedup:.1f}× faster | |
| - **Intensity increase:** {intensity_ratio:.0f}× higher arithmetic intensity""" | |
| return base_insight + measured_insight | |
| return base_insight + "\n\n*Run a benchmark to see measured values on the chart!*" | |