a0y0346
Fix Online Softmax: move subplot titles inside chart area (y=0.95)
7720e5f
"""
Visualizer for FlashAttention concepts.
CPU-only animations showing tiling, online softmax, and memory hierarchy.
"""
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def create_tiling_grid(
seq_len: int = 8,
block_size: int = 2,
current_step: int = 0,
causal: bool = False
) -> go.Figure:
"""
Create a grid visualization showing FlashAttention tile processing.
Args:
seq_len: Sequence length (number of tokens)
block_size: Size of each tile block
current_step: Current step in the animation (0-indexed)
causal: Whether to use causal masking
Returns:
Plotly figure with the tiling grid
"""
num_blocks = seq_len // block_size
total_tiles = num_blocks * num_blocks if not causal else sum(range(1, num_blocks + 1))
# Create figure
fig = go.Figure()
# Calculate which tiles are done, current, future, or masked
tile_idx = 0
annotations = []
for i in range(num_blocks): # Query blocks (rows)
for j in range(num_blocks): # Key blocks (columns)
x0, x1 = j, j + 1
y0, y1 = num_blocks - i - 1, num_blocks - i
# Determine tile status
if causal and j > i:
# Masked tile (future keys for causal attention)
color = "rgba(200, 200, 200, 0.3)"
status = "masked"
elif tile_idx < current_step:
# Done
color = "rgba(34, 197, 94, 0.6)" # Green
status = "done"
elif tile_idx == current_step:
# Current
color = "rgba(249, 115, 22, 0.8)" # Orange
status = "current"
else:
# Future
color = "rgba(229, 231, 235, 0.5)" # Light gray
status = "pending"
# Add rectangle
fig.add_shape(
type="rect",
x0=x0, y0=y0, x1=x1, y1=y1,
line=dict(color="rgba(0,0,0,0.3)", width=1),
fillcolor=color,
)
# Add label for current tile
if status == "current":
annotations.append(dict(
x=(x0 + x1) / 2,
y=(y0 + y1) / 2,
text=f"Q[{i}]×K[{j}]",
showarrow=False,
font=dict(size=10, color="white", weight="bold"),
))
if not (causal and j > i):
tile_idx += 1
# Add axis labels
for i in range(num_blocks):
# K labels (top)
annotations.append(dict(
x=i + 0.5,
y=num_blocks + 0.2,
text=f"K[{i}]",
showarrow=False,
font=dict(size=9, color="gray"),
))
# Q labels (left)
annotations.append(dict(
x=-0.3,
y=num_blocks - i - 0.5,
text=f"Q[{i}]",
showarrow=False,
font=dict(size=9, color="gray"),
))
fig.update_layout(
annotations=annotations,
xaxis=dict(
range=[-0.5, num_blocks + 0.5],
showgrid=False,
zeroline=False,
showticklabels=False,
title="Key Blocks →",
),
yaxis=dict(
range=[-0.5, num_blocks + 0.5],
showgrid=False,
zeroline=False,
showticklabels=False,
scaleanchor="x",
title="← Query Blocks",
),
height=350,
margin=dict(l=50, r=20, t=40, b=50),
title=dict(
text=f"Attention Matrix Tiling (Step {current_step + 1}/{tile_idx if current_step >= tile_idx else total_tiles})",
x=0.5,
),
showlegend=False,
)
# Add legend manually
legend_items = [
("Current", "rgba(249, 115, 22, 0.8)"),
("Done", "rgba(34, 197, 94, 0.6)"),
("Pending", "rgba(229, 231, 235, 0.5)"),
]
if causal:
legend_items.append(("Masked", "rgba(200, 200, 200, 0.3)"))
for idx, (name, color) in enumerate(legend_items):
fig.add_trace(go.Scatter(
x=[None], y=[None],
mode="markers",
marker=dict(size=15, color=color, symbol="square"),
name=name,
showlegend=True,
))
fig.update_layout(
legend=dict(
orientation="h",
yanchor="bottom",
y=-0.25,
xanchor="center",
x=0.5,
)
)
return fig
def create_online_softmax_state(
current_step: int = 0,
num_tiles: int = 4,
) -> tuple[go.Figure, str]:
"""
Create visualization of online softmax state (m, l, O) evolution.
Uses a concrete 8-token example with block_size=2.
Shows how running max (m) and sum (l) update, with rescaling when max changes.
Args:
current_step: Current tile being processed (0-indexed)
num_tiles: Total number of tiles
Returns:
Tuple of (Plotly figure, explanation text)
"""
# Pre-computed example values for 8 tokens, block_size=2
# Simulating attention scores from Q[0] to all K blocks
example_data = [
{
"tile": 0,
"block_max": 2.1,
"block_sum_exp": 3.42,
"m_before": float("-inf"),
"m_after": 2.1,
"l_before": 0.0,
"l_after": 3.42,
"rescale_factor": 1.0,
"rescaled": False,
},
{
"tile": 1,
"block_max": 3.5,
"block_sum_exp": 5.21,
"m_before": 2.1,
"m_after": 3.5,
"l_before": 3.42,
"l_after": 6.06, # 3.42 * exp(2.1-3.5) + 5.21 = 0.85 + 5.21 ≈ 6.06
"rescale_factor": 0.247, # exp(2.1 - 3.5)
"rescaled": True,
},
{
"tile": 2,
"block_max": 2.8,
"block_sum_exp": 4.01,
"m_before": 3.5,
"m_after": 3.5, # No change - block_max < m
"l_before": 6.06,
"l_after": 8.03, # 6.06 * 1.0 + 4.01 * exp(2.8-3.5)
"rescale_factor": 1.0,
"rescaled": False,
},
{
"tile": 3,
"block_max": 4.2,
"block_sum_exp": 6.83,
"m_before": 3.5,
"m_after": 4.2,
"l_before": 8.03,
"l_after": 10.79, # 8.03 * exp(3.5-4.2) + 6.83
"rescale_factor": 0.497, # exp(3.5 - 4.2)
"rescaled": True,
},
]
# Build the visualization
step = min(current_step, len(example_data) - 1)
current_data = example_data[step]
# Create figure with bar chart showing m and l evolution
fig = make_subplots(
rows=1, cols=2,
subplot_titles=("Running Max (m)", "Running Sum (l)"),
horizontal_spacing=0.15,
)
# Get historical values up to current step
m_values = [example_data[i]["m_after"] if i <= step else None for i in range(num_tiles)]
l_values = [example_data[i]["l_after"] if i <= step else None for i in range(num_tiles)]
# Colors - highlight rescaling events
m_colors = []
l_colors = []
for i in range(num_tiles):
if i > step:
m_colors.append("rgba(200, 200, 200, 0.5)")
l_colors.append("rgba(200, 200, 200, 0.5)")
elif i == step:
m_colors.append("rgba(249, 115, 22, 0.9)") # Orange for current
l_colors.append("rgba(249, 115, 22, 0.9)")
elif example_data[i]["rescaled"]:
m_colors.append("rgba(239, 68, 68, 0.7)") # Red for rescale events
l_colors.append("rgba(239, 68, 68, 0.7)")
else:
m_colors.append("rgba(34, 197, 94, 0.7)") # Green for normal
l_colors.append("rgba(34, 197, 94, 0.7)")
# Add bars for m
fig.add_trace(
go.Bar(
x=[f"Tile {i}" for i in range(num_tiles)],
y=[v if v is not None else 0 for v in m_values],
marker_color=m_colors,
text=[f"{v:.2f}" if v is not None else "" for v in m_values],
textposition="outside",
name="m (max)",
),
row=1, col=1
)
# Add bars for l
fig.add_trace(
go.Bar(
x=[f"Tile {i}" for i in range(num_tiles)],
y=[v if v is not None else 0 for v in l_values],
marker_color=l_colors,
text=[f"{v:.2f}" if v is not None else "" for v in l_values],
textposition="outside",
name="l (sum)",
),
row=1, col=2
)
# Move subplot titles down so they don't get cut off by Gradio label
for annotation in fig['layout']['annotations']:
annotation['y'] = 0.95
annotation['yanchor'] = 'top'
fig.update_layout(
height=380,
margin=dict(l=40, r=40, t=30, b=40),
showlegend=False,
)
# Increase y-axis range to make room for text labels above bars
fig.update_yaxes(range=[0, 14], row=1, col=1)
fig.update_yaxes(range=[0, 18], row=1, col=2)
# Generate explanation text
d = current_data
if d["rescaled"]:
explanation = f"""**Processing Tile {step} (Keys {step*2}-{step*2+1})**
🔴 **MAX CHANGED!** Block max ({d['block_max']:.2f}) > Previous max ({d['m_before']:.2f})
**Rescaling required:**
- Rescale factor: exp({d['m_before']:.1f} - {d['block_max']:.1f}) = **{d['rescale_factor']:.3f}**
- Previous l rescaled: {d['l_before']:.2f} × {d['rescale_factor']:.3f} = {d['l_before'] * d['rescale_factor']:.2f}
- New l = rescaled + block_sum = **{d['l_after']:.2f}**
- Previous O also rescaled by {d['rescale_factor']:.3f}
*This is the key insight: when max increases, we must rescale all previous accumulations!*
"""
else:
explanation = f"""**Processing Tile {step} (Keys {step*2}-{step*2+1})**
✅ No rescaling needed (block max {d['block_max']:.2f} ≤ current max {d['m_after']:.2f})
**Simple accumulation:**
- m stays at: **{d['m_after']:.2f}**
- l += block_sum × exp(block_max - m)
- l = {d['l_before']:.2f} + {d['block_sum_exp']:.2f} × exp({d['block_max']:.1f} - {d['m_after']:.1f}) = **{d['l_after']:.2f}**
"""
return fig, explanation
def create_memory_hierarchy_diagram(
algorithm: str = "flash",
current_step: int = 0,
) -> go.Figure:
"""
Create a diagram showing HBM vs SRAM memory hierarchy.
Args:
algorithm: "standard" or "flash"
current_step: For animation purposes
Returns:
Plotly figure showing memory hierarchy
"""
fig = go.Figure()
# Define positions
hbm_y = 0.7
sram_y = 0.3
# HBM box
fig.add_shape(
type="rect",
x0=0.05, y0=0.55, x1=0.95, y1=0.95,
fillcolor="rgba(59, 130, 246, 0.1)",
line=dict(color="rgba(59, 130, 246, 0.8)", width=2),
)
# SRAM box
fig.add_shape(
type="rect",
x0=0.2, y0=0.15, x1=0.8, y1=0.45,
fillcolor="rgba(34, 197, 94, 0.1)",
line=dict(color="rgba(34, 197, 94, 0.8)", width=2),
)
# HBM matrices (Q, K, V, O)
matrix_width = 0.15
matrices = ["Q", "K", "V", "O"]
hbm_x_start = 0.15
for i, name in enumerate(matrices):
x = hbm_x_start + i * 0.2
fig.add_shape(
type="rect",
x0=x, y0=0.65, x1=x + matrix_width, y1=0.85,
fillcolor="rgba(59, 130, 246, 0.3)",
line=dict(color="rgba(59, 130, 246, 0.6)", width=1),
)
fig.add_annotation(
x=x + matrix_width/2, y=0.75,
text=f"<b>{name}</b><br>[N, d]",
showarrow=False,
font=dict(size=11),
)
# SRAM tiles
if algorithm == "flash":
tiles = ["Q_tile", "K_tile", "V_tile", "S_tile", "O_tile"]
tile_width = 0.1
sram_x_start = 0.25
for i, name in enumerate(tiles):
x = sram_x_start + i * 0.11
# Highlight current tile being processed
is_active = (i == current_step % len(tiles))
fill = "rgba(249, 115, 22, 0.5)" if is_active else "rgba(34, 197, 94, 0.3)"
fig.add_shape(
type="rect",
x0=x, y0=0.22, x1=x + tile_width, y1=0.38,
fillcolor=fill,
line=dict(color="rgba(34, 197, 94, 0.6)", width=1),
)
fig.add_annotation(
x=x + tile_width/2, y=0.30,
text=name.replace("_", "<br>"),
showarrow=False,
font=dict(size=9),
)
# Transfer arrows (selective)
# Show only tile-sized transfers
fig.add_annotation(
x=0.5, y=0.48,
ax=0.5, ay=0.55,
xref="x", yref="y",
axref="x", ayref="y",
text="",
showarrow=True,
arrowhead=2,
arrowsize=1.5,
arrowwidth=2,
arrowcolor="rgba(34, 197, 94, 0.8)",
)
fig.add_annotation(
x=0.65, y=0.515,
text="O(B) per tile",
showarrow=False,
font=dict(size=10, color="green"),
xanchor="left",
)
else:
# Standard attention - full matrix in SRAM (doesn't fit!)
fig.add_shape(
type="rect",
x0=0.3, y0=0.22, x1=0.7, y1=0.38,
fillcolor="rgba(239, 68, 68, 0.3)",
line=dict(color="rgba(239, 68, 68, 0.6)", width=1, dash="dash"),
)
fig.add_annotation(
x=0.5, y=0.30,
text="S[N,N]<br>❌ Doesn't fit!",
showarrow=False,
font=dict(size=10, color="red"),
)
# Transfer arrows (full matrix)
fig.add_annotation(
x=0.5, y=0.48,
ax=0.5, ay=0.55,
xref="x", yref="y",
axref="x", ayref="y",
text="",
showarrow=True,
arrowhead=2,
arrowsize=1.5,
arrowwidth=2,
arrowcolor="rgba(239, 68, 68, 0.8)",
)
fig.add_annotation(
x=0.65, y=0.515,
text="O(N²) traffic!",
showarrow=False,
font=dict(size=10, color="red"),
xanchor="left",
)
# Labels
fig.add_annotation(
x=0.5, y=0.97,
text="<b>HBM (High Bandwidth Memory)</b><br>80 GB capacity | 2 TB/s bandwidth | ~400 cycles latency",
showarrow=False,
font=dict(size=11),
)
fig.add_annotation(
x=0.5, y=0.12,
text="<b>SRAM (Shared Memory)</b><br>192 KB capacity | 19 TB/s bandwidth | ~20 cycles latency",
showarrow=False,
font=dict(size=11),
)
fig.update_layout(
xaxis=dict(range=[0, 1], showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(range=[0, 1], showgrid=False, zeroline=False, showticklabels=False),
height=400,
margin=dict(l=20, r=20, t=40, b=20),
title=dict(
text=f"Memory Hierarchy: {'FlashAttention' if algorithm == 'flash' else 'Standard Attention'}",
x=0.5,
),
)
return fig
def get_max_steps(seq_len: int, block_size: int, causal: bool) -> int:
"""Calculate total number of steps for the tiling animation."""
num_blocks = seq_len // block_size
if causal:
return sum(range(1, num_blocks + 1))
return num_blocks * num_blocks