a0y0346 commited on
Commit
ac3157f
·
1 Parent(s): c9bdf44

Fix Visualizer tab bugs: softmax data, reset button, causal toggle

Browse files
Files changed (2) hide show
  1. app.py +18 -5
  2. src/visualizer.py +1 -1
app.py CHANGED
@@ -159,9 +159,13 @@ def create_app() -> gr.Blocks:
159
  """Move to previous step."""
160
  return max(current_step - 1, 0)
161
 
162
- def reset_step():
163
- """Reset to step 0."""
164
- return 0
 
 
 
 
165
 
166
  # Wire up events
167
  viz_inputs = [seq_len_viz, block_size_viz, causal_viz, step_slider]
@@ -170,7 +174,12 @@ def create_app() -> gr.Blocks:
170
  # Update on parameter change
171
  seq_len_viz.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
172
  block_size_viz.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
173
- causal_viz.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
 
 
 
 
 
174
  step_slider.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
175
 
176
  # Step controls
@@ -180,7 +189,11 @@ def create_app() -> gr.Blocks:
180
  outputs=step_slider
181
  )
182
  step_back_btn.click(fn=step_back, inputs=step_slider, outputs=step_slider)
183
- reset_btn.click(fn=reset_step, outputs=step_slider)
 
 
 
 
184
 
185
  # Memory hierarchy
186
  algo_choice.change(fn=update_memory_hierarchy, inputs=algo_choice, outputs=memory_plot)
 
159
  """Move to previous step."""
160
  return max(current_step - 1, 0)
161
 
162
+ def reset_step(seq_len, block_size, causal):
163
+ """Reset to step 0 and update visualizations."""
164
+ step = 0
165
+ tiling_fig = create_tiling_grid(seq_len, block_size, step, causal)
166
+ num_tiles = seq_len // block_size
167
+ softmax_fig, explanation = create_online_softmax_state(step, num_tiles)
168
+ return step, tiling_fig, softmax_fig, explanation
169
 
170
  # Wire up events
171
  viz_inputs = [seq_len_viz, block_size_viz, causal_viz, step_slider]
 
174
  # Update on parameter change
175
  seq_len_viz.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
176
  block_size_viz.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
177
+ # Causal toggle resets step to 0 since the grid structure changes
178
+ causal_viz.change(
179
+ fn=reset_step,
180
+ inputs=[seq_len_viz, block_size_viz, causal_viz],
181
+ outputs=[step_slider, tiling_plot, softmax_plot, softmax_explanation]
182
+ )
183
  step_slider.change(fn=update_visualizations, inputs=viz_inputs, outputs=viz_outputs)
184
 
185
  # Step controls
 
189
  outputs=step_slider
190
  )
191
  step_back_btn.click(fn=step_back, inputs=step_slider, outputs=step_slider)
192
+ reset_btn.click(
193
+ fn=reset_step,
194
+ inputs=[seq_len_viz, block_size_viz, causal_viz],
195
+ outputs=[step_slider, tiling_plot, softmax_plot, softmax_explanation]
196
+ )
197
 
198
  # Memory hierarchy
199
  algo_choice.change(fn=update_memory_hierarchy, inputs=algo_choice, outputs=memory_plot)
src/visualizer.py CHANGED
@@ -194,7 +194,7 @@ def create_online_softmax_state(
194
  "m_before": 2.1,
195
  "m_after": 3.5,
196
  "l_before": 3.42,
197
- "l_after": 1.70, # 3.42 * exp(2.1-3.5) + 5.21 = 0.85 + 5.21
198
  "rescale_factor": 0.247, # exp(2.1 - 3.5)
199
  "rescaled": True,
200
  },
 
194
  "m_before": 2.1,
195
  "m_after": 3.5,
196
  "l_before": 3.42,
197
+ "l_after": 6.06, # 3.42 * exp(2.1-3.5) + 5.21 = 0.85 + 5.21 ≈ 6.06
198
  "rescale_factor": 0.247, # exp(2.1 - 3.5)
199
  "rescaled": True,
200
  },