a0y0346 commited on
Commit
c9bdf44
·
1 Parent(s): 7539a21

Phase 2: Add Visualizer tab with tiling animation, online softmax, and memory hierarchy

Browse files
Files changed (2) hide show
  1. app.py +131 -10
  2. src/visualizer.py +490 -0
app.py CHANGED
@@ -31,6 +31,12 @@ from src.constants import (
31
  SEQ_LENGTH_OPTIONS,
32
  )
33
  from src.models import get_available_models, get_model_memory_footprint
 
 
 
 
 
 
34
 
35
 
36
  def create_app() -> gr.Blocks:
@@ -59,22 +65,137 @@ def create_app() -> gr.Blocks:
59
  ## FlashAttention Visualizer
60
 
61
  Understand how FlashAttention processes attention in tiles,
62
- avoiding the O(N²) memory bottleneck.
63
-
64
- *Coming in Phase 2...*
65
  """)
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  with gr.Row():
68
- with gr.Column():
69
- gr.Markdown("### Tiling Animation")
70
- gr.Markdown("*Tile-by-tile attention computation will appear here*")
71
 
72
- with gr.Column():
73
  gr.Markdown("### Online Softmax State")
74
- gr.Markdown("*Running max (m), sum (l), and output (O) tracking*")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- gr.Markdown("### Memory Hierarchy")
77
- gr.Markdown("*HBM ↔ SRAM data movement visualization*")
 
 
 
 
 
 
 
 
 
78
 
79
  # Tab 2: Benchmark (Zero GPU)
80
  with gr.Tab("Benchmark", id="tab-benchmark"):
 
31
  SEQ_LENGTH_OPTIONS,
32
  )
33
  from src.models import get_available_models, get_model_memory_footprint
34
+ from src.visualizer import (
35
+ create_tiling_grid,
36
+ create_online_softmax_state,
37
+ create_memory_hierarchy_diagram,
38
+ get_max_steps,
39
+ )
40
 
41
 
42
  def create_app() -> gr.Blocks:
 
65
  ## FlashAttention Visualizer
66
 
67
  Understand how FlashAttention processes attention in tiles,
68
+ avoiding the O(N²) memory bottleneck. Step through the algorithm
69
+ to see how tiles are processed and how online softmax maintains
70
+ running statistics.
71
  """)
72
 
73
+ # Controls
74
+ with gr.Row():
75
+ with gr.Column(scale=1):
76
+ seq_len_viz = gr.Slider(
77
+ minimum=4,
78
+ maximum=16,
79
+ step=2,
80
+ value=8,
81
+ label="Sequence Length (tokens)",
82
+ )
83
+ with gr.Column(scale=1):
84
+ block_size_viz = gr.Slider(
85
+ minimum=2,
86
+ maximum=4,
87
+ step=1,
88
+ value=2,
89
+ label="Block Size",
90
+ )
91
+ with gr.Column(scale=1):
92
+ causal_viz = gr.Checkbox(
93
+ value=False,
94
+ label="Causal Masking",
95
+ )
96
+
97
+ # Step controls
98
+ with gr.Row():
99
+ step_back_btn = gr.Button("◀ Step Back", size="sm")
100
+ step_slider = gr.Slider(
101
+ minimum=0,
102
+ maximum=15,
103
+ step=1,
104
+ value=0,
105
+ label="Current Step",
106
+ )
107
+ step_forward_btn = gr.Button("Step Forward ▶", size="sm")
108
+ reset_btn = gr.Button("Reset", size="sm", variant="secondary")
109
+
110
+ # Tiling and Online Softmax side by side
111
  with gr.Row():
112
+ with gr.Column(scale=1):
113
+ gr.Markdown("### Attention Matrix Tiling")
114
+ tiling_plot = gr.Plot(label="Tiling View")
115
 
116
+ with gr.Column(scale=1):
117
  gr.Markdown("### Online Softmax State")
