a0y0346
Fix GQA scaling chart x-axis: use log scale to space tick labels properly
287497c
"""
GQA/MQA comparison module.
Demonstrates Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)
using REAL HuggingFace model configurations and benchmarks.
Key concepts:
- MHA (Multi-Head Attention): Each query head has its own K,V heads (ratio 1:1)
- GQA (Grouped-Query Attention): Multiple query heads share K,V heads (ratio N:1)
- MQA (Multi-Query Attention): All query heads share one K,V head (ratio N:1 where N=num_heads)
"""
import torch
import torch.nn.functional as F
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from .models import load_model
def get_gqa_config_from_model(model_name: str) -> dict:
"""
Load model and extract REAL GQA configuration from model.config.
Args:
model_name: Model name to load
Returns:
Dict with GQA configuration from actual model
"""
model = load_model(model_name)
config = model.config
num_heads = config.num_attention_heads
num_kv_heads = getattr(config, 'num_key_value_heads', num_heads)
head_dim = config.hidden_size // num_heads
num_layers = config.num_hidden_layers
# Calculate GQA ratio
gqa_ratio = num_heads // num_kv_heads if num_kv_heads > 0 else 1
# Determine attention type
if num_kv_heads == num_heads:
attention_type = "MHA"
attention_description = "Multi-Head Attention (each Q head has own K,V)"
elif num_kv_heads == 1:
attention_type = "MQA"
attention_description = "Multi-Query Attention (all Q heads share 1 K,V)"
else:
attention_type = "GQA"
attention_description = f"Grouped-Query Attention ({gqa_ratio} Q heads share 1 K,V pair)"
return {
"model_name": model_name,
"num_heads": num_heads,
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"num_layers": num_layers,
"hidden_size": config.hidden_size,
"gqa_ratio": gqa_ratio,
"attention_type": attention_type,
"attention_description": attention_description,
"is_mha": num_heads == num_kv_heads,
"is_mqa": num_kv_heads == 1,
"is_gqa": 1 < num_kv_heads < num_heads,
}
def calculate_kv_cache_memory(
num_kv_heads: int,
head_dim: int,
num_layers: int,
seq_len: int,
batch_size: int = 1,
dtype_bytes: int = 2, # FP16
) -> float:
"""
Calculate KV cache memory in MB.
Formula: 2 (K+V) × num_kv_heads × head_dim × seq_len × num_layers × batch_size × dtype_bytes
"""
bytes_total = 2 * num_kv_heads * head_dim * seq_len * num_layers * batch_size * dtype_bytes
return bytes_total / (1024 * 1024)
def compare_attention_memory(model_name: str, seq_len: int) -> dict:
"""
Compare KV cache memory for MHA vs actual GQA configuration.
Uses REAL model config to show memory savings.
"""
gqa_config = get_gqa_config_from_model(model_name)
num_heads = gqa_config["num_heads"]
num_kv_heads = gqa_config["num_kv_heads"]
head_dim = gqa_config["head_dim"]
num_layers = gqa_config["num_layers"]
# Calculate actual GQA memory
gqa_memory_mb = calculate_kv_cache_memory(
num_kv_heads=num_kv_heads,
head_dim=head_dim,
num_layers=num_layers,
seq_len=seq_len,
)
# Calculate what MHA would use (if every Q head had own K,V)
mha_memory_mb = calculate_kv_cache_memory(
num_kv_heads=num_heads, # MHA: num_kv_heads = num_heads
head_dim=head_dim,
num_layers=num_layers,
seq_len=seq_len,
)
# Calculate MQA memory (single K,V for all heads)
mqa_memory_mb = calculate_kv_cache_memory(
num_kv_heads=1, # MQA: single KV head
head_dim=head_dim,
num_layers=num_layers,
seq_len=seq_len,
)
savings_vs_mha = mha_memory_mb - gqa_memory_mb
savings_ratio = mha_memory_mb / gqa_memory_mb if gqa_memory_mb > 0 else 1
return {
"model_name": model_name,
"seq_len": seq_len,
"gqa_config": gqa_config,
"mha_memory_mb": round(mha_memory_mb, 2),
"gqa_memory_mb": round(gqa_memory_mb, 2),
"mqa_memory_mb": round(mqa_memory_mb, 2),
"savings_mb": round(savings_vs_mha, 2),
"savings_ratio": round(savings_ratio, 1),
}
def benchmark_gqa_attention(
model_name: str,
seq_len: int,
batch_size: int = 1,
num_iterations: int = 10,
use_flash: bool = True,
) -> dict:
"""
Benchmark attention with real model's GQA configuration.
Uses actual num_kv_heads from model.config and expands KV for SDPA.
"""
if not torch.cuda.is_available():
return {"time_ms": 0, "memory_mb": 0, "status": "error: CUDA not available"}
device = torch.device("cuda")
dtype = torch.float16
try:
gqa_config = get_gqa_config_from_model(model_name)
num_heads = gqa_config["num_heads"]
num_kv_heads = gqa_config["num_kv_heads"]
head_dim = gqa_config["head_dim"]
gqa_ratio = gqa_config["gqa_ratio"]
# Create Q tensor: [batch, num_heads, seq_len, head_dim]
Q = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
# Create K,V with actual GQA KV head count
K = torch.randn(batch_size, num_kv_heads, seq_len, head_dim, dtype=dtype, device=device)
V = torch.randn(batch_size, num_kv_heads, seq_len, head_dim, dtype=dtype, device=device)
# Expand K,V to match Q head count (this is what happens in GQA)
if num_kv_heads < num_heads:
K = K.repeat_interleave(gqa_ratio, dim=1)
V = V.repeat_interleave(gqa_ratio, dim=1)
# Backend selection
if use_flash:
enable_math, enable_flash, enable_mem_efficient = False, True, False
else:
enable_math, enable_flash, enable_mem_efficient = True, False, False
# Warmup
for _ in range(3):
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
_ = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
# Timed runs
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_iterations):
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
output = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
end.record()
torch.cuda.synchronize()
time_ms = start.elapsed_time(end) / num_iterations
memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
del Q, K, V, output
torch.cuda.empty_cache()
return {
"time_ms": round(time_ms, 3),
"memory_mb": round(memory_mb, 1),
"config": gqa_config,
"attention_type": gqa_config["attention_type"],
"gqa_ratio": gqa_ratio,
"status": "success",
}
except Exception as e:
return {
"time_ms": 0,
"memory_mb": 0,
"status": f"error: {str(e)[:100]}",
}
def benchmark_mha_attention(
model_name: str,
seq_len: int,
batch_size: int = 1,
num_iterations: int = 10,
use_flash: bool = True,
) -> dict:
"""
Benchmark attention as if model used MHA (for comparison).
Uses num_heads for both Q and KV to simulate MHA.
"""
if not torch.cuda.is_available():
return {"time_ms": 0, "memory_mb": 0, "status": "error: CUDA not available"}
device = torch.device("cuda")
dtype = torch.float16
try:
gqa_config = get_gqa_config_from_model(model_name)
num_heads = gqa_config["num_heads"]
head_dim = gqa_config["head_dim"]
# For MHA simulation: num_kv_heads = num_heads
Q = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
K = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
V = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
# Backend selection
if use_flash:
enable_math, enable_flash, enable_mem_efficient = False, True, False
else:
enable_math, enable_flash, enable_mem_efficient = True, False, False
# Warmup
for _ in range(3):
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
_ = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
# Timed runs
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_iterations):
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
output = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
end.record()
torch.cuda.synchronize()
time_ms = start.elapsed_time(end) / num_iterations
memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
del Q, K, V, output
torch.cuda.empty_cache()
return {
"time_ms": round(time_ms, 3),
"memory_mb": round(memory_mb, 1),
"attention_type": "MHA (simulated)",
"num_kv_heads": num_heads,
"status": "success",
}
except Exception as e:
return {
"time_ms": 0,
"memory_mb": 0,
"status": f"error: {str(e)[:100]}",
}
def create_head_sharing_diagram(model_name: str) -> go.Figure:
"""
Create visual diagram showing how Q heads share K,V heads.
Uses REAL model config for accurate visualization.
"""
gqa_config = get_gqa_config_from_model(model_name)
num_heads = gqa_config["num_heads"]
num_kv_heads = gqa_config["num_kv_heads"]
gqa_ratio = gqa_config["gqa_ratio"]
attention_type = gqa_config["attention_type"]
fig = make_subplots(
rows=1, cols=3,
subplot_titles=(
f"<b>MHA</b> ({num_heads}:{num_heads})",
f"<b>{attention_type}</b> ({num_heads}:{num_kv_heads})",
f"<b>MQA</b> ({num_heads}:1)",
),
horizontal_spacing=0.08,
)
# Limit display for visual clarity
display_heads = min(num_heads, 8)
display_kv = min(num_kv_heads, 8)
# Colors
q_color = "#3b82f6" # Blue for Q heads
kv_color = "#22c55e" # Green for KV heads
# MHA: 1:1 mapping
for i in range(display_heads):
# Q head
fig.add_trace(go.Scatter(
x=[0], y=[display_heads - 1 - i],
mode="markers+text",
marker=dict(size=25, color=q_color, symbol="circle"),
text=[f"Q{i}"],
textposition="middle center",
textfont=dict(color="white", size=9),
showlegend=False,
hoverinfo="skip",
), row=1, col=1)
# KV head
fig.add_trace(go.Scatter(
x=[2], y=[display_heads - 1 - i],
mode="markers+text",
marker=dict(size=25, color=kv_color, symbol="square"),
text=[f"KV{i}"],
textposition="middle center",
textfont=dict(color="white", size=9),
showlegend=False,
hoverinfo="skip",
), row=1, col=1)
# Connection line
fig.add_trace(go.Scatter(
x=[0.3, 1.7], y=[display_heads - 1 - i, display_heads - 1 - i],
mode="lines",
line=dict(color="rgba(0,0,0,0.3)", width=1),
showlegend=False,
hoverinfo="skip",
), row=1, col=1)
# GQA: Actual model configuration
for i in range(display_heads):
# Q head
fig.add_trace(go.Scatter(
x=[0], y=[display_heads - 1 - i],
mode="markers+text",
marker=dict(size=25, color=q_color, symbol="circle"),
text=[f"Q{i}"],
textposition="middle center",
textfont=dict(color="white", size=9),
showlegend=False,
hoverinfo="skip",
), row=1, col=2)
# KV heads for GQA
for j in range(display_kv):
kv_y = (display_heads - 1) * (j / max(display_kv - 1, 1)) if display_kv > 1 else (display_heads - 1) / 2
fig.add_trace(go.Scatter(
x=[2], y=[kv_y],
mode="markers+text",
marker=dict(size=30, color=kv_color, symbol="square"),
text=[f"KV{j}"],
textposition="middle center",
textfont=dict(color="white", size=9),
showlegend=False,
hoverinfo="skip",
), row=1, col=2)
# Draw connections from Q heads to their shared KV head
heads_per_kv = gqa_ratio
for k in range(min(heads_per_kv, display_heads - j * heads_per_kv)):
q_idx = j * heads_per_kv + k
if q_idx < display_heads:
fig.add_trace(go.Scatter(
x=[0.3, 1.7], y=[display_heads - 1 - q_idx, kv_y],
mode="lines",
line=dict(color="rgba(0,0,0,0.2)", width=1),
showlegend=False,
hoverinfo="skip",
), row=1, col=2)
# MQA: All Q heads share 1 KV
for i in range(display_heads):
fig.add_trace(go.Scatter(
x=[0], y=[display_heads - 1 - i],
mode="markers+text",
marker=dict(size=25, color=q_color, symbol="circle"),
text=[f"Q{i}"],
textposition="middle center",
textfont=dict(color="white", size=9),
showlegend=False,
hoverinfo="skip",
), row=1, col=3)
# Connection to single KV
fig.add_trace(go.Scatter(
x=[0.3, 1.7], y=[display_heads - 1 - i, (display_heads - 1) / 2],
mode="lines",
line=dict(color="rgba(0,0,0,0.2)", width=1),
showlegend=False,
hoverinfo="skip",
), row=1, col=3)
# Single KV for MQA
fig.add_trace(go.Scatter(
x=[2], y=[(display_heads - 1) / 2],
mode="markers+text",
marker=dict(size=35, color=kv_color, symbol="square"),
text=["KV0"],
textposition="middle center",
textfont=dict(color="white", size=10),
showlegend=False,
hoverinfo="skip",
), row=1, col=3)
fig.update_layout(
title=dict(
text=f"Head Sharing Pattern: {model_name}",
x=0.5,
),
height=350,
showlegend=False,
margin=dict(l=20, r=20, t=60, b=40),
)
# Update axes for all subplots
for col in [1, 2, 3]:
fig.update_xaxes(
showgrid=False, zeroline=False, showticklabels=False,
range=[-0.5, 2.5], row=1, col=col
)
fig.update_yaxes(
showgrid=False, zeroline=False, showticklabels=False,
range=[-0.5, display_heads - 0.5], row=1, col=col
)
return fig
def create_memory_comparison_chart(model_name: str, seq_len: int) -> go.Figure:
"""
Create bar chart comparing KV cache memory for MHA vs GQA vs MQA.
Uses REAL model config values.
"""
memory_data = compare_attention_memory(model_name, seq_len)
gqa_config = memory_data["gqa_config"]
attention_types = ["MHA<br>(Full)", f"{gqa_config['attention_type']}<br>(Actual)", "MQA<br>(Minimal)"]
memory_values = [
memory_data["mha_memory_mb"],
memory_data["gqa_memory_mb"],
memory_data["mqa_memory_mb"],
]
colors = ["#ef4444", "#22c55e", "#3b82f6"] # Red, Green, Blue
fig = go.Figure()
fig.add_trace(go.Bar(
x=attention_types,
y=memory_values,
marker_color=colors,
text=[f"{v:.1f} MB" for v in memory_values],
textposition="outside",
textfont=dict(size=12),
))
# Add savings annotation - positioned to the right of the GQA bar
savings_ratio = memory_data["savings_ratio"]
fig.add_annotation(
x=1.5, y=memory_values[1] * 1.5,
text=f"<b>{savings_ratio:.1f}× smaller</b>",
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=1,
arrowcolor="#22c55e",
ax=30,
ay=-20,
font=dict(size=11, color="#22c55e"),
)
# Calculate y-axis range to fit bar labels
max_val = max(memory_values)
fig.update_layout(
title=dict(
text=f"KV Cache Memory at {seq_len} tokens<br><sub>{gqa_config['num_heads']} Q heads, {gqa_config['num_kv_heads']} KV heads, {gqa_config['num_layers']} layers</sub>",
x=0.5,
),
yaxis_title="Memory (MB)",
height=400,
margin=dict(l=50, r=50, t=90, b=50),
yaxis=dict(rangemode='tozero', range=[0, max_val * 1.25]), # Add headroom for labels
)
return fig
def create_scaling_chart(model_name: str) -> go.Figure:
"""
Show how KV cache scales with sequence length for different attention types.
"""
gqa_config = get_gqa_config_from_model(model_name)
seq_lengths = [512, 1024, 2048, 4096, 8192, 16384]
mha_memory = []
gqa_memory = []
mqa_memory = []
for seq_len in seq_lengths:
data = compare_attention_memory(model_name, seq_len)
mha_memory.append(data["mha_memory_mb"])
gqa_memory.append(data["gqa_memory_mb"])
mqa_memory.append(data["mqa_memory_mb"])
fig = go.Figure()
fig.add_trace(go.Scatter(
x=seq_lengths, y=mha_memory,
mode="lines+markers",
name="MHA (baseline)",
line=dict(color="#ef4444", width=2),
marker=dict(size=8),
))
fig.add_trace(go.Scatter(
x=seq_lengths, y=gqa_memory,
mode="lines+markers",
name=f"{gqa_config['attention_type']} (actual)",
line=dict(color="#22c55e", width=3),
marker=dict(size=10),
))
fig.add_trace(go.Scatter(
x=seq_lengths, y=mqa_memory,
mode="lines+markers",
name="MQA (minimal)",
line=dict(color="#3b82f6", width=2, dash="dash"),
marker=dict(size=8),
))
fig.update_layout(
title=dict(
text=f"KV Cache Scaling: {model_name}<br><sub>GQA Ratio {gqa_config['gqa_ratio']}:1</sub>",
x=0.5,
),
xaxis_title="Sequence Length",
yaxis_title="KV Cache Memory (MB)",
height=420,
margin=dict(l=60, r=50, t=80, b=100),
legend=dict(
orientation="h",
yanchor="top",
y=-0.18,
xanchor="center",
x=0.5,
),
xaxis=dict(
type="log",
tickmode="array",
tickvals=seq_lengths,
ticktext=[f"{s//1000}K" if s >= 1000 else str(s) for s in seq_lengths],
tickangle=0,
),
yaxis=dict(rangemode='tozero'),
)
return fig
def run_gqa_analysis(model_name: str, seq_len: int) -> tuple:
"""
Run complete GQA analysis for a model.
Returns config info, head diagram, memory chart, scaling chart, and insight text.
"""
# Get real config
gqa_config = get_gqa_config_from_model(model_name)
# Memory comparison
memory_data = compare_attention_memory(model_name, seq_len)
# Create visualizations
head_diagram = create_head_sharing_diagram(model_name)
memory_chart = create_memory_comparison_chart(model_name, seq_len)
scaling_chart = create_scaling_chart(model_name)
# Generate insight
insight = f"""### {model_name} Attention Configuration
**Type:** {gqa_config['attention_type']} ({gqa_config['attention_description']})
**Architecture (from model.config):**
- Query heads: **{gqa_config['num_heads']}**
- KV heads: **{gqa_config['num_kv_heads']}**
- Head dimension: **{gqa_config['head_dim']}**
- Layers: **{gqa_config['num_layers']}**
- GQA ratio: **{gqa_config['gqa_ratio']}:1**
---
### KV Cache Memory at {seq_len} tokens
| Configuration | Memory | vs MHA |
|--------------|--------|--------|
| MHA (baseline) | {memory_data['mha_memory_mb']:.1f} MB | 1.0× |
| **{gqa_config['attention_type']}** | **{memory_data['gqa_memory_mb']:.1f} MB** | **{memory_data['savings_ratio']:.1f}× smaller** |
| MQA (minimal) | {memory_data['mqa_memory_mb']:.1f} MB | {memory_data['mha_memory_mb']/memory_data['mqa_memory_mb']:.1f}× smaller |
---
### Why GQA?
- **Memory efficiency:** Reduces KV cache by {memory_data['savings_ratio']:.1f}×
- **Longer contexts:** Same GPU memory supports {memory_data['savings_ratio']:.1f}× longer sequences
- **Quality preservation:** Better than MQA for maintaining attention quality
- **Compute savings:** Fewer KV projections during generation
"""
return gqa_config, head_diagram, memory_chart, scaling_chart, insight