"""
Benchmark module for FlashAttention Explorer.
GPU benchmark functions for comparing attention backends using real HuggingFace models.
"""
import torch
import torch.nn.functional as F
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from .constants import GPU_SPECS, ATTENTION_BACKENDS, MODEL_CONFIGS, DEFAULT_GPU, DEFAULT_MODEL
from .models import load_model, clear_model_cache
from .attention_utils import (
extract_attention_layer,
create_attention_inputs,
benchmark_attention_layer,
get_model_attention_info,
)
def detect_gpu() -> dict:
"""
Detect the actual GPU and return its specs.
Returns:
Dict with GPU name and specs
"""
if not torch.cuda.is_available():
return {"name": "CPU (No GPU)", "detected": False, **GPU_SPECS[DEFAULT_GPU]}
gpu_name_raw = torch.cuda.get_device_name(0)
gpu_name = gpu_name_raw.lower()
# Get memory in GB for dynamic spec estimation
try:
mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
except Exception:
mem_gb = 24 # fallback
# Match against known GPUs (ordered from newest to oldest)
if "h200" in gpu_name:
# H200 specs - HBM3e memory, very high bandwidth
return {
"detected": True,
"detected_name": gpu_name_raw,
"name": "NVIDIA H200",
"tflops_fp16": 989, # Same compute as H100
"bandwidth_gbps": 4800, # HBM3e: 4.8 TB/s
"memory_gb": round(mem_gb),
"sram_kb": 256,
}
elif "h100" in gpu_name:
return {"detected": True, "detected_name": gpu_name_raw, **GPU_SPECS["H100"]}
elif "a100" in gpu_name:
return {"detected": True, "detected_name": gpu_name_raw, **GPU_SPECS["A100_80GB"]}
elif "a10" in gpu_name:
return {"detected": True, "detected_name": gpu_name_raw, **GPU_SPECS["A10G"]}
elif "l40" in gpu_name:
# L40S specs
return {
"detected": True,
"detected_name": gpu_name_raw,
"name": "NVIDIA L40S",
"tflops_fp16": 362,
"bandwidth_gbps": 864,
"memory_gb": round(mem_gb),
"sram_kb": 192,
}
elif "l4" in gpu_name:
# L4 specs
return {
"detected": True,
"detected_name": gpu_name_raw,
"name": "NVIDIA L4",
"tflops_fp16": 121,
"bandwidth_gbps": 300,
"memory_gb": round(mem_gb),
"sram_kb": 96,
}
elif "t4" in gpu_name:
return {
"detected": True,
"detected_name": gpu_name_raw,
"name": "NVIDIA T4",
"tflops_fp16": 65,
"bandwidth_gbps": 320,
"memory_gb": round(mem_gb),
"sram_kb": 64,
}
elif "v100" in gpu_name:
return {
"detected": True,
"detected_name": gpu_name_raw,
"name": "NVIDIA V100",
"tflops_fp16": 125,
"bandwidth_gbps": 900,
"memory_gb": round(mem_gb),
"sram_kb": 128,
}
elif "rtx 4090" in gpu_name or "4090" in gpu_name:
return {
"detected": True,
"detected_name": gpu_name_raw,
"name": "NVIDIA RTX 4090",
"tflops_fp16": 330,
"bandwidth_gbps": 1008,
"memory_gb": round(mem_gb),
"sram_kb": 128,
}
else:
# Unknown GPU - estimate specs using compute capability and SM count
# These are the best indicators of performance we can query
try:
props = torch.cuda.get_device_properties(0)
sm_count = props.multi_processor_count
major, minor = torch.cuda.get_device_capability(0)
# FP16 FLOPs per SM per cycle varies by architecture
# Ampere (8.x): 256 FP16 ops/SM/cycle, Hopper (9.x): 512
# Clock speed ~1.5-2 GHz typically
if major >= 9: # Hopper/Ada
flops_per_sm = 512
clock_ghz = 1.8
bw_per_gb_mem = 50 # Rough: HBM3 ~50 GB/s per GB capacity
elif major >= 8: # Ampere
flops_per_sm = 256
clock_ghz = 1.5
bw_per_gb_mem = 25 # HBM2e
elif major >= 7: # Volta/Turing
flops_per_sm = 128
clock_ghz = 1.4
bw_per_gb_mem = 28
else: # Older
flops_per_sm = 64
clock_ghz = 1.2
bw_per_gb_mem = 20
# Estimate TFLOPS: SMs × FLOPs/SM/cycle × clock × 2 (FMA)
est_tflops = (sm_count * flops_per_sm * clock_ghz * 2) / 1000
est_bw = mem_gb * bw_per_gb_mem
except Exception:
# Fallback if properties query fails
est_tflops = 125
est_bw = 600
return {
"detected": True,
"detected_name": gpu_name_raw,
"name": gpu_name_raw,
"tflops_fp16": round(est_tflops),
"bandwidth_gbps": round(est_bw),
"memory_gb": round(mem_gb),
"sram_kb": 128,
"estimated": True, # Flag that these are estimated from compute capability
"compute_capability": f"{major}.{minor}" if 'major' in dir() else "unknown",
}
def run_attention_benchmark(
model_name: str = None,
seq_len: int = 1024,
batch_size: int = 1,
num_iterations: int = 10,
warmup_iterations: int = 3,
# Legacy parameters (used if model_name is None)
num_heads: int = 16,
head_dim: int = 64,
) -> dict:
"""
Benchmark three SDPA backends using a real HuggingFace model's attention layer.
Args:
model_name: Name of the model from MODEL_CONFIGS (e.g., "SmolLM2-360M")
If None, falls back to legacy random tensor mode
seq_len: Sequence length (number of tokens)
batch_size: Batch size
num_iterations: Number of timed iterations
warmup_iterations: Number of warmup iterations
num_heads: (Legacy) Number of attention heads if model_name is None
head_dim: (Legacy) Dimension per head if model_name is None
Returns:
Dict with timing and memory results per backend
"""
if not torch.cuda.is_available():
return {"error": "CUDA not available"}
device = torch.device("cuda")
dtype = torch.float16
# If model_name is provided, use real model dimensions for benchmarking
if model_name is not None and model_name in MODEL_CONFIGS:
try:
# Load the real HuggingFace model
model = load_model(model_name)
# Get model attention info for real dimensions
attn_info = get_model_attention_info(model)
# Extract dimensions from real model
model_num_heads = attn_info["num_attention_heads"]
model_head_dim = attn_info["head_dim"]
results = {"model_name": model_name, "using_real_model": True}
results["model_info"] = attn_info
# First try: Use actual attention layer forward pass
attention_layer_works = False
try:
attention_layer = extract_attention_layer(model, layer_idx=0)
hidden_states, position_ids = create_attention_inputs(
model, batch_size, seq_len, device, dtype
)
# Test if attention layer works with first backend
test_result = benchmark_attention_layer(
attention_layer=attention_layer,
hidden_states=hidden_states,
position_ids=position_ids,
backend="flash",
num_iterations=2,
warmup_iterations=1,
)
if test_result.get("time_ms") is not None:
attention_layer_works = True
del hidden_states, position_ids
torch.cuda.empty_cache()
except Exception as layer_error:
print(f"[run_attention_benchmark] Attention layer extraction failed: {layer_error}")
attention_layer_works = False
if attention_layer_works:
# Use actual attention layer
hidden_states, position_ids = create_attention_inputs(
model, batch_size, seq_len, device, dtype
)
for backend in ["math", "flash", "mem_efficient"]:
result = benchmark_attention_layer(
attention_layer=attention_layer,
hidden_states=hidden_states,
position_ids=position_ids,
backend=backend,
num_iterations=num_iterations,
warmup_iterations=warmup_iterations,
)
results[backend] = result
del hidden_states, position_ids
torch.cuda.empty_cache()
else:
# Fallback: Use F.scaled_dot_product_attention with real model dimensions
print(f"[run_attention_benchmark] Falling back to SDPA with model dimensions")
results["fallback_mode"] = True
# Create Q, K, V tensors with real model dimensions
Q = torch.randn(batch_size, model_num_heads, seq_len, model_head_dim, device=device, dtype=dtype)
K = torch.randn(batch_size, model_num_heads, seq_len, model_head_dim, device=device, dtype=dtype)
V = torch.randn(batch_size, model_num_heads, seq_len, model_head_dim, device=device, dtype=dtype)
backends = [
("math", True, False, False),
("flash", False, True, False),
("mem_efficient", False, False, True),
]
for backend_name, enable_math, enable_flash, enable_mem_efficient in backends:
try:
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
# Warmup
for _ in range(warmup_iterations):
_ = F.scaled_dot_product_attention(Q, K, V)
torch.cuda.synchronize()
# Timed runs
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_iterations):
_ = F.scaled_dot_product_attention(Q, K, V)
end.record()
torch.cuda.synchronize()
time_ms = start.elapsed_time(end) / num_iterations
memory_mb = torch.cuda.max_memory_allocated() / 1e6
results[backend_name] = {
"time_ms": round(time_ms, 3),
"memory_mb": round(memory_mb, 1),
"status": "success"
}
except Exception as e:
results[backend_name] = {
"time_ms": None,
"memory_mb": None,
"status": f"error: {str(e)[:50]}"
}
del Q, K, V
torch.cuda.empty_cache()
# Calculate speedups
if results.get("math", {}).get("time_ms"):
base_time = results["math"]["time_ms"]
for backend in ["math", "flash", "mem_efficient"]:
if results.get(backend, {}).get("time_ms"):
results[backend]["speedup"] = round(base_time / results[backend]["time_ms"], 2)
return results
except Exception as e:
return {"error": f"Failed to load model: {str(e)[:100]}"}
# Legacy mode: Use raw SDPA with random tensors (fallback)
results = {"using_real_model": False}
# Create input tensors
Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
K = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
V = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
# Test each backend
backends = [
("math", True, False, False),
("flash", False, True, False),
("mem_efficient", False, False, True),
]
for backend_name, enable_math, enable_flash, enable_mem_efficient in backends:
try:
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
# Warmup
for _ in range(warmup_iterations):
_ = F.scaled_dot_product_attention(Q, K, V)
torch.cuda.synchronize()
# Timed runs
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_iterations):
_ = F.scaled_dot_product_attention(Q, K, V)
end.record()
torch.cuda.synchronize()
time_ms = start.elapsed_time(end) / num_iterations
memory_mb = torch.cuda.max_memory_allocated() / 1e6
results[backend_name] = {
"time_ms": round(time_ms, 3),
"memory_mb": round(memory_mb, 1),
"status": "success"
}
except Exception as e:
results[backend_name] = {
"time_ms": None,
"memory_mb": None,
"status": f"error: {str(e)[:50]}"
}
# Calculate speedups relative to math backend
if results.get("math", {}).get("time_ms"):
base_time = results["math"]["time_ms"]
for backend in results:
if isinstance(results[backend], dict) and results[backend].get("time_ms"):
results[backend]["speedup"] = round(base_time / results[backend]["time_ms"], 2)
# Clean up
del Q, K, V
torch.cuda.empty_cache()
return results
def run_scaling_benchmark(
model_name: str = None,
seq_lengths: list = None,
batch_size: int = 1,
# Legacy parameters (used if model_name is None)
num_heads: int = 16,
head_dim: int = 64,
) -> dict:
"""
Benchmark attention backends across multiple sequence lengths using a real model.
Args:
model_name: Name of the model from MODEL_CONFIGS (e.g., "SmolLM2-360M")
seq_lengths: List of sequence lengths to test
batch_size: Batch size
num_heads: (Legacy) Number of attention heads if model_name is None
head_dim: (Legacy) Dimension per head if model_name is None
Returns:
Dict with arrays of timing and memory results for each backend
"""
if seq_lengths is None:
seq_lengths = [512, 1024, 2048, 4096]
if not torch.cuda.is_available():
return {"error": "CUDA not available"}
results = {
"seq_lengths": seq_lengths,
"model_name": model_name,
"math": {"time_ms": [], "memory_mb": []},
"flash": {"time_ms": [], "memory_mb": []},
"mem_efficient": {"time_ms": [], "memory_mb": []},
}
for seq_len in seq_lengths:
bench_result = run_attention_benchmark(
model_name=model_name,
seq_len=seq_len,
batch_size=batch_size,
num_iterations=5, # Fewer iterations for scaling test
warmup_iterations=2,
# Legacy params (ignored if model_name is set)
num_heads=num_heads,
head_dim=head_dim,
)
for backend in ["math", "flash", "mem_efficient"]:
if bench_result.get(backend, {}).get("time_ms"):
results[backend]["time_ms"].append(bench_result[backend]["time_ms"])
results[backend]["memory_mb"].append(bench_result[backend]["memory_mb"])
else:
results[backend]["time_ms"].append(None)
results[backend]["memory_mb"].append(None)
return results
def create_benchmark_results_table(results: dict) -> str:
"""Create a markdown table from benchmark results."""
if "error" in results:
return f"**Error:** {results['error']}"
# Build table
lines = [
"| Backend | Time (ms) | Memory (MB) | Speedup |",
"|---------|-----------|-------------|---------|",
]
for backend in ["math", "flash", "mem_efficient"]:
if backend in results:
r = results[backend]
name = ATTENTION_BACKENDS.get(backend, backend)
time_str = f"{r['time_ms']:.2f}" if r.get('time_ms') else "N/A"
mem_str = f"{r['memory_mb']:.0f}" if r.get('memory_mb') else "N/A"
speedup_str = f"{r.get('speedup', 1.0):.1f}×"
lines.append(f"| {name} | {time_str} | {mem_str} | {speedup_str} |")
return "\n".join(lines)
def create_benchmark_insight(results: dict) -> str:
"""Create insight text from benchmark results."""
if "error" in results:
return ""
flash = results.get("flash", {})
math = results.get("math", {})
if not flash.get("time_ms") or not math.get("time_ms"):
return "**Note:** Some backends may not be available on this GPU."
speedup = math["time_ms"] / flash["time_ms"]
mem_reduction = math["memory_mb"] / flash["memory_mb"] if flash["memory_mb"] > 0 else 1
return f"""**Key Insight:**
FlashAttention is **{speedup:.1f}× faster** and uses **{mem_reduction:.1f}× less memory**!
This improvement comes from:
- Tiling attention into SRAM-sized blocks
- Never materializing the full N×N attention matrix in HBM
- Fused kernel avoiding multiple HBM round-trips"""
def create_scaling_chart(results: dict) -> go.Figure:
"""Create a scaling chart showing time and memory vs sequence length."""
if "error" in results:
fig = go.Figure()
fig.add_annotation(
x=0.5, y=0.5,
text=f"Error: {results['error']}",
showarrow=False,
font=dict(size=16, color="red")
)
return fig
seq_lengths = results["seq_lengths"]
# Create subplot with two y-axes
fig = make_subplots(
rows=1, cols=2,
subplot_titles=("Execution Time", "Peak Memory"),
horizontal_spacing=0.12,
)
colors = {
"math": "rgba(239, 68, 68, 0.8)", # Red
"flash": "rgba(34, 197, 94, 0.8)", # Green
"mem_efficient": "rgba(59, 130, 246, 0.8)", # Blue
}
# Plot time
for backend in ["math", "flash", "mem_efficient"]:
times = results[backend]["time_ms"]
name = ATTENTION_BACKENDS.get(backend, backend)
# Filter out None values
valid_points = [(s, t) for s, t in zip(seq_lengths, times) if t is not None]
if valid_points:
x_vals, y_vals = zip(*valid_points)
fig.add_trace(
go.Scatter(
x=list(x_vals),
y=list(y_vals),
mode="lines+markers",
name=name,
line=dict(color=colors[backend], width=2),
marker=dict(size=8),
legendgroup=backend,
),
row=1, col=1
)
# Plot memory
for backend in ["math", "flash", "mem_efficient"]:
memory = results[backend]["memory_mb"]
name = ATTENTION_BACKENDS.get(backend, backend)
valid_points = [(s, m) for s, m in zip(seq_lengths, memory) if m is not None]
if valid_points:
x_vals, y_vals = zip(*valid_points)
fig.add_trace(
go.Scatter(
x=list(x_vals),
y=list(y_vals),
mode="lines+markers",
name=name,
line=dict(color=colors[backend], width=2),
marker=dict(size=8),
legendgroup=backend,
showlegend=False,
),
row=1, col=2
)
fig.update_xaxes(title_text="Sequence Length", row=1, col=1)
fig.update_xaxes(title_text="Sequence Length", row=1, col=2)
fig.update_yaxes(title_text="Time (ms)", row=1, col=1)
fig.update_yaxes(title_text="Memory (MB)", row=1, col=2)
fig.update_layout(
height=350,
margin=dict(l=50, r=50, t=50, b=50),
legend=dict(
orientation="h",
yanchor="bottom",
y=-0.3,
xanchor="center",
x=0.5
),
)
return fig
def calculate_attention_flops(seq_len: int, num_heads: int, head_dim: int, batch_size: int = 1) -> float:
"""
Calculate FLOPs for scaled dot-product attention.
FLOPs breakdown:
- Q @ K^T: 2 * batch * heads * seq * seq * head_dim
- Softmax: ~5 * batch * heads * seq * seq (exp, sum, div)
- P @ V: 2 * batch * heads * seq * seq * head_dim
Total: ~4 * batch * heads * seq² * head_dim + 5 * batch * heads * seq²
"""
qk_flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim
softmax_flops = 5 * batch_size * num_heads * seq_len * seq_len
pv_flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim
return qk_flops + softmax_flops + pv_flops
def calculate_memory_traffic(
seq_len: int,
num_heads: int,
head_dim: int,
batch_size: int = 1,
is_flash: bool = False,
dtype_bytes: int = 2, # FP16
) -> float:
"""
Calculate memory traffic in bytes for attention.
Standard Attention:
- Read Q, K, V: 3 * batch * heads * seq * head_dim * dtype_bytes
- Write S = Q @ K^T: batch * heads * seq * seq * dtype_bytes
- Read S for softmax: batch * heads * seq * seq * dtype_bytes
- Write P = softmax(S): batch * heads * seq * seq * dtype_bytes
- Read P and V: batch * heads * seq * seq + batch * heads * seq * head_dim
- Write O: batch * heads * seq * head_dim * dtype_bytes
FlashAttention:
- Read Q, K, V once: 3 * batch * heads * seq * head_dim * dtype_bytes
- Write O once: batch * heads * seq * head_dim * dtype_bytes
- No attention matrix written to HBM!
"""
qkv_size = 3 * batch_size * num_heads * seq_len * head_dim * dtype_bytes
output_size = batch_size * num_heads * seq_len * head_dim * dtype_bytes
if is_flash:
# FlashAttention: Only Q, K, V reads + O write
return qkv_size + output_size
else:
# Standard: Also materializes attention matrix (read + write twice)
attention_matrix_size = batch_size * num_heads * seq_len * seq_len * dtype_bytes
return qkv_size + output_size + 3 * attention_matrix_size
def calculate_roofline_metrics(
results: dict,
seq_len: int,
num_heads: int,
head_dim: int,
batch_size: int = 1,
) -> dict:
"""
Calculate arithmetic intensity and achieved TFLOPS from benchmark results.
Returns dict with measured metrics for each backend.
"""
flops = calculate_attention_flops(seq_len, num_heads, head_dim, batch_size)
metrics = {}
for backend in ["math", "flash", "mem_efficient"]:
if backend not in results or results[backend].get("time_ms") is None:
continue
time_ms = results[backend]["time_ms"]
time_s = time_ms / 1000.0
# Calculate achieved TFLOPS
achieved_tflops = (flops / time_s) / 1e12
# Calculate memory traffic (approximation)
is_flash = backend in ["flash", "mem_efficient"]
memory_bytes = calculate_memory_traffic(
seq_len, num_heads, head_dim, batch_size, is_flash=is_flash
)
# Arithmetic intensity = FLOPs / bytes
arith_intensity = flops / memory_bytes
metrics[backend] = {
"flops": flops,
"memory_bytes": memory_bytes,
"time_ms": time_ms,
"achieved_tflops": achieved_tflops,
"arith_intensity": arith_intensity,
}
return metrics
def create_roofline_chart(
results: dict,
gpu_specs: dict = None,
benchmark_metrics: dict = None,
) -> go.Figure:
"""
Create a roofline chart showing where different attention implementations fall.
The roofline model shows:
- X-axis: Arithmetic intensity (FLOPs per byte of memory traffic)
- Y-axis: Performance (TFLOPS)
- The roofline is min(peak_compute, bandwidth * intensity)
Args:
results: Benchmark results dict (can be empty)
gpu_specs: GPU specifications dict (from detect_gpu() or GPU_SPECS)
benchmark_metrics: Roofline metrics from calculate_roofline_metrics()
If benchmark_metrics is provided, plots MEASURED values.
Otherwise, plots theoretical approximations.
"""
# Use provided specs or default to A10G
if gpu_specs is None:
gpu = GPU_SPECS[DEFAULT_GPU]
else:
gpu = gpu_specs
peak_tflops = gpu["tflops_fp16"]
bandwidth_gbps = gpu["bandwidth_gbps"]
# Ridge point: where memory-bound meets compute-bound
ridge_point = (peak_tflops * 1e12) / (bandwidth_gbps * 1e9)
# Create figure
fig = go.Figure()
# Roofline curve
x_range = np.logspace(0, 3, 100)
y_roofline = np.minimum(
peak_tflops,
bandwidth_gbps * x_range / 1000
)
fig.add_trace(go.Scatter(
x=x_range,
y=y_roofline,
mode="lines",
name="Roofline",
line=dict(color="rgba(0, 0, 0, 0.6)", width=2),
))
# Memory-bound region (dashed)
fig.add_trace(go.Scatter(
x=[1, ridge_point],
y=[bandwidth_gbps / 1000, peak_tflops],
mode="lines",
name="Memory Bound",
line=dict(color="rgba(239, 68, 68, 0.5)", width=3, dash="dash"),
))
# Compute-bound region (dashed)
fig.add_trace(go.Scatter(
x=[ridge_point, 1000],
y=[peak_tflops, peak_tflops],
mode="lines",
name="Compute Bound",
line=dict(color="rgba(34, 197, 94, 0.5)", width=3, dash="dash"),
))
# Determine if we have measured data or should use theoretical
use_measured = benchmark_metrics is not None and len(benchmark_metrics) > 0
if use_measured:
# Plot MEASURED data points
title_suffix = " (Measured)"
# Math/Standard backend
if "math" in benchmark_metrics:
m = benchmark_metrics["math"]
fig.add_trace(go.Scatter(
x=[m["arith_intensity"]],
y=[m["achieved_tflops"]],
mode="markers",
name=f"Math ({m['achieved_tflops']:.1f} TFLOPS, {m['time_ms']:.1f}ms)",
marker=dict(size=16, color="#dc2626", symbol="circle",
line=dict(color="white", width=2)),
))
# Add label as annotation for better visibility
fig.add_annotation(
x=np.log10(m["arith_intensity"]),
y=m["achieved_tflops"],
text=f"Math
{m['time_ms']:.1f}ms",
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=1,
arrowcolor="#dc2626",
ax=0,
ay=-40,
font=dict(size=10, color="#dc2626"),
bgcolor="rgba(255, 255, 255, 0.95)",
bordercolor="#dc2626",
borderwidth=1,
borderpad=3,
)
# Flash backend
if "flash" in benchmark_metrics:
m = benchmark_metrics["flash"]
fig.add_trace(go.Scatter(
x=[m["arith_intensity"]],
y=[m["achieved_tflops"]],
mode="markers",
name=f"Flash ({m['achieved_tflops']:.1f} TFLOPS, {m['time_ms']:.1f}ms)",
marker=dict(size=16, color="#16a34a", symbol="circle",
line=dict(color="white", width=2)),
))
fig.add_annotation(
x=np.log10(m["arith_intensity"]),
y=m["achieved_tflops"],
text=f"Flash
{m['time_ms']:.1f}ms",
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=1,
arrowcolor="#16a34a",
ax=0,
ay=-40,
font=dict(size=10, color="#16a34a"),
bgcolor="rgba(255, 255, 255, 0.95)",
bordercolor="#16a34a",
borderwidth=1,
borderpad=3,
)
# Memory-efficient backend
if "mem_efficient" in benchmark_metrics:
m = benchmark_metrics["mem_efficient"]
fig.add_trace(go.Scatter(
x=[m["arith_intensity"]],
y=[m["achieved_tflops"]],
mode="markers",
name=f"MemEff ({m['achieved_tflops']:.1f} TFLOPS, {m['time_ms']:.1f}ms)",
marker=dict(size=16, color="#2563eb", symbol="circle",
line=dict(color="white", width=2)),
))
fig.add_annotation(
x=np.log10(m["arith_intensity"]),
y=m["achieved_tflops"],
text=f"MemEff
{m['time_ms']:.1f}ms",
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=1,
arrowcolor="#2563eb",
ax=30, # Offset to avoid overlap
ay=-30,
font=dict(size=10, color="#2563eb"),
bgcolor="rgba(255, 255, 255, 0.95)",
bordercolor="#2563eb",
borderwidth=1,
borderpad=3,
)
else:
# Plot THEORETICAL approximations
title_suffix = " (Theoretical)"
# Standard attention - memory bound
std_intensity = 10
std_achieved = min(peak_tflops * 0.15, bandwidth_gbps * std_intensity / 1000)
fig.add_trace(go.Scatter(
x=[std_intensity],
y=[std_achieved],
mode="markers",
name="Standard (Theoretical)",
marker=dict(size=15, color="rgba(220, 38, 38, 0.6)", symbol="circle-open",
line=dict(width=2)),
))
fig.add_annotation(
x=np.log10(std_intensity),
y=std_achieved,
text="Standard
(theoretical)",
showarrow=True,
arrowhead=2,
ax=0,
ay=-35,
font=dict(size=10, color="#dc2626"),
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(220, 38, 38, 0.5)",
borderwidth=1,
borderpad=3,
)
# FlashAttention - compute bound
flash_intensity = 200
flash_achieved = min(peak_tflops * 0.7, bandwidth_gbps * flash_intensity / 1000)
fig.add_trace(go.Scatter(
x=[flash_intensity],
y=[flash_achieved],
mode="markers",
name="Flash (Theoretical)",
marker=dict(size=15, color="rgba(22, 163, 74, 0.6)", symbol="circle-open",
line=dict(width=2)),
))
fig.add_annotation(
x=np.log10(flash_intensity),
y=flash_achieved,
text="FlashAttention
(theoretical)",
showarrow=True,
arrowhead=2,
ax=0,
ay=-35,
font=dict(size=10, color="#16a34a"),
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(22, 163, 74, 0.5)",
borderwidth=1,
borderpad=3,
)
# Add ridge point marker
fig.add_trace(go.Scatter(
x=[ridge_point],
y=[peak_tflops],
mode="markers",
name=f"Ridge Point ({ridge_point:.0f} FLOPs/byte)",
marker=dict(size=10, color="rgba(0, 0, 0, 0.6)", symbol="diamond"),
))
# Add annotations with better visibility (white background)
fig.add_annotation(
x=np.log10(5),
y=peak_tflops * 0.1,
text="Memory Bound
(limited by bandwidth)",
showarrow=False,
font=dict(size=11, color="#dc2626"), # Solid red
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="#dc2626",
borderwidth=1,
borderpad=4,
)
fig.add_annotation(
x=np.log10(300),
y=peak_tflops * 0.65,
text="Compute Bound
(limited by TFLOPS)",
showarrow=False,
font=dict(size=11, color="#16a34a"), # Solid green
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="#16a34a",
borderwidth=1,
borderpad=4,
)
# Use detected_name if available, otherwise use name
display_name = gpu.get("detected_name", gpu.get("name", "GPU"))
# Add estimated indicator if specs were estimated
estimated_note = " (estimated specs)" if gpu.get("estimated") else ""
fig.update_layout(
title=dict(
text=f"Roofline Model: {display_name}{title_suffix}{estimated_note}
"
f""
f"Peak: {peak_tflops} TFLOPS | Bandwidth: {bandwidth_gbps} GB/s",
x=0.5,
font=dict(size=14),
),
xaxis=dict(
title="Arithmetic Intensity (FLOPs/byte)",
type="log",
range=[0, 3],
),
yaxis=dict(
title="Performance (TFLOPS)",
range=[0, peak_tflops * 1.2], # More headroom for text
),
height=420,
margin=dict(l=60, r=40, t=80, b=80), # More room for title and legend
legend=dict(
orientation="h",
yanchor="bottom",
y=-0.30,
xanchor="center",
x=0.5,
font=dict(size=10),
),
showlegend=True,
)
return fig
def get_roofline_insight(benchmark_metrics: dict = None) -> str:
"""Return insight text for the roofline chart."""
base_insight = """**Why FlashAttention is Faster:**
The roofline model reveals the key insight:
1. **Standard Attention** sits in the **memory-bound** region (left of ridge point)
- Limited by HBM bandwidth, not compute
- Reading/writing the N×N attention matrix dominates runtime
2. **FlashAttention** moves to the **compute-bound** region (right of ridge point)
- By never materializing the full attention matrix
- Arithmetic intensity increases ~20-50×
- Can now utilize most of the GPU's TFLOPS
*The same FLOPs, but 10× less memory traffic = faster execution!*"""
if benchmark_metrics and "math" in benchmark_metrics and "flash" in benchmark_metrics:
math_m = benchmark_metrics["math"]
flash_m = benchmark_metrics["flash"]
speedup = math_m["time_ms"] / flash_m["time_ms"]
intensity_ratio = flash_m["arith_intensity"] / math_m["arith_intensity"]
measured_insight = f"""
---
**📊 Measured Results:**
- **Math backend:** {math_m['achieved_tflops']:.1f} TFLOPS @ {math_m['arith_intensity']:.0f} FLOPs/byte
- **Flash backend:** {flash_m['achieved_tflops']:.1f} TFLOPS @ {flash_m['arith_intensity']:.0f} FLOPs/byte
- **Speedup:** {speedup:.1f}× faster
- **Intensity increase:** {intensity_ratio:.0f}× higher arithmetic intensity"""
return base_insight + measured_insight
return base_insight + "\n\n*Run a benchmark to see measured values on the chart!*"