118
+ softmax_plot = gr.Plot(label="Running m and l")
119
+ softmax_explanation = gr.Markdown("*Step through to see online softmax updates*")
120
+
121
+ # Memory Hierarchy
122
+ gr.Markdown("### Memory Hierarchy Comparison")
123
+ with gr.Row():
124
+ algo_choice = gr.Radio(
125
+ choices=["flash", "standard"],
126
+ value="flash",
127
+ label="Algorithm",
128
+ )
129
+ memory_plot = gr.Plot(label="Memory Hierarchy")
130
+
131
+ # Event handlers for visualizer
132
+ def update_visualizations(seq_len, block_size, causal, step):
133
+ """Update all visualizations based on current parameters."""
134
+ max_steps = get_max_steps(seq_len, block_size, causal)
135
+ # Clamp step to valid range
136
+ step = min(step, max_steps - 1)
137
+ step = max(step, 0)
138
+
139
+ tiling_fig = create_tiling_grid(seq_len, block_size, step, causal)
140
+
141
+ # Online softmax uses 4 tiles for the example
142
+ num_tiles = seq_len // block_size
143
+ softmax_step = min(step, num_tiles - 1)
144
+ softmax_fig, explanation = create_online_softmax_state(softmax_step, num_tiles)
145
+
146
+ return tiling_fig, softmax_fig, explanation, step
147
+
148
+ def update_memory_hierarchy(algo):
149
+ """Update memory hierarchy diagram."""
150
+ return create_memory_hierarchy_diagram(algo)
151
+
152
+ def step_forward(seq_len, block_size, causal, current_step):
153
+ """Move to next step."""
154
+ max_steps = get_max_steps(seq_len, block_size, causal)
155
+ new_step = min(current_step + 1, max_steps - 1)
156
+ return new_step
157
+
158
+ def step_back(current_step):
159
+ """Move to previous step."""
160
+ return max(current_step - 1, 0)
161
+
162
+ def reset_step():
163
+ """Reset to step 0."""
164
+ return 0
165
+
166
+ # Wire up events
167
+ viz_inputs = [seq_len_viz, block_size_viz, causal_viz, step_slider]
168
+ viz_outputs = [tiling_plot, softmax_plot, softmax_explanation, step_slider]
169
+
170
+ # Update on parameter change
171
+ seq_len_viz.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
172
+ block_size_viz.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
173
+ causal_viz.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
174
+ step_slider.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
175
+
176
+ # Step controls
177
+ step_forward_btn.click(
178
+ fn=step_forward,
179
+ inputs=[seq_len_viz, block_size_viz, causal_viz, step_slider],
180
+ outputs=step_slider
181
+ )
182
+ step_back_btn.click(fn=step_back, inputs=step_slider, outputs=step_slider)
183
+ reset_btn.click(fn=reset_step, outputs=step_slider)
184
+
185
+ # Memory hierarchy
186
+ algo_choice.change(fn=update_memory_hierarchy, inputs=algo_choice, outputs=memory_plot)
187
 
188
+ # Initialize on load
189
+ demo.load(
190
+ fn=update_visualizations,
191
+ inputs=viz_inputs,
192
+ outputs=viz_outputs
193
+ )
194
+ demo.load(
195
+ fn=update_memory_hierarchy,
196
+ inputs=algo_choice,
197
+ outputs=memory_plot
198
+ )
199
 
200
  # Tab 2: Benchmark (Zero GPU)
201
  with gr.Tab("Benchmark", id="tab-benchmark"):
