Spaces:
Sleeping
Sleeping
| """ | |
| Visualizer for FlashAttention concepts. | |
| CPU-only animations showing tiling, online softmax, and memory hierarchy. | |
| """ | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| def create_tiling_grid( | |
| seq_len: int = 8, | |
| block_size: int = 2, | |
| current_step: int = 0, | |
| causal: bool = False | |
| ) -> go.Figure: | |
| """ | |
| Create a grid visualization showing FlashAttention tile processing. | |
| Args: | |
| seq_len: Sequence length (number of tokens) | |
| block_size: Size of each tile block | |
| current_step: Current step in the animation (0-indexed) | |
| causal: Whether to use causal masking | |
| Returns: | |
| Plotly figure with the tiling grid | |
| """ | |
| num_blocks = seq_len // block_size | |
| total_tiles = num_blocks * num_blocks if not causal else sum(range(1, num_blocks + 1)) | |
| # Create figure | |
| fig = go.Figure() | |
| # Calculate which tiles are done, current, future, or masked | |
| tile_idx = 0 | |
| annotations = [] | |
| for i in range(num_blocks): # Query blocks (rows) | |
| for j in range(num_blocks): # Key blocks (columns) | |
| x0, x1 = j, j + 1 | |
| y0, y1 = num_blocks - i - 1, num_blocks - i | |
| # Determine tile status | |
| if causal and j > i: | |
| # Masked tile (future keys for causal attention) | |
| color = "rgba(200, 200, 200, 0.3)" | |
| status = "masked" | |
| elif tile_idx < current_step: | |
| # Done | |
| color = "rgba(34, 197, 94, 0.6)" # Green | |
| status = "done" | |
| elif tile_idx == current_step: | |
| # Current | |
| color = "rgba(249, 115, 22, 0.8)" # Orange | |
| status = "current" | |
| else: | |
| # Future | |
| color = "rgba(229, 231, 235, 0.5)" # Light gray | |
| status = "pending" | |
| # Add rectangle | |
| fig.add_shape( | |
| type="rect", | |
| x0=x0, y0=y0, x1=x1, y1=y1, | |
| line=dict(color="rgba(0,0,0,0.3)", width=1), | |
| fillcolor=color, | |
| ) | |
| # Add label for current tile | |
| if status == "current": | |
| annotations.append(dict( | |
| x=(x0 + x1) / 2, | |
| y=(y0 + y1) / 2, | |
| text=f"Q[{i}]×K[{j}]", | |
| showarrow=False, | |
| font=dict(size=10, color="white", weight="bold"), | |
| )) | |
| if not (causal and j > i): | |
| tile_idx += 1 | |
| # Add axis labels | |
| for i in range(num_blocks): | |
| # K labels (top) | |
| annotations.append(dict( | |
| x=i + 0.5, | |
| y=num_blocks + 0.2, | |
| text=f"K[{i}]", | |
| showarrow=False, | |
| font=dict(size=9, color="gray"), | |
| )) | |
| # Q labels (left) | |
| annotations.append(dict( | |
| x=-0.3, | |
| y=num_blocks - i - 0.5, | |
| text=f"Q[{i}]", | |
| showarrow=False, | |
| font=dict(size=9, color="gray"), | |
| )) | |
| fig.update_layout( | |
| annotations=annotations, | |
| xaxis=dict( | |
| range=[-0.5, num_blocks + 0.5], | |
| showgrid=False, | |
| zeroline=False, | |
| showticklabels=False, | |
| title="Key Blocks →", | |
| ), | |
| yaxis=dict( | |
| range=[-0.5, num_blocks + 0.5], | |
| showgrid=False, | |
| zeroline=False, | |
| showticklabels=False, | |
| scaleanchor="x", | |
| title="← Query Blocks", | |
| ), | |
| height=350, | |
| margin=dict(l=50, r=20, t=40, b=50), | |
| title=dict( | |
| text=f"Attention Matrix Tiling (Step {current_step + 1}/{tile_idx if current_step >= tile_idx else total_tiles})", | |
| x=0.5, | |
| ), | |
| showlegend=False, | |
| ) | |
| # Add legend manually | |
| legend_items = [ | |
| ("Current", "rgba(249, 115, 22, 0.8)"), | |
| ("Done", "rgba(34, 197, 94, 0.6)"), | |
| ("Pending", "rgba(229, 231, 235, 0.5)"), | |
| ] | |
| if causal: | |
| legend_items.append(("Masked", "rgba(200, 200, 200, 0.3)")) | |
| for idx, (name, color) in enumerate(legend_items): | |
| fig.add_trace(go.Scatter( | |
| x=[None], y=[None], | |
| mode="markers", | |
| marker=dict(size=15, color=color, symbol="square"), | |
| name=name, | |
| showlegend=True, | |
| )) | |
| fig.update_layout( | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=-0.25, | |
| xanchor="center", | |
| x=0.5, | |
| ) | |
| ) | |
| return fig | |
| def create_online_softmax_state( | |
| current_step: int = 0, | |
| num_tiles: int = 4, | |
| ) -> tuple[go.Figure, str]: | |
| """ | |
| Create visualization of online softmax state (m, l, O) evolution. | |
| Uses a concrete 8-token example with block_size=2. | |
| Shows how running max (m) and sum (l) update, with rescaling when max changes. | |
| Args: | |
| current_step: Current tile being processed (0-indexed) | |
| num_tiles: Total number of tiles | |
| Returns: | |
| Tuple of (Plotly figure, explanation text) | |
| """ | |
| # Pre-computed example values for 8 tokens, block_size=2 | |
| # Simulating attention scores from Q[0] to all K blocks | |
| example_data = [ | |
| { | |
| "tile": 0, | |
| "block_max": 2.1, | |
| "block_sum_exp": 3.42, | |
| "m_before": float("-inf"), | |
| "m_after": 2.1, | |
| "l_before": 0.0, | |
| "l_after": 3.42, | |
| "rescale_factor": 1.0, | |
| "rescaled": False, | |
| }, | |
| { | |
| "tile": 1, | |
| "block_max": 3.5, | |
| "block_sum_exp": 5.21, | |
| "m_before": 2.1, | |
| "m_after": 3.5, | |
| "l_before": 3.42, | |
| "l_after": 6.06, # 3.42 * exp(2.1-3.5) + 5.21 = 0.85 + 5.21 ≈ 6.06 | |
| "rescale_factor": 0.247, # exp(2.1 - 3.5) | |
| "rescaled": True, | |
| }, | |
| { | |
| "tile": 2, | |
| "block_max": 2.8, | |
| "block_sum_exp": 4.01, | |
| "m_before": 3.5, | |
| "m_after": 3.5, # No change - block_max < m | |
| "l_before": 6.06, | |
| "l_after": 8.03, # 6.06 * 1.0 + 4.01 * exp(2.8-3.5) | |
| "rescale_factor": 1.0, | |
| "rescaled": False, | |
| }, | |
| { | |
| "tile": 3, | |
| "block_max": 4.2, | |
| "block_sum_exp": 6.83, | |
| "m_before": 3.5, | |
| "m_after": 4.2, | |
| "l_before": 8.03, | |
| "l_after": 10.79, # 8.03 * exp(3.5-4.2) + 6.83 | |
| "rescale_factor": 0.497, # exp(3.5 - 4.2) | |
| "rescaled": True, | |
| }, | |
| ] | |
| # Build the visualization | |
| step = min(current_step, len(example_data) - 1) | |
| current_data = example_data[step] | |
| # Create figure with bar chart showing m and l evolution | |
| fig = make_subplots( | |
| rows=1, cols=2, | |
| subplot_titles=("Running Max (m)", "Running Sum (l)"), | |
| horizontal_spacing=0.15, | |
| ) | |
| # Get historical values up to current step | |
| m_values = [example_data[i]["m_after"] if i <= step else None for i in range(num_tiles)] | |
| l_values = [example_data[i]["l_after"] if i <= step else None for i in range(num_tiles)] | |
| # Colors - highlight rescaling events | |
| m_colors = [] | |
| l_colors = [] | |
| for i in range(num_tiles): | |
| if i > step: | |
| m_colors.append("rgba(200, 200, 200, 0.5)") | |
| l_colors.append("rgba(200, 200, 200, 0.5)") | |
| elif i == step: | |
| m_colors.append("rgba(249, 115, 22, 0.9)") # Orange for current | |
| l_colors.append("rgba(249, 115, 22, 0.9)") | |
| elif example_data[i]["rescaled"]: | |
| m_colors.append("rgba(239, 68, 68, 0.7)") # Red for rescale events | |
| l_colors.append("rgba(239, 68, 68, 0.7)") | |
| else: | |
| m_colors.append("rgba(34, 197, 94, 0.7)") # Green for normal | |
| l_colors.append("rgba(34, 197, 94, 0.7)") | |
| # Add bars for m | |
| fig.add_trace( | |
| go.Bar( | |
| x=[f"Tile {i}" for i in range(num_tiles)], | |
| y=[v if v is not None else 0 for v in m_values], | |
| marker_color=m_colors, | |
| text=[f"{v:.2f}" if v is not None else "" for v in m_values], | |
| textposition="outside", | |
| name="m (max)", | |
| ), | |
| row=1, col=1 | |
| ) | |
| # Add bars for l | |
| fig.add_trace( | |
| go.Bar( | |
| x=[f"Tile {i}" for i in range(num_tiles)], | |
| y=[v if v is not None else 0 for v in l_values], | |
| marker_color=l_colors, | |
| text=[f"{v:.2f}" if v is not None else "" for v in l_values], | |
| textposition="outside", | |
| name="l (sum)", | |
| ), | |
| row=1, col=2 | |
| ) | |
| # Move subplot titles down so they don't get cut off by Gradio label | |
| for annotation in fig['layout']['annotations']: | |
| annotation['y'] = 0.95 | |
| annotation['yanchor'] = 'top' | |
| fig.update_layout( | |
| height=380, | |
| margin=dict(l=40, r=40, t=30, b=40), | |
| showlegend=False, | |
| ) | |
| # Increase y-axis range to make room for text labels above bars | |
| fig.update_yaxes(range=[0, 14], row=1, col=1) | |
| fig.update_yaxes(range=[0, 18], row=1, col=2) | |
| # Generate explanation text | |
| d = current_data | |
| if d["rescaled"]: | |
| explanation = f"""**Processing Tile {step} (Keys {step*2}-{step*2+1})** | |
| 🔴 **MAX CHANGED!** Block max ({d['block_max']:.2f}) > Previous max ({d['m_before']:.2f}) | |
| **Rescaling required:** | |
| - Rescale factor: exp({d['m_before']:.1f} - {d['block_max']:.1f}) = **{d['rescale_factor']:.3f}** | |
| - Previous l rescaled: {d['l_before']:.2f} × {d['rescale_factor']:.3f} = {d['l_before'] * d['rescale_factor']:.2f} | |
| - New l = rescaled + block_sum = **{d['l_after']:.2f}** | |
| - Previous O also rescaled by {d['rescale_factor']:.3f} | |
| *This is the key insight: when max increases, we must rescale all previous accumulations!* | |
| """ | |
| else: | |
| explanation = f"""**Processing Tile {step} (Keys {step*2}-{step*2+1})** | |
| ✅ No rescaling needed (block max {d['block_max']:.2f} ≤ current max {d['m_after']:.2f}) | |
| **Simple accumulation:** | |
| - m stays at: **{d['m_after']:.2f}** | |
| - l += block_sum × exp(block_max - m) | |
| - l = {d['l_before']:.2f} + {d['block_sum_exp']:.2f} × exp({d['block_max']:.1f} - {d['m_after']:.1f}) = **{d['l_after']:.2f}** | |
| """ | |
| return fig, explanation | |
| def create_memory_hierarchy_diagram( | |
| algorithm: str = "flash", | |
| current_step: int = 0, | |
| ) -> go.Figure: | |
| """ | |
| Create a diagram showing HBM vs SRAM memory hierarchy. | |
| Args: | |
| algorithm: "standard" or "flash" | |
| current_step: For animation purposes | |
| Returns: | |
| Plotly figure showing memory hierarchy | |
| """ | |
| fig = go.Figure() | |
| # Define positions | |
| hbm_y = 0.7 | |
| sram_y = 0.3 | |
| # HBM box | |
| fig.add_shape( | |
| type="rect", | |
| x0=0.05, y0=0.55, x1=0.95, y1=0.95, | |
| fillcolor="rgba(59, 130, 246, 0.1)", | |
| line=dict(color="rgba(59, 130, 246, 0.8)", width=2), | |
| ) | |
| # SRAM box | |
| fig.add_shape( | |
| type="rect", | |
| x0=0.2, y0=0.15, x1=0.8, y1=0.45, | |
| fillcolor="rgba(34, 197, 94, 0.1)", | |
| line=dict(color="rgba(34, 197, 94, 0.8)", width=2), | |
| ) | |
| # HBM matrices (Q, K, V, O) | |
| matrix_width = 0.15 | |
| matrices = ["Q", "K", "V", "O"] | |
| hbm_x_start = 0.15 | |
| for i, name in enumerate(matrices): | |
| x = hbm_x_start + i * 0.2 | |
| fig.add_shape( | |
| type="rect", | |
| x0=x, y0=0.65, x1=x + matrix_width, y1=0.85, | |
| fillcolor="rgba(59, 130, 246, 0.3)", | |
| line=dict(color="rgba(59, 130, 246, 0.6)", width=1), | |
| ) | |
| fig.add_annotation( | |
| x=x + matrix_width/2, y=0.75, | |
| text=f"<b>{name}</b><br>[N, d]", | |
| showarrow=False, | |
| font=dict(size=11), | |
| ) | |
| # SRAM tiles | |
| if algorithm == "flash": | |
| tiles = ["Q_tile", "K_tile", "V_tile", "S_tile", "O_tile"] | |
| tile_width = 0.1 | |
| sram_x_start = 0.25 | |
| for i, name in enumerate(tiles): | |
| x = sram_x_start + i * 0.11 | |
| # Highlight current tile being processed | |
| is_active = (i == current_step % len(tiles)) | |
| fill = "rgba(249, 115, 22, 0.5)" if is_active else "rgba(34, 197, 94, 0.3)" | |
| fig.add_shape( | |
| type="rect", | |
| x0=x, y0=0.22, x1=x + tile_width, y1=0.38, | |
| fillcolor=fill, | |
| line=dict(color="rgba(34, 197, 94, 0.6)", width=1), | |
| ) | |
| fig.add_annotation( | |
| x=x + tile_width/2, y=0.30, | |
| text=name.replace("_", "<br>"), | |
| showarrow=False, | |
| font=dict(size=9), | |
| ) | |
| # Transfer arrows (selective) | |
| # Show only tile-sized transfers | |
| fig.add_annotation( | |
| x=0.5, y=0.48, | |
| ax=0.5, ay=0.55, | |
| xref="x", yref="y", | |
| axref="x", ayref="y", | |
| text="", | |
| showarrow=True, | |
| arrowhead=2, | |
| arrowsize=1.5, | |
| arrowwidth=2, | |
| arrowcolor="rgba(34, 197, 94, 0.8)", | |
| ) | |
| fig.add_annotation( | |
| x=0.65, y=0.515, | |
| text="O(B) per tile", | |
| showarrow=False, | |
| font=dict(size=10, color="green"), | |
| xanchor="left", | |
| ) | |
| else: | |
| # Standard attention - full matrix in SRAM (doesn't fit!) | |
| fig.add_shape( | |
| type="rect", | |
| x0=0.3, y0=0.22, x1=0.7, y1=0.38, | |
| fillcolor="rgba(239, 68, 68, 0.3)", | |
| line=dict(color="rgba(239, 68, 68, 0.6)", width=1, dash="dash"), | |
| ) | |
| fig.add_annotation( | |
| x=0.5, y=0.30, | |
| text="S[N,N]<br>❌ Doesn't fit!", | |
| showarrow=False, | |
| font=dict(size=10, color="red"), | |
| ) | |
| # Transfer arrows (full matrix) | |
| fig.add_annotation( | |
| x=0.5, y=0.48, | |
| ax=0.5, ay=0.55, | |
| xref="x", yref="y", | |
| axref="x", ayref="y", | |
| text="", | |
| showarrow=True, | |
| arrowhead=2, | |
| arrowsize=1.5, | |
| arrowwidth=2, | |
| arrowcolor="rgba(239, 68, 68, 0.8)", | |
| ) | |
| fig.add_annotation( | |
| x=0.65, y=0.515, | |
| text="O(N²) traffic!", | |
| showarrow=False, | |
| font=dict(size=10, color="red"), | |
| xanchor="left", | |
| ) | |
| # Labels | |
| fig.add_annotation( | |
| x=0.5, y=0.97, | |
| text="<b>HBM (High Bandwidth Memory)</b><br>80 GB capacity | 2 TB/s bandwidth | ~400 cycles latency", | |
| showarrow=False, | |
| font=dict(size=11), | |
| ) | |
| fig.add_annotation( | |
| x=0.5, y=0.12, | |
| text="<b>SRAM (Shared Memory)</b><br>192 KB capacity | 19 TB/s bandwidth | ~20 cycles latency", | |
| showarrow=False, | |
| font=dict(size=11), | |
| ) | |
| fig.update_layout( | |
| xaxis=dict(range=[0, 1], showgrid=False, zeroline=False, showticklabels=False), | |
| yaxis=dict(range=[0, 1], showgrid=False, zeroline=False, showticklabels=False), | |
| height=400, | |
| margin=dict(l=20, r=20, t=40, b=20), | |
| title=dict( | |
| text=f"Memory Hierarchy: {'FlashAttention' if algorithm == 'flash' else 'Standard Attention'}", | |
| x=0.5, | |
| ), | |
| ) | |
| return fig | |
| def get_max_steps(seq_len: int, block_size: int, causal: bool) -> int: | |
| """Calculate total number of steps for the tiling animation.""" | |
| num_blocks = seq_len // block_size | |
| if causal: | |
| return sum(range(1, num_blocks + 1)) | |
| return num_blocks * num_blocks | |