""" Visualizer for FlashAttention concepts. CPU-only animations showing tiling, online softmax, and memory hierarchy. """ import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots def create_tiling_grid( seq_len: int = 8, block_size: int = 2, current_step: int = 0, causal: bool = False ) -> go.Figure: """ Create a grid visualization showing FlashAttention tile processing. Args: seq_len: Sequence length (number of tokens) block_size: Size of each tile block current_step: Current step in the animation (0-indexed) causal: Whether to use causal masking Returns: Plotly figure with the tiling grid """ num_blocks = seq_len // block_size total_tiles = num_blocks * num_blocks if not causal else sum(range(1, num_blocks + 1)) # Create figure fig = go.Figure() # Calculate which tiles are done, current, future, or masked tile_idx = 0 annotations = [] for i in range(num_blocks): # Query blocks (rows) for j in range(num_blocks): # Key blocks (columns) x0, x1 = j, j + 1 y0, y1 = num_blocks - i - 1, num_blocks - i # Determine tile status if causal and j > i: # Masked tile (future keys for causal attention) color = "rgba(200, 200, 200, 0.3)" status = "masked" elif tile_idx < current_step: # Done color = "rgba(34, 197, 94, 0.6)" # Green status = "done" elif tile_idx == current_step: # Current color = "rgba(249, 115, 22, 0.8)" # Orange status = "current" else: # Future color = "rgba(229, 231, 235, 0.5)" # Light gray status = "pending" # Add rectangle fig.add_shape( type="rect", x0=x0, y0=y0, x1=x1, y1=y1, line=dict(color="rgba(0,0,0,0.3)", width=1), fillcolor=color, ) # Add label for current tile if status == "current": annotations.append(dict( x=(x0 + x1) / 2, y=(y0 + y1) / 2, text=f"Q[{i}]ƗK[{j}]", showarrow=False, font=dict(size=10, color="white", weight="bold"), )) if not (causal and j > i): tile_idx += 1 # Add axis labels for i in range(num_blocks): # K labels (top) annotations.append(dict( x=i + 0.5, y=num_blocks + 0.2, text=f"K[{i}]", showarrow=False, font=dict(size=9, color="gray"), )) # Q labels (left) annotations.append(dict( x=-0.3, y=num_blocks - i - 0.5, text=f"Q[{i}]", showarrow=False, font=dict(size=9, color="gray"), )) fig.update_layout( annotations=annotations, xaxis=dict( range=[-0.5, num_blocks + 0.5], showgrid=False, zeroline=False, showticklabels=False, title="Key Blocks →", ), yaxis=dict( range=[-0.5, num_blocks + 0.5], showgrid=False, zeroline=False, showticklabels=False, scaleanchor="x", title="← Query Blocks", ), height=350, margin=dict(l=50, r=20, t=40, b=50), title=dict( text=f"Attention Matrix Tiling (Step {current_step + 1}/{tile_idx if current_step >= tile_idx else total_tiles})", x=0.5, ), showlegend=False, ) # Add legend manually legend_items = [ ("Current", "rgba(249, 115, 22, 0.8)"), ("Done", "rgba(34, 197, 94, 0.6)"), ("Pending", "rgba(229, 231, 235, 0.5)"), ] if causal: legend_items.append(("Masked", "rgba(200, 200, 200, 0.3)")) for idx, (name, color) in enumerate(legend_items): fig.add_trace(go.Scatter( x=[None], y=[None], mode="markers", marker=dict(size=15, color=color, symbol="square"), name=name, showlegend=True, )) fig.update_layout( legend=dict( orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.5, ) ) return fig def create_online_softmax_state( current_step: int = 0, num_tiles: int = 4, ) -> tuple[go.Figure, str]: """ Create visualization of online softmax state (m, l, O) evolution. Uses a concrete 8-token example with block_size=2. Shows how running max (m) and sum (l) update, with rescaling when max changes. Args: current_step: Current tile being processed (0-indexed) num_tiles: Total number of tiles Returns: Tuple of (Plotly figure, explanation text) """ # Pre-computed example values for 8 tokens, block_size=2 # Simulating attention scores from Q[0] to all K blocks example_data = [ { "tile": 0, "block_max": 2.1, "block_sum_exp": 3.42, "m_before": float("-inf"), "m_after": 2.1, "l_before": 0.0, "l_after": 3.42, "rescale_factor": 1.0, "rescaled": False, }, { "tile": 1, "block_max": 3.5, "block_sum_exp": 5.21, "m_before": 2.1, "m_after": 3.5, "l_before": 3.42, "l_after": 6.06, # 3.42 * exp(2.1-3.5) + 5.21 = 0.85 + 5.21 ā‰ˆ 6.06 "rescale_factor": 0.247, # exp(2.1 - 3.5) "rescaled": True, }, { "tile": 2, "block_max": 2.8, "block_sum_exp": 4.01, "m_before": 3.5, "m_after": 3.5, # No change - block_max < m "l_before": 6.06, "l_after": 8.03, # 6.06 * 1.0 + 4.01 * exp(2.8-3.5) "rescale_factor": 1.0, "rescaled": False, }, { "tile": 3, "block_max": 4.2, "block_sum_exp": 6.83, "m_before": 3.5, "m_after": 4.2, "l_before": 8.03, "l_after": 10.79, # 8.03 * exp(3.5-4.2) + 6.83 "rescale_factor": 0.497, # exp(3.5 - 4.2) "rescaled": True, }, ] # Build the visualization step = min(current_step, len(example_data) - 1) current_data = example_data[step] # Create figure with bar chart showing m and l evolution fig = make_subplots( rows=1, cols=2, subplot_titles=("Running Max (m)", "Running Sum (l)"), horizontal_spacing=0.15, ) # Get historical values up to current step m_values = [example_data[i]["m_after"] if i <= step else None for i in range(num_tiles)] l_values = [example_data[i]["l_after"] if i <= step else None for i in range(num_tiles)] # Colors - highlight rescaling events m_colors = [] l_colors = [] for i in range(num_tiles): if i > step: m_colors.append("rgba(200, 200, 200, 0.5)") l_colors.append("rgba(200, 200, 200, 0.5)") elif i == step: m_colors.append("rgba(249, 115, 22, 0.9)") # Orange for current l_colors.append("rgba(249, 115, 22, 0.9)") elif example_data[i]["rescaled"]: m_colors.append("rgba(239, 68, 68, 0.7)") # Red for rescale events l_colors.append("rgba(239, 68, 68, 0.7)") else: m_colors.append("rgba(34, 197, 94, 0.7)") # Green for normal l_colors.append("rgba(34, 197, 94, 0.7)") # Add bars for m fig.add_trace( go.Bar( x=[f"Tile {i}" for i in range(num_tiles)], y=[v if v is not None else 0 for v in m_values], marker_color=m_colors, text=[f"{v:.2f}" if v is not None else "" for v in m_values], textposition="outside", name="m (max)", ), row=1, col=1 ) # Add bars for l fig.add_trace( go.Bar( x=[f"Tile {i}" for i in range(num_tiles)], y=[v if v is not None else 0 for v in l_values], marker_color=l_colors, text=[f"{v:.2f}" if v is not None else "" for v in l_values], textposition="outside", name="l (sum)", ), row=1, col=2 ) # Move subplot titles down so they don't get cut off by Gradio label for annotation in fig['layout']['annotations']: annotation['y'] = 0.95 annotation['yanchor'] = 'top' fig.update_layout( height=380, margin=dict(l=40, r=40, t=30, b=40), showlegend=False, ) # Increase y-axis range to make room for text labels above bars fig.update_yaxes(range=[0, 14], row=1, col=1) fig.update_yaxes(range=[0, 18], row=1, col=2) # Generate explanation text d = current_data if d["rescaled"]: explanation = f"""**Processing Tile {step} (Keys {step*2}-{step*2+1})** šŸ”“ **MAX CHANGED!** Block max ({d['block_max']:.2f}) > Previous max ({d['m_before']:.2f}) **Rescaling required:** - Rescale factor: exp({d['m_before']:.1f} - {d['block_max']:.1f}) = **{d['rescale_factor']:.3f}** - Previous l rescaled: {d['l_before']:.2f} Ɨ {d['rescale_factor']:.3f} = {d['l_before'] * d['rescale_factor']:.2f} - New l = rescaled + block_sum = **{d['l_after']:.2f}** - Previous O also rescaled by {d['rescale_factor']:.3f} *This is the key insight: when max increases, we must rescale all previous accumulations!* """ else: explanation = f"""**Processing Tile {step} (Keys {step*2}-{step*2+1})** āœ… No rescaling needed (block max {d['block_max']:.2f} ≤ current max {d['m_after']:.2f}) **Simple accumulation:** - m stays at: **{d['m_after']:.2f}** - l += block_sum Ɨ exp(block_max - m) - l = {d['l_before']:.2f} + {d['block_sum_exp']:.2f} Ɨ exp({d['block_max']:.1f} - {d['m_after']:.1f}) = **{d['l_after']:.2f}** """ return fig, explanation def create_memory_hierarchy_diagram( algorithm: str = "flash", current_step: int = 0, ) -> go.Figure: """ Create a diagram showing HBM vs SRAM memory hierarchy. Args: algorithm: "standard" or "flash" current_step: For animation purposes Returns: Plotly figure showing memory hierarchy """ fig = go.Figure() # Define positions hbm_y = 0.7 sram_y = 0.3 # HBM box fig.add_shape( type="rect", x0=0.05, y0=0.55, x1=0.95, y1=0.95, fillcolor="rgba(59, 130, 246, 0.1)", line=dict(color="rgba(59, 130, 246, 0.8)", width=2), ) # SRAM box fig.add_shape( type="rect", x0=0.2, y0=0.15, x1=0.8, y1=0.45, fillcolor="rgba(34, 197, 94, 0.1)", line=dict(color="rgba(34, 197, 94, 0.8)", width=2), ) # HBM matrices (Q, K, V, O) matrix_width = 0.15 matrices = ["Q", "K", "V", "O"] hbm_x_start = 0.15 for i, name in enumerate(matrices): x = hbm_x_start + i * 0.2 fig.add_shape( type="rect", x0=x, y0=0.65, x1=x + matrix_width, y1=0.85, fillcolor="rgba(59, 130, 246, 0.3)", line=dict(color="rgba(59, 130, 246, 0.6)", width=1), ) fig.add_annotation( x=x + matrix_width/2, y=0.75, text=f"{name}
[N, d]", showarrow=False, font=dict(size=11), ) # SRAM tiles if algorithm == "flash": tiles = ["Q_tile", "K_tile", "V_tile", "S_tile", "O_tile"] tile_width = 0.1 sram_x_start = 0.25 for i, name in enumerate(tiles): x = sram_x_start + i * 0.11 # Highlight current tile being processed is_active = (i == current_step % len(tiles)) fill = "rgba(249, 115, 22, 0.5)" if is_active else "rgba(34, 197, 94, 0.3)" fig.add_shape( type="rect", x0=x, y0=0.22, x1=x + tile_width, y1=0.38, fillcolor=fill, line=dict(color="rgba(34, 197, 94, 0.6)", width=1), ) fig.add_annotation( x=x + tile_width/2, y=0.30, text=name.replace("_", "
"), showarrow=False, font=dict(size=9), ) # Transfer arrows (selective) # Show only tile-sized transfers fig.add_annotation( x=0.5, y=0.48, ax=0.5, ay=0.55, xref="x", yref="y", axref="x", ayref="y", text="", showarrow=True, arrowhead=2, arrowsize=1.5, arrowwidth=2, arrowcolor="rgba(34, 197, 94, 0.8)", ) fig.add_annotation( x=0.65, y=0.515, text="O(B) per tile", showarrow=False, font=dict(size=10, color="green"), xanchor="left", ) else: # Standard attention - full matrix in SRAM (doesn't fit!) fig.add_shape( type="rect", x0=0.3, y0=0.22, x1=0.7, y1=0.38, fillcolor="rgba(239, 68, 68, 0.3)", line=dict(color="rgba(239, 68, 68, 0.6)", width=1, dash="dash"), ) fig.add_annotation( x=0.5, y=0.30, text="S[N,N]
āŒ Doesn't fit!", showarrow=False, font=dict(size=10, color="red"), ) # Transfer arrows (full matrix) fig.add_annotation( x=0.5, y=0.48, ax=0.5, ay=0.55, xref="x", yref="y", axref="x", ayref="y", text="", showarrow=True, arrowhead=2, arrowsize=1.5, arrowwidth=2, arrowcolor="rgba(239, 68, 68, 0.8)", ) fig.add_annotation( x=0.65, y=0.515, text="O(N²) traffic!", showarrow=False, font=dict(size=10, color="red"), xanchor="left", ) # Labels fig.add_annotation( x=0.5, y=0.97, text="HBM (High Bandwidth Memory)
80 GB capacity | 2 TB/s bandwidth | ~400 cycles latency", showarrow=False, font=dict(size=11), ) fig.add_annotation( x=0.5, y=0.12, text="SRAM (Shared Memory)
192 KB capacity | 19 TB/s bandwidth | ~20 cycles latency", showarrow=False, font=dict(size=11), ) fig.update_layout( xaxis=dict(range=[0, 1], showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(range=[0, 1], showgrid=False, zeroline=False, showticklabels=False), height=400, margin=dict(l=20, r=20, t=40, b=20), title=dict( text=f"Memory Hierarchy: {'FlashAttention' if algorithm == 'flash' else 'Standard Attention'}", x=0.5, ), ) return fig def get_max_steps(seq_len: int, block_size: int, causal: bool) -> int: """Calculate total number of steps for the tiling animation.""" num_blocks = seq_len // block_size if causal: return sum(range(1, num_blocks + 1)) return num_blocks * num_blocks