src/visualizer.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualizer for FlashAttention concepts.
3
+ CPU-only animations showing tiling, online softmax, and memory hierarchy.
4
+ """
5
+
6
+ import numpy as np
7
+ import plotly.graph_objects as go
8
+ from plotly.subplots import make_subplots
9
+
10
+
11
+ def create_tiling_grid(
12
+ seq_len: int = 8,
13
+ block_size: int = 2,
14
+ current_step: int = 0,
15
+ causal: bool = False
16
+ ) -> go.Figure:
17
+ """
18
+ Create a grid visualization showing FlashAttention tile processing.
19
+
20
+ Args:
21
+ seq_len: Sequence length (number of tokens)
22
+ block_size: Size of each tile block
23
+ current_step: Current step in the animation (0-indexed)
24
+ causal: Whether to use causal masking
25
+
26
+ Returns:
27
+ Plotly figure with the tiling grid
28
+ """
29
+ num_blocks = seq_len // block_size
30
+ total_tiles = num_blocks * num_blocks if not causal else sum(range(1, num_blocks + 1))
31
+
32
+ # Create figure
33
+ fig = go.Figure()
34
+
35
+ # Calculate which tiles are done, current, future, or masked
36
+ tile_idx = 0
37
+ annotations = []
38
+
39
+ for i in range(num_blocks): # Query blocks (rows)
40
+ for j in range(num_blocks): # Key blocks (columns)
41
+ x0, x1 = j, j + 1
42
+ y0, y1 = num_blocks - i - 1, num_blocks - i
43
+
44
+ # Determine tile status
45
+ if causal and j > i:
46
+ # Masked tile (future keys for causal attention)
47
+ color = "rgba(200, 200, 200, 0.3)"
48
+ status = "masked"
49
+ elif tile_idx < current_step:
50
+ # Done
51
+ color = "rgba(34, 197, 94, 0.6)" # Green
52
+ status = "done"
53
+ elif tile_idx == current_step:
54
+ # Current
55
+ color = "rgba(249, 115, 22, 0.8)" # Orange
56
+ status = "current"
57
+ else:
58
+ # Future
59
+ color = "rgba(229, 231, 235, 0.5)" # Light gray
60
+ status = "pending"
61
+
62
+ # Add rectangle
63
+ fig.add_shape(
64
+ type="rect",
65
+ x0=x0, y0=y0, x1=x1, y1=y1,
66
+ line=dict(color="rgba(0,0,0,0.3)", width=1),
67
+ fillcolor=color,
68
+ )
69
+
70
+ # Add label for current tile
71
+ if status == "current":
72
+ annotations.append(dict(
73
+ x=(x0 + x1) / 2,
74
+ y=(y0 + y1) / 2,
75
+ text=f"Q[{i}]×K[{j}]",
76
+ showarrow=False,
77
+ font=dict(size=10, color="white", weight="bold"),
78
+ ))
79
+
80
+ if not (causal and j > i):
81
+ tile_idx += 1
82
+
83
+ # Add axis labels
84
+ for i in range(num_blocks):
85
+ # K labels (top)
86
+ annotations.append(dict(
87
+ x=i + 0.5,
88
+ y=num_blocks + 0.2,
89
+ text=f"K[{i}]",
90
+ showarrow=False,
91
+ font=dict(size=9, color="gray"),
92
+ ))
93
+ # Q labels (left)
94
+ annotations.append(dict(
95
+ x=-0.3,
96
+ y=num_blocks - i - 0.5,
97
+ text=f"Q[{i}]",
98
+ showarrow=False,
99
+ font=dict(size=9, color="gray"),
100
+ ))
101
+
102
+ fig.update_layout(
103
+ annotations=annotations,
104
+ xaxis=dict(
105
+ range=[-0.5, num_blocks + 0.5],
106
+ showgrid=False,
107
+ zeroline=False,
108
+ showticklabels=False,
109
+ title="Key Blocks →",
110
+ ),
111
+ yaxis=dict(
112
+ range=[-0.5, num_blocks + 0.5],
113
+ showgrid=False,
114
+ zeroline=False,
115
+ showticklabels=False,
116
+ scaleanchor="x",
117
+ title="← Query Blocks",
118
+ ),
119
+ height=350,
120
+ margin=dict(l=50, r=20, t=40, b=50),
121
+ title=dict(
122
+ text=f"Attention Matrix Tiling (Step {current_step + 1}/{tile_idx if current_step >= tile_idx else total_tiles})",
123
+ x=0.5,
124
+ ),
125
+ showlegend=False,
126
+ )
127
+
128
+ # Add legend manually
129
+ legend_items = [
130
+ ("Current", "rgba(249, 115, 22, 0.8)"),
131
+ ("Done", "rgba(34, 197, 94, 0.6)"),
132
+ ("Pending", "rgba(229, 231, 235, 0.5)"),
133
+ ]
134
+ if causal:
135
+ legend_items.append(("Masked", "rgba(200, 200, 200, 0.3)"))
136
+
137
+ for idx, (name, color) in enumerate(legend_items):
138
+ fig.add_trace(go.Scatter(
139
+ x=[None], y=[None],
140
+ mode="markers",
141
+ marker=dict(size=15, color=color, symbol="square"),
142
+ name=name,
143
+ showlegend=True,
144
+ ))
145
+
146
+ fig.update_layout(
147
+ legend=dict(
148
+ orientation="h",
149
+ yanchor="bottom",
150
+ y=-0.25,
151
+ xanchor="center",
152
+ x=0.5,
153
+ )
154
+ )
155
+
156
+ return fig
157
+
158
+
159
+ def create_online_softmax_state(
160
+ current_step: int = 0,
161
+ num_tiles: int = 4,
162
+ ) -> tuple[go.Figure, str]:
163
+ """
164
+ Create visualization of online softmax state (m, l, O) evolution.
165
+
166
+ Uses a concrete 8-token example with block_size=2.
167
+ Shows how running max (m) and sum (l) update, with rescaling when max changes.
168
+
169
+ Args:
170
+ current_step: Current tile being processed (0-indexed)
171
+ num_tiles: Total number of tiles
172
+
173
+ Returns:
174
+ Tuple of (Plotly figure, explanation text)
175
+ """
176
+ # Pre-computed example values for 8 tokens, block_size=2
177
+ # Simulating attention scores from Q[0] to all K blocks
178
+ example_data = [
179
+ {
180
+ "tile": 0,
181
+ "block_max": 2.1,
182
+ "block_sum_exp": 3.42,
183
+ "m_before": float("-inf"),
184
+ "m_after": 2.1,
185
+ "l_before": 0.0,
186
+ "l_after": 3.42,
187
+ "rescale_factor": 1.0,
188
+ "rescaled": False,
189
+ },
190
+ {
191
+ "tile": 1,
192
+ "block_max": 3.5,
193
+ "block_sum_exp": 5.21,
194
+ "m_before": 2.1,
195
+ "m_after": 3.5,
196
+ "l_before": 3.42,
197
+ "l_after": 1.70, # 3.42 * exp(2.1-3.5) + 5.21 = 0.85 + 5.21
198
+ "rescale_factor": 0.247, # exp(2.1 - 3.5)
199
+ "rescaled": True,
200
+ },
201
+ {
202
+ "tile": 2,
203
+ "block_max": 2.8,
204
+ "block_sum_exp": 4.01,
205
+ "m_before": 3.5,
206
+ "m_after": 3.5, # No change - block_max < m
207
+ "l_before": 6.06,
208
+ "l_after": 8.03, # 6.06 * 1.0 + 4.01 * exp(2.8-3.5)
209
+ "rescale_factor": 1.0,
210
+ "rescaled": False,
211
+ },
212
+ {
213
+ "tile": 3,
214
+ "block_max": 4.2,
215
+ "block_sum_exp": 6.83,
216
+ "m_before": 3.5,
217
+ "m_after": 4.2,
218
+ "l_before": 8.03,
219
+ "l_after": 10.79, # 8.03 * exp(3.5-4.2) + 6.83
220
+ "rescale_factor": 0.497, # exp(3.5 - 4.2)
221
+ "rescaled": True,
222
+ },
223
+ ]
224
+
225
+ # Build the visualization
226
+ step = min(current_step, len(example_data) - 1)
227
+ current_data = example_data[step]
228
+
229
+ # Create figure with bar chart showing m and l evolution
230
+ fig = make_subplots(
231
+ rows=1, cols=2,
232
+ subplot_titles=("Running Max (m)", "Running Sum (l)"),
233
+ horizontal_spacing=0.15,
234
+ )
235
+
236
+ # Get historical values up to current step
237
+ m_values = [example_data[i]["m_after"] if i <= step else None for i in range(num_tiles)]
238
+ l_values = [example_data[i]["l_after"] if i <= step else None for i in range(num_tiles)]
239
+
240
+ # Colors - highlight rescaling events
241
+ m_colors = []
242
+ l_colors = []
243
+ for i in range(num_tiles):
244
+ if i > step:
245
+ m_colors.append("rgba(200, 200, 200, 0.5)")
246
+ l_colors.append("rgba(200, 200, 200, 0.5)")
247
+ elif i == step:
248
+ m_colors.append("rgba(249, 115, 22, 0.9)") # Orange for current
249
+ l_colors.append("rgba(249, 115, 22, 0.9)")
250
+ elif example_data[i]["rescaled"]:
251
+ m_colors.append("rgba(239, 68, 68, 0.7)") # Red for rescale events
252
+ l_colors.append("rgba(239, 68, 68, 0.7)")
253
+ else:
254
+ m_colors.append("rgba(34, 197, 94, 0.7)") # Green for normal
255
+ l_colors.append("rgba(34, 197, 94, 0.7)")
256
+
257
+ # Add bars for m
258
+ fig.add_trace(
259
+ go.Bar(
260
+ x=[f"Tile {i}" for i in range(num_tiles)],
261
+ y=[v if v is not None else 0 for v in m_values],
262
+ marker_color=m_colors,
263
+ text=[f"{v:.2f}" if v is not None else "" for v in m_values],
264
+ textposition="outside",
265
+ name="m (max)",
266
+ ),
267
+ row=1, col=1
268
+ )
269
+
270
+ # Add bars for l
271
+ fig.add_trace(
272
+ go.Bar(
273
+ x=[f"Tile {i}" for i in range(num_tiles)],
274
+ y=[v if v is not None else 0 for v in l_values],
275
+ marker_color=l_colors,
276
+ text=[f"{v:.2f}" if v is not None else "" for v in l_values],
277
+ textposition="outside",
278
+ name="l (sum)",
279
+ ),
280
+ row=1, col=2
281
+ )
282
+
283
+ fig.update_layout(
284
+ height=280,
285
+ margin=dict(l=40, r=40, t=60, b=40),
286
+ showlegend=False,
287
+ )
288
+
289
+ fig.update_yaxes(range=[0, 12], row=1, col=1)
290
+ fig.update_yaxes(range=[0, 15], row=1, col=2)
291
+
292
+ # Generate explanation text
293
+ d = current_data
294
+ if d["rescaled"]:
295
+ explanation = f"""**Processing Tile {step} (Keys {step*2}-{step*2+1})**
296
+
297
+ 🔴 **MAX CHANGED!** Block max ({d['block_max']:.2f}) > Previous max ({d['m_before']:.2f})
298
+
299
+ **Rescaling required:**
300
+ - Rescale factor: exp({d['m_before']:.1f} - {d['block_max']:.1f}) = **{d['rescale_factor']:.3f}**
301
+ - Previous l rescaled: {d['l_before']:.2f} × {d['rescale_factor']:.3f} = {d['l_before'] * d['rescale_factor']:.2f}
302
+ - New l = rescaled + block_sum = **{d['l_after']:.2f}**
303
+ - Previous O also rescaled by {d['rescale_factor']:.3f}
304
+
305
+ *This is the key insight: when max increases, we must rescale all previous accumulations!*
306
+ """
307
+ else:
308
+ explanation = f"""**Processing Tile {step} (Keys {step*2}-{step*2+1})**
309
+
310
+ ✅ No rescaling needed (block max {d['block_max']:.2f} ≤ current max {d['m_after']:.2f})
311
+
312
+ **Simple accumulation:**
313
+ - m stays at: **{d['m_after']:.2f}**
314
+ - l += block_sum × exp(block_max - m)
315
+ - l = {d['l_before']:.2f} + {d['block_sum_exp']:.2f} × exp({d['block_max']:.1f} - {d['m_after']:.1f}) = **{d['l_after']:.2f}**
316
+ """
317
+
318
+ return fig, explanation
319
+
320
+
321
+ def create_memory_hierarchy_diagram(
322
+ algorithm: str = "flash",
323
+ current_step: int = 0,
324
+ ) -> go.Figure:
325
+ """
326
+ Create a diagram showing HBM vs SRAM memory hierarchy.
327
+
328
+ Args:
329
+ algorithm: "standard" or "flash"
330
+ current_step: For animation purposes
331
+
332
+ Returns:
333
+ Plotly figure showing memory hierarchy
334
+ """
335
+ fig = go.Figure()
336
+
337
+ # Define positions
338
+ hbm_y = 0.7
339
+ sram_y = 0.3
340
+
341
+ # HBM box
342
+ fig.add_shape(
343
+ type="rect",
344
+ x0=0.05, y0=0.55, x1=0.95, y1=0.95,
345
+ fillcolor="rgba(59, 130, 246, 0.1)",
346
+ line=dict(color="rgba(59, 130, 246, 0.8)", width=2),
347
+ )
348
+
349
+ # SRAM box
350
+ fig.add_shape(
351
+ type="rect",
352
+ x0=0.2, y0=0.15, x1=0.8, y1=0.45,
353
+ fillcolor="rgba(34, 197, 94, 0.1)",
354
+ line=dict(color="rgba(34, 197, 94, 0.8)", width=2),
355
+ )
356
+
357
+ # HBM matrices (Q, K, V, O)
358
+ matrix_width = 0.15
359
+ matrices = ["Q", "K", "V", "O"]
360
+ hbm_x_start = 0.15
361
+
362
+ for i, name in enumerate(matrices):
363
+ x = hbm_x_start + i * 0.2
364
+ fig.add_shape(
365
+ type="rect",
366
+ x0=x, y0=0.65, x1=x + matrix_width, y1=0.85,
367
+ fillcolor="rgba(59, 130, 246, 0.3)",
368
+ line=dict(color="rgba(59, 130, 246, 0.6)", width=1),
369
+ )
370
+ fig.add_annotation(
371
+ x=x + matrix_width/2, y=0.75,
372
+ text=f"<b>{name}</b><br>[N, d]",
373
+ showarrow=False,
374
+ font=dict(size=11),
375
+ )
376
+
377
+ # SRAM tiles
378
+ if algorithm == "flash":
379
+ tiles = ["Q_tile", "K_tile", "V_tile", "S_tile", "O_tile"]
380
+ tile_width = 0.1
381
+ sram_x_start = 0.25
382
+
383
+ for i, name in enumerate(tiles):
384
+ x = sram_x_start + i * 0.11
385
+ # Highlight current tile being processed
386
+ is_active = (i == current_step % len(tiles))
387
+ fill = "rgba(249, 115, 22, 0.5)" if is_active else "rgba(34, 197, 94, 0.3)"
388
+
389
+ fig.add_shape(
390
+ type="rect",
391
+ x0=x, y0=0.22, x1=x + tile_width, y1=0.38,
392
+ fillcolor=fill,
393
+ line=dict(color="rgba(34, 197, 94, 0.6)", width=1),
394
+ )
395
+ fig.add_annotation(
396
+ x=x + tile_width/2, y=0.30,
397
+ text=name.replace("_", "<br>"),
398
+ showarrow=False,
399
+ font=dict(size=9),
400
+ )
401
+
402
+ # Transfer arrows (selective)
403
+ # Show only tile-sized transfers
404
+ fig.add_annotation(
405
+ x=0.5, y=0.50,
406
+ ax=0.5, ay=0.55,
407
+ xref="x", yref="y",
408
+ axref="x", ayref="y",
409
+ text="",
410
+ showarrow=True,
411
+ arrowhead=2,
412
+ arrowsize=1.5,
413
+ arrowwidth=2,
414
+ arrowcolor="rgba(34, 197, 94, 0.8)",
415
+ )
416
+ fig.add_annotation(
417
+ x=0.5, y=0.52,
418
+ text="O(B) per tile",
419
+ showarrow=False,
420
+ font=dict(size=10, color="green"),
421
+ )
422
+ else:
423
+ # Standard attention - full matrix in SRAM (doesn't fit!)
424
+ fig.add_shape(
425
+ type="rect",
426
+ x0=0.3, y0=0.22, x1=0.7, y1=0.38,
427
+ fillcolor="rgba(239, 68, 68, 0.3)",
428
+ line=dict(color="rgba(239, 68, 68, 0.6)", width=1, dash="dash"),
429
+ )
430
+ fig.add_annotation(
431
+ x=0.5, y=0.30,
432
+ text="S[N,N]<br>❌ Doesn't fit!",
433
+ showarrow=False,
434
+ font=dict(size=10, color="red"),
435
+ )
436
+
437
+ # Transfer arrows (full matrix)
438
+ fig.add_annotation(
439
+ x=0.5, y=0.50,
440
+ ax=0.5, ay=0.55,
441
+ xref="x", yref="y",
442
+ axref="x", ayref="y",
443
+ text="",
444
+ showarrow=True,
445
+ arrowhead=2,
446
+ arrowsize=1.5,
447
+ arrowwidth=2,
448
+ arrowcolor="rgba(239, 68, 68, 0.8)",
449
+ )
450
+ fig.add_annotation(
451
+ x=0.5, y=0.52,
452
+ text="O(N²) traffic!",
453
+ showarrow=False,
454
+ font=dict(size=10, color="red"),
455
+ )
456
+
457
+ # Labels
458
+ fig.add_annotation(
459
+ x=0.5, y=0.97,
460
+ text="<b>HBM (High Bandwidth Memory)</b><br>80 GB capacity | 2 TB/s bandwidth | ~400 cycles latency",
461
+ showarrow=False,
462
+ font=dict(size=11),
463
+ )
464
+ fig.add_annotation(
465
+ x=0.5, y=0.12,
466
+ text="<b>SRAM (Shared Memory)</b><br>192 KB capacity | 19 TB/s bandwidth | ~20 cycles latency",
467
+ showarrow=False,
468
+ font=dict(size=11),
469
+ )
470
+
471
+ fig.update_layout(
472
+ xaxis=dict(range=[0, 1], showgrid=False, zeroline=False, showticklabels=False),
473
+ yaxis=dict(range=[0, 1], showgrid=False, zeroline=False, showticklabels=False),
474
+ height=400,
475
+ margin=dict(l=20, r=20, t=40, b=20),
476
+ title=dict(
477
+ text=f"Memory Hierarchy: {'FlashAttention' if algorithm == 'flash' else 'Standard Attention'}",
478
+ x=0.5,
479
+ ),
480
+ )
481
+
482
+ return fig
483
+
484
+
485
+ def get_max_steps(seq_len: int, block_size: int, causal: bool) -> int:
486
+ """Calculate total number of steps for the tiling animation."""
487
+ num_blocks = seq_len // block_size
488
+ if causal:
489
+ return sum(range(1, num_blocks + 1))
490
+ return num_blocks * num_blocks