flash-attention-explorer / src /prefill_decode.py
a0y0346
Refactor benchmarks to use real model.config values
af9b854
"""
Prefill vs Decode phase comparison module.
Demonstrates the key difference between:
- Prefill: Process entire prompt in parallel (N² attention complexity)
- Decode: Generate one token at a time (N attention per token, but sequential)
Uses REAL HuggingFace model attention layers for accurate benchmarking.
"""
import torch
import torch.nn.functional as F
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from .constants import MODEL_CONFIGS, ATTENTION_BACKENDS
from .models import load_model
from .attention_utils import (
extract_attention_layer,
create_attention_inputs,
benchmark_attention_layer,
get_model_attention_info,
)
def get_real_model_config(model_name: str) -> dict:
"""
Load model and extract ACTUAL config values from model.config.
This function ensures we use real model architecture values,
NOT hardcoded constants from MODEL_CONFIGS.
Args:
model_name: Key from MODEL_CONFIGS (e.g., "SmolLM2-360M")
Returns:
Dict with real model configuration values
"""
model = load_model(model_name)
config = model.config
# Extract values directly from model.config
num_heads = config.num_attention_heads
num_kv_heads = getattr(config, 'num_key_value_heads', num_heads)
head_dim = config.hidden_size // num_heads
return {
"num_layers": config.num_hidden_layers,
"num_heads": num_heads,
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"hidden_size": config.hidden_size,
"model_type": getattr(config, 'model_type', 'unknown'),
"gqa_ratio": num_heads // num_kv_heads if num_kv_heads > 0 else 1,
}
def run_prefill_with_real_model(
model,
attention_layer,
seq_len: int,
batch_size: int = 1,
num_iterations: int = 5,
use_flash: bool = True,
) -> dict:
"""
Run prefill phase attention using a REAL model's attention layer.
Prefill processes the entire prompt at once:
- Hidden states have shape [batch, seq_len, hidden_dim]
- Full N×N attention matrix computed via the real attention layer
Args:
model: Loaded HuggingFace model
attention_layer: Extracted attention module
seq_len: Sequence length
batch_size: Batch size
num_iterations: Number of timed iterations
use_flash: Whether to use FlashAttention backend
Returns:
Dict with timing and memory stats
"""
if not torch.cuda.is_available():
return {"error": "CUDA not available"}
device = torch.device("cuda")
dtype = torch.float16
# Create proper inputs for the attention layer
hidden_states, position_ids = create_attention_inputs(
model, batch_size, seq_len, device, dtype
)
# Backend configuration
backend = "flash" if use_flash else "math"
# Run benchmark using the utility function
result = benchmark_attention_layer(
attention_layer=attention_layer,
hidden_states=hidden_states,
position_ids=position_ids,
backend=backend,
num_iterations=num_iterations,
warmup_iterations=2,
)
# Clean up
del hidden_states, position_ids
torch.cuda.empty_cache()
# Add phase info to result
result["seq_len"] = seq_len
result["phase"] = "prefill"
result["using_real_model"] = True
return result
def run_prefill_benchmark(
model_name: str,
seq_len: int,
batch_size: int = 1,
num_iterations: int = 10,
use_flash: bool = True,
) -> dict:
"""
Benchmark prefill phase using F.scaled_dot_product_attention with REAL model dimensions.
This function uses the model's actual configuration (from model.config) to create
properly-sized Q, K, V tensors, then benchmarks the SDPA operation directly.
This is more reliable than calling attention layer forward() methods.
Args:
model_name: Key from MODEL_CONFIGS (model will be loaded to get real config)
seq_len: Sequence length (prompt tokens)
batch_size: Batch size
num_iterations: Number of timed iterations
use_flash: Whether to use FlashAttention backend
Returns:
Dict with time_ms, memory_mb, and status
"""
if not torch.cuda.is_available():
return {"time_ms": 0, "memory_mb": 0, "status": "error: CUDA not available"}
device = torch.device("cuda")
dtype = torch.float16
try:
# Get REAL config from loaded model
real_config = get_real_model_config(model_name)
num_heads = real_config["num_heads"]
head_dim = real_config["head_dim"]
# Create Q, K, V tensors with REAL model dimensions
# Shape: [batch, num_heads, seq_len, head_dim]
Q = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
K = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
V = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
# Set backend flags
if use_flash:
enable_math, enable_flash, enable_mem_efficient = False, True, False
else:
enable_math, enable_flash, enable_mem_efficient = True, False, False
# Warmup
for _ in range(3):
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
_ = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
# Timed runs
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_iterations):
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
output = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
end.record()
torch.cuda.synchronize()
time_ms = start.elapsed_time(end) / num_iterations
memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
# Cleanup
del Q, K, V, output
torch.cuda.empty_cache()
return {
"time_ms": round(time_ms, 3),
"memory_mb": round(memory_mb, 1),
"seq_len": seq_len,
"phase": "prefill",
"backend": "flash" if use_flash else "math",
"num_heads": num_heads,
"head_dim": head_dim,
"status": "success",
"using_real_config": True,
}
except Exception as e:
return {
"time_ms": 0,
"memory_mb": 0,
"status": f"error: {str(e)[:100]}",
"phase": "prefill",
}
def run_decode_with_real_model(
model,
attention_layer,
kv_cache_len: int,
num_tokens: int = 10,
batch_size: int = 1,
num_iterations: int = 3,
use_flash: bool = True,
) -> dict:
"""
Run decode phase attention using a REAL model's attention layer.
Decode generates one token at a time:
- Single query token attending to all past keys/values
- Simulates the memory-bound decode phase
Args:
model: Loaded HuggingFace model
attention_layer: Extracted attention module
kv_cache_len: Length of the KV cache (context)
num_tokens: Number of tokens to simulate generating
batch_size: Batch size
num_iterations: Iterations for averaging
use_flash: Whether to use FlashAttention backend
Returns:
Dict with per-token timing and memory stats
"""
if not torch.cuda.is_available():
return {"error": "CUDA not available"}
device = torch.device("cuda")
dtype = torch.float16
# Create single-token query input (simulating decode)
hidden_dim = model.config.hidden_size
query_hidden = torch.randn(batch_size, 1, hidden_dim, dtype=dtype, device=device)
position_ids = torch.tensor([[kv_cache_len]], device=device).expand(batch_size, 1)
# Backend flags
if use_flash:
enable_math, enable_flash, enable_mem_efficient = False, True, False
else:
enable_math, enable_flash, enable_mem_efficient = True, False, False
try:
# Warmup
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
with torch.no_grad():
for _ in range(2):
_ = attention_layer(query_hidden, position_ids=position_ids)
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
# Time multiple tokens
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
with torch.no_grad():
start.record()
for _ in range(num_tokens * num_iterations):
output = attention_layer(query_hidden, position_ids=position_ids)
end.record()
torch.cuda.synchronize()
total_time_ms = start.elapsed_time(end)
time_per_token_ms = total_time_ms / (num_tokens * num_iterations)
memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
# Clean up
del query_hidden
torch.cuda.empty_cache()
return {
"time_ms_per_token": round(time_per_token_ms, 4),
"total_time_ms": round(total_time_ms / num_iterations, 3),
"memory_mb": round(memory_mb, 1),
"kv_cache_len": kv_cache_len,
"num_tokens": num_tokens,
"phase": "decode",
"using_real_model": True,
"status": "success",
}
except Exception as e:
return {
"time_ms_per_token": 0,
"total_time_ms": 0,
"memory_mb": 0,
"kv_cache_len": kv_cache_len,
"num_tokens": num_tokens,
"phase": "decode",
"status": f"error: {str(e)[:80]}",
}
def run_decode_benchmark(
model_name: str,
kv_cache_len: int,
num_tokens: int = 10,
batch_size: int = 1,
num_iterations: int = 5,
use_flash: bool = True,
) -> dict:
"""
Benchmark decode phase using F.scaled_dot_product_attention with REAL model dimensions.
Properly simulates decode by:
- Creating single query token (Q with seq_len=1)
- Creating KV cache tensors with kv_cache_len tokens
- Handling GQA by expanding KV heads to match Q heads
Args:
model_name: Key from MODEL_CONFIGS (model will be loaded to get real config)
kv_cache_len: Length of KV cache (context length)
num_tokens: Number of decode tokens to simulate
batch_size: Batch size
num_iterations: Iterations for timing
use_flash: Whether to use FlashAttention backend
Returns:
Dict with time_ms_per_token, memory_mb, and status
"""
if not torch.cuda.is_available():
return {"time_ms_per_token": 0, "memory_mb": 0, "status": "error: CUDA not available"}
device = torch.device("cuda")
dtype = torch.float16
try:
# Get REAL config from loaded model
real_config = get_real_model_config(model_name)
num_heads = real_config["num_heads"]
num_kv_heads = real_config["num_kv_heads"]
head_dim = real_config["head_dim"]
# Single query token: [batch, num_heads, 1, head_dim]
Q = torch.randn(batch_size, num_heads, 1, head_dim, dtype=dtype, device=device)
# KV cache with real model's KV head count: [batch, num_kv_heads, kv_cache_len, head_dim]
K_cache = torch.randn(batch_size, num_kv_heads, kv_cache_len, head_dim, dtype=dtype, device=device)
V_cache = torch.randn(batch_size, num_kv_heads, kv_cache_len, head_dim, dtype=dtype, device=device)
# Handle GQA: expand KV heads to match Q heads if needed
if num_kv_heads < num_heads:
repeat_factor = num_heads // num_kv_heads
K_cache = K_cache.repeat_interleave(repeat_factor, dim=1)
V_cache = V_cache.repeat_interleave(repeat_factor, dim=1)
# Set backend flags
if use_flash:
enable_math, enable_flash_flag, enable_mem_efficient = False, True, False
else:
enable_math, enable_flash_flag, enable_mem_efficient = True, False, False
# Warmup
for _ in range(3):
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash_flag,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
_ = F.scaled_dot_product_attention(Q, K_cache, V_cache)
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
# Timed runs - simulate generating num_tokens
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_tokens * num_iterations):
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash_flag,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
output = F.scaled_dot_product_attention(Q, K_cache, V_cache)
end.record()
torch.cuda.synchronize()
total_time_ms = start.elapsed_time(end)
time_per_token_ms = total_time_ms / (num_tokens * num_iterations)
memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
# Cleanup
del Q, K_cache, V_cache, output
torch.cuda.empty_cache()
return {
"time_ms_per_token": round(time_per_token_ms, 4),
"total_time_ms": round(total_time_ms / num_iterations, 3),
"memory_mb": round(memory_mb, 1),
"kv_cache_len": kv_cache_len,
"num_tokens": num_tokens,
"phase": "decode",
"backend": "flash" if use_flash else "math",
"num_heads": num_heads,
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"status": "success",
"using_real_config": True,
}
except Exception as e:
return {
"time_ms_per_token": 0,
"total_time_ms": 0,
"memory_mb": 0,
"kv_cache_len": kv_cache_len,
"num_tokens": num_tokens,
"phase": "decode",
"status": f"error: {str(e)[:100]}",
}
# Legacy function kept for backwards compatibility
def simulate_prefill_attention(
batch_size: int,
num_heads: int,
seq_len: int,
head_dim: int,
num_iterations: int = 5,
use_flash: bool = True,
) -> dict:
"""
Legacy: Simulate prefill phase attention with random tensors.
Use run_prefill_with_real_model() for real model benchmarks.
"""
if not torch.cuda.is_available():
return {"error": "CUDA not available"}
device = torch.device("cuda")
dtype = torch.float16
Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
K = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
V = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
if use_flash:
enable_math, enable_flash_flag, enable_mem_efficient = False, True, False
else:
enable_math, enable_flash_flag, enable_mem_efficient = True, False, False
# Warmup
for _ in range(2):
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash_flag, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient
):
try:
_ = F.scaled_dot_product_attention(Q, K, V)
except Exception:
pass
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_iterations):
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash_flag, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient
):
try:
output = F.scaled_dot_product_attention(Q, K, V)
except Exception:
output = F.scaled_dot_product_attention(Q, K, V)
end.record()
torch.cuda.synchronize()
total_time_ms = start.elapsed_time(end)
avg_time_ms = total_time_ms / num_iterations
peak_memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
del Q, K, V, output
torch.cuda.empty_cache()
return {
"time_ms": avg_time_ms,
"memory_mb": peak_memory_mb,
"seq_len": seq_len,
"phase": "prefill",
}
# Legacy function kept for backwards compatibility
def simulate_decode_attention(
batch_size: int,
num_heads: int,
kv_cache_len: int,
head_dim: int,
num_tokens: int = 10,
use_flash: bool = True,
) -> dict:
"""
Legacy: Simulate decode phase attention with random tensors.
Use run_decode_with_real_model() for real model benchmarks.
"""
if not torch.cuda.is_available():
return {"error": "CUDA not available"}
device = torch.device("cuda")
dtype = torch.float16
K_cache = torch.randn(batch_size, num_heads, kv_cache_len, head_dim, device=device, dtype=dtype)
V_cache = torch.randn(batch_size, num_heads, kv_cache_len, head_dim, device=device, dtype=dtype)
Q = torch.randn(batch_size, num_heads, 1, head_dim, device=device, dtype=dtype)
if use_flash:
enable_math, enable_flash_flag, enable_mem_efficient = False, True, False
else:
enable_math, enable_flash_flag, enable_mem_efficient = True, False, False
# Warmup
for _ in range(2):
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash_flag, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient
):
try:
_ = F.scaled_dot_product_attention(Q, K_cache, V_cache)
except Exception:
pass
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_tokens):
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash_flag, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient
):
try:
output = F.scaled_dot_product_attention(Q, K_cache, V_cache)
except Exception:
output = F.scaled_dot_product_attention(Q, K_cache, V_cache)
end.record()
torch.cuda.synchronize()
total_time_ms = start.elapsed_time(end)
avg_time_per_token_ms = total_time_ms / num_tokens
peak_memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
del Q, K_cache, V_cache, output
torch.cuda.empty_cache()
return {
"time_ms_per_token": avg_time_per_token_ms,
"total_time_ms": total_time_ms,
"memory_mb": peak_memory_mb,
"kv_cache_len": kv_cache_len,
"num_tokens": num_tokens,
"phase": "decode",
}
def run_prefill_decode_comparison(
model_name: str,
context_length: int,
decode_tokens: int = 32,
) -> tuple:
"""
Run full comparison between prefill and decode phases using REAL HuggingFace model.
Uses F.scaled_dot_product_attention with real model dimensions for reliable benchmarking.
All config values come from model.config, not constants.
Returns results dict, comparison chart, KV cache chart, and insight text.
"""
if model_name not in MODEL_CONFIGS:
return {"error": f"Unknown model: {model_name}"}, None, None, "Error: Unknown model"
# Get REAL config from model.config (not constants)
try:
real_config = get_real_model_config(model_name)
except Exception as e:
return {"error": f"Failed to load model: {str(e)[:50]}"}, None, None, f"Error: {str(e)[:50]}"
results = {
"model": model_name,
"context_length": context_length,
"decode_tokens": decode_tokens,
"real_config": real_config,
"using_real_config": True,
}
# Run prefill benchmarks using SDPA with REAL model dimensions
prefill_flash = run_prefill_benchmark(
model_name=model_name,
seq_len=context_length,
batch_size=1,
use_flash=True,
)
prefill_math = run_prefill_benchmark(
model_name=model_name,
seq_len=context_length,
batch_size=1,
use_flash=False,
)
# Run decode benchmarks using SDPA with proper KV cache simulation
decode_flash = run_decode_benchmark(
model_name=model_name,
kv_cache_len=context_length,
num_tokens=decode_tokens,
batch_size=1,
use_flash=True,
)
decode_math = run_decode_benchmark(
model_name=model_name,
kv_cache_len=context_length,
num_tokens=decode_tokens,
batch_size=1,
use_flash=False,
)
results["prefill"] = {
"flash": prefill_flash,
"math": prefill_math,
}
results["decode"] = {
"flash": decode_flash,
"math": decode_math,
}
# Add model info for display
results["model_info"] = {
"num_heads": real_config["num_heads"],
"num_kv_heads": real_config["num_kv_heads"],
"head_dim": real_config["head_dim"],
"num_layers": real_config["num_layers"],
"gqa_ratio": real_config["gqa_ratio"],
}
# Create comparison chart
comparison_chart = create_comparison_chart(results)
# Create KV cache growth chart using REAL model config
kv_cache_chart = create_kv_cache_chart(model_name, context_length, decode_tokens)
# Generate insight
insight = generate_phase_insight(results)
# Add real model indicator to insight
if results.get("using_real_config"):
model_indicator = f"\n\n---\n\n*Benchmarked using real **{model_name}** config ({real_config['num_heads']} heads, {real_config['head_dim']}d, GQA {real_config['gqa_ratio']}:1)*"
insight = insight + model_indicator
return results, comparison_chart, kv_cache_chart, insight
def create_comparison_chart(results: dict) -> go.Figure:
"""Create bar chart comparing prefill vs decode timing."""
prefill_flash = results["prefill"]["flash"]
prefill_math = results["prefill"]["math"]
decode_flash = results["decode"]["flash"]
decode_math = results["decode"]["math"]
# Helper to safely get numeric value (handles None)
def safe_get(d, key, default=0):
val = d.get(key, default)
return val if val is not None else default
fig = make_subplots(
rows=1, cols=2,
subplot_titles=("<b>Prefill Time</b> (Full Prompt)", "<b>Decode Time</b> (Per Token)"),
horizontal_spacing=0.15,
vertical_spacing=0.15,
)
# Get max values for proper y-axis scaling with headroom for labels
prefill_math_time = safe_get(prefill_math, "time_ms", 0)
prefill_flash_time = safe_get(prefill_flash, "time_ms", 0)
decode_math_time = safe_get(decode_math, "time_ms_per_token", 0)
decode_flash_time = safe_get(decode_flash, "time_ms_per_token", 0)
prefill_max = max(prefill_math_time, prefill_flash_time)
decode_max = max(decode_math_time, decode_flash_time)
# Prefill comparison
fig.add_trace(
go.Bar(
x=["Math<br>(Standard)", "Flash<br>Attention"],
y=[prefill_math_time, prefill_flash_time],
marker_color=["#ef4444", "#22c55e"],
text=[f"<b>{prefill_math_time:.2f}ms</b>", f"<b>{prefill_flash_time:.2f}ms</b>"],
textposition="inside",
textangle=0,
insidetextanchor="middle",
textfont=dict(color="white", size=12),
name="Prefill",
showlegend=False,
),
row=1, col=1
)
# Decode comparison (per token)
fig.add_trace(
go.Bar(
x=["Math<br>(Standard)", "Flash<br>Attention"],
y=[decode_math_time, decode_flash_time],
marker_color=["#ef4444", "#22c55e"],
text=[f"<b>{decode_math_time:.3f}ms</b>", f"<b>{decode_flash_time:.3f}ms</b>"],
textposition="inside",
textangle=0,
insidetextanchor="middle",
textfont=dict(color="white", size=12),
name="Decode",
showlegend=False,
),
row=1, col=2
)
# Calculate speedups
if prefill_math_time > 0 and prefill_flash_time > 0:
prefill_speedup = prefill_math_time / prefill_flash_time
else:
prefill_speedup = 1.0
if decode_math_time > 0 and decode_flash_time > 0:
decode_speedup = decode_math_time / decode_flash_time
else:
decode_speedup = 1.0
fig.update_layout(
title=dict(
text=f"<b>Prefill vs Decode: FlashAttention Speedup</b><br>"
f"<span style='font-size:13px;color:#16a34a'>"
f"Prefill: {prefill_speedup:.1f}× faster | Decode: {decode_speedup:.1f}× faster</span>",
x=0.5,
font=dict(size=15),
),
height=380,
margin=dict(l=60, r=40, t=100, b=60),
yaxis_title="Time (ms)",
yaxis2_title="Time (ms)",
)
# Add more y-axis headroom
fig.update_yaxes(range=[0, prefill_max * 1.15], row=1, col=1)
fig.update_yaxes(range=[0, decode_max * 1.15], row=1, col=2)
return fig
def create_kv_cache_chart(model_name: str, context_length: int, decode_tokens: int) -> go.Figure:
"""
Create chart showing KV cache growth during generation.
Uses REAL model config values from model.config, not constants.
Args:
model_name: Model name to load config from
context_length: Number of context tokens (prefill)
decode_tokens: Number of decode tokens to generate
Returns:
Plotly figure showing KV cache growth
"""
# Get REAL config from loaded model (no constants!)
real_config = get_real_model_config(model_name)
num_kv_heads = real_config["num_kv_heads"]
head_dim = real_config["head_dim"]
num_layers = real_config["num_layers"]
# Calculate KV cache size at each step
# KV cache per layer: 2 (K+V) × kv_heads × head_dim × 2 (FP16 bytes)
bytes_per_token_per_layer = 2 * num_kv_heads * head_dim * 2
total_bytes_per_token = bytes_per_token_per_layer * num_layers
# Generate sequence of token counts
token_counts = list(range(0, context_length + decode_tokens + 1, max(1, (context_length + decode_tokens) // 50)))
if token_counts[-1] != context_length + decode_tokens:
token_counts.append(context_length + decode_tokens)
# Calculate cache sizes in MB
cache_sizes_mb = [(t * total_bytes_per_token) / (1024 * 1024) for t in token_counts]
fig = go.Figure()
# Prefill region (0 to context_length)
prefill_tokens = [t for t in token_counts if t <= context_length]
prefill_sizes = [(t * total_bytes_per_token) / (1024 * 1024) for t in prefill_tokens]
fig.add_trace(go.Scatter(
x=prefill_tokens,
y=prefill_sizes,
mode="lines",
name="Prefill Phase",
fill="tozeroy",
line=dict(color="#3b82f6", width=2),
fillcolor="rgba(59, 130, 246, 0.3)",
))
# Decode region (context_length to end)
decode_tokens_list = [t for t in token_counts if t >= context_length]
decode_sizes = [(t * total_bytes_per_token) / (1024 * 1024) for t in decode_tokens_list]
fig.add_trace(go.Scatter(
x=decode_tokens_list,
y=decode_sizes,
mode="lines",
name="Decode Phase",
fill="tozeroy",
line=dict(color="#22c55e", width=2),
fillcolor="rgba(34, 197, 94, 0.3)",
))
# Add vertical line at context boundary
cache_at_context = (context_length * total_bytes_per_token) / (1024 * 1024)
fig.add_vline(
x=context_length,
line_dash="dash",
line_color="rgba(0, 0, 0, 0.5)",
annotation_text=f"Prefill→Decode<br>({cache_at_context:.1f} MB)",
annotation_position="top",
)
fig.update_layout(
title=dict(
text=f"KV Cache Growth ({num_kv_heads} KV heads × {num_layers} layers)",
x=0.5,
),
xaxis_title="Tokens Processed",
yaxis_title="KV Cache Size (MB)",
height=300,
margin=dict(l=50, r=50, t=60, b=50),
legend=dict(
orientation="h",
yanchor="bottom",
y=-0.25,
xanchor="center",
x=0.5,
),
yaxis=dict(rangemode='tozero'),
)
return fig
def generate_phase_insight(results: dict) -> str:
"""Generate insight text from comparison results."""
prefill_flash = results["prefill"]["flash"]
prefill_math = results["prefill"]["math"]
decode_flash = results["decode"]["flash"]
decode_math = results["decode"]["math"]
# Helper to safely get numeric value (handles None)
def safe_get(d, key, default=0):
val = d.get(key, default)
return val if val is not None else default
prefill_math_time = safe_get(prefill_math, "time_ms", 0)
prefill_flash_time = safe_get(prefill_flash, "time_ms", 0)
decode_math_time = safe_get(decode_math, "time_ms_per_token", 0)
decode_flash_time = safe_get(decode_flash, "time_ms_per_token", 0)
# Calculate speedups
if prefill_math_time > 0 and prefill_flash_time > 0:
prefill_speedup = prefill_math_time / prefill_flash_time
else:
prefill_speedup = 1.0
if decode_math_time > 0 and decode_flash_time > 0:
decode_speedup = decode_math_time / decode_flash_time
else:
decode_speedup = 1.0
context_length = results["context_length"]
decode_tokens = results["decode_tokens"]
insight = f"""### Key Observations
**Prefill Phase** (processing {context_length} tokens):
- Standard attention: **{prefill_math_time:.2f}ms**
- FlashAttention: **{prefill_flash_time:.2f}ms**
- Speedup: **{prefill_speedup:.1f}×**
**Decode Phase** (generating {decode_tokens} tokens):
- Standard attention: **{decode_math_time:.3f}ms/token**
- FlashAttention: **{decode_flash_time:.3f}ms/token**
- Speedup: **{decode_speedup:.1f}×**
---
### Why the Difference?
1. **Prefill is compute-bound** with N² attention operations
- FlashAttention's memory efficiency provides significant speedup
- Larger contexts benefit more (quadratic scaling)
2. **Decode is memory-bound** with 1×N attention per token
- Each decode step is fast but sequential
- KV cache read dominates, limiting FlashAttention's advantage
3. **Optimal strategy**: FlashAttention helps most during prefill;
decode phase benefits from KV cache optimizations (GQA/MQA)
"""
return insight
def get_attention_pattern_chart(context_length: int) -> go.Figure:
"""Create visualization of prefill vs decode attention patterns using scatter."""
# Calculate FLOPs for insight
prefill_flops = context_length * context_length # N² attention
decode_flops_per_token = context_length # 1×N per decode token
fig = make_subplots(
rows=1, cols=2,
subplot_titles=(
f"<b>Prefill:</b> {context_length}×{context_length} = {prefill_flops:,} ops",
f"<b>Decode:</b> 1×{context_length} = {decode_flops_per_token:,} ops/token"
),
horizontal_spacing=0.15,
)
# Prefill: Lower triangular pattern (causal mask)
# Dynamic size based on context length for visual feedback
if context_length <= 16:
size = context_length
elif context_length <= 128:
size = 12
elif context_length <= 512:
size = 10
else:
size = 8 # Smaller grid for very large contexts
# Adjust marker size based on grid size
marker_size = max(10, 22 - size)
# Generate coordinates for filled cells (lower triangular)
prefill_x = []
prefill_y = []
for row in range(size):
for col in range(row + 1): # Only up to diagonal
prefill_x.append(col)
prefill_y.append(row)
fig.add_trace(
go.Scatter(
x=prefill_x,
y=prefill_y,
mode="markers",
marker=dict(
size=marker_size,
color="#3b82f6",
symbol="square",
),
name="Attends",
showlegend=False,
hovertemplate="Query %{y} → Key %{x}<extra></extra>",
),
row=1, col=1
)
# Decode: Each step attends to growing sequence
num_decode_steps = 6
base_context = max(4, size - num_decode_steps)
decode_x = []
decode_y = []
for step in range(num_decode_steps):
attend_len = base_context + step + 1
for col in range(min(attend_len, size)):
decode_x.append(col)
decode_y.append(step)
fig.add_trace(
go.Scatter(
x=decode_x,
y=decode_y,
mode="markers",
marker=dict(
size=marker_size + 4,
color="#22c55e",
symbol="square",
),
name="Attends",
showlegend=False,
hovertemplate="Decode step %{y} → Key %{x}<extra></extra>",
),
row=1, col=2
)
# Update axes with proper ranges
fig.update_xaxes(
title_text="Key positions",
range=[-0.5, size - 0.5],
dtick=2,
row=1, col=1
)
fig.update_xaxes(
title_text="Key positions (KV cache)",
range=[-0.5, size - 0.5],
dtick=2,
row=1, col=2
)
fig.update_yaxes(
title_text="Query positions",
range=[-0.5, size - 0.5],
dtick=2,
row=1, col=1
)
fig.update_yaxes(
title_text="Decode steps",
range=[-0.5, num_decode_steps - 0.5],
dtick=1,
row=1, col=2
)
fig.update_layout(
height=380,
margin=dict(l=60, r=30, t=70, b=50),
plot_bgcolor="rgba(241, 245, 249, 0.5)",
)
return fig