Spaces:
Sleeping
Sleeping
a0y0346 commited on
Commit ·
ac3157f
1
Parent(s): c9bdf44
Fix Visualizer tab bugs: softmax data, reset button, causal toggle
Browse files- app.py +18 -5
- src/visualizer.py +1 -1
app.py
CHANGED
|
@@ -159,9 +159,13 @@ def create_app() -> gr.Blocks:
|
|
| 159 |
"""Move to previous step."""
|
| 160 |
return max(current_step - 1, 0)
|
| 161 |
|
| 162 |
-
def reset_step():
|
| 163 |
-
"""Reset to step 0."""
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
# Wire up events
|
| 167 |
viz_inputs = [seq_len_viz, block_size_viz, causal_viz, step_slider]
|
|
@@ -170,7 +174,12 @@ def create_app() -> gr.Blocks:
|
|
| 170 |
# Update on parameter change
|
| 171 |
seq_len_viz.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
|
| 172 |
block_size_viz.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
step_slider.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
|
| 175 |
|
| 176 |
# Step controls
|
|
@@ -180,7 +189,11 @@ def create_app() -> gr.Blocks:
|
|
| 180 |
outputs=step_slider
|
| 181 |
)
|
| 182 |
step_back_btn.click(fn=step_back, inputs=step_slider, outputs=step_slider)
|
| 183 |
-
reset_btn.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
# Memory hierarchy
|
| 186 |
algo_choice.change(fn=update_memory_hierarchy, inputs=algo_choice, outputs=memory_plot)
|
|
|
|
| 159 |
"""Move to previous step."""
|
| 160 |
return max(current_step - 1, 0)
|
| 161 |
|
| 162 |
+
def reset_step(seq_len, block_size, causal):
|
| 163 |
+
"""Reset to step 0 and update visualizations."""
|
| 164 |
+
step = 0
|
| 165 |
+
tiling_fig = create_tiling_grid(seq_len, block_size, step, causal)
|
| 166 |
+
num_tiles = seq_len // block_size
|
| 167 |
+
softmax_fig, explanation = create_online_softmax_state(step, num_tiles)
|
| 168 |
+
return step, tiling_fig, softmax_fig, explanation
|
| 169 |
|
| 170 |
# Wire up events
|
| 171 |
viz_inputs = [seq_len_viz, block_size_viz, causal_viz, step_slider]
|
|
|
|
| 174 |
# Update on parameter change
|
| 175 |
seq_len_viz.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
|
| 176 |
block_size_viz.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
|
| 177 |
+
# Causal toggle resets step to 0 since the grid structure changes
|
| 178 |
+
causal_viz.change(
|
| 179 |
+
fn=reset_step,
|
| 180 |
+
inputs=[seq_len_viz, block_size_viz, causal_viz],
|
| 181 |
+
outputs=[step_slider, tiling_plot, softmax_plot, softmax_explanation]
|
| 182 |
+
)
|
| 183 |
step_slider.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
|
| 184 |
|
| 185 |
# Step controls
|
|
|
|
| 189 |
outputs=step_slider
|
| 190 |
)
|
| 191 |
step_back_btn.click(fn=step_back, inputs=step_slider, outputs=step_slider)
|
| 192 |
+
reset_btn.click(
|
| 193 |
+
fn=reset_step,
|
| 194 |
+
inputs=[seq_len_viz, block_size_viz, causal_viz],
|
| 195 |
+
outputs=[step_slider, tiling_plot, softmax_plot, softmax_explanation]
|
| 196 |
+
)
|
| 197 |
|
| 198 |
# Memory hierarchy
|
| 199 |
algo_choice.change(fn=update_memory_hierarchy, inputs=algo_choice, outputs=memory_plot)
|
src/visualizer.py
CHANGED
|
@@ -194,7 +194,7 @@ def create_online_softmax_state(
|
|
| 194 |
"m_before": 2.1,
|
| 195 |
"m_after": 3.5,
|
| 196 |
"l_before": 3.42,
|
| 197 |
-
"l_after":
|
| 198 |
"rescale_factor": 0.247, # exp(2.1 - 3.5)
|
| 199 |
"rescaled": True,
|
| 200 |
},
|
|
|
|
| 194 |
"m_before": 2.1,
|
| 195 |
"m_after": 3.5,
|
| 196 |
"l_before": 3.42,
|
| 197 |
+
"l_after": 6.06, # 3.42 * exp(2.1-3.5) + 5.21 = 0.85 + 5.21 ≈ 6.06
|
| 198 |
"rescale_factor": 0.247, # exp(2.1 - 3.5)
|
| 199 |
"rescaled": True,
|
| 200 |
},
|