cdpearlman Cursor commited on
Commit
d60cfe2
·
1 Parent(s): 3f991b4

feat(output): Add token scrubber with per-position top-5 next-token probabilities

Browse files
app.py CHANGED
@@ -28,7 +28,8 @@ from components.model_selector import create_model_selector
28
  from components.glossary import create_glossary_modal
29
  from components.pipeline import (create_pipeline_container, create_tokenization_content,
30
  create_embedding_content, create_attention_content,
31
- create_mlp_content, create_output_content)
 
32
  from components.investigation_panel import create_investigation_panel, create_attribution_results_display
33
  from components.ablation_panel import create_selected_heads_display, create_ablation_results_display
34
  from components.chatbot import create_chatbot_container, render_messages
@@ -375,9 +376,12 @@ def run_generation(n_clicks, model_name, prompt, max_new_tokens, beam_width, pat
375
  # the full-sequence analysis runs when the user selects a beam.
376
  if max_new_tokens == 1:
377
  full_text = results[0]['text']
 
 
 
378
  else:
379
  full_text = prompt
380
- activation_data = execute_forward_pass(model, tokenizer, full_text, config)
381
 
382
  results_ui = []
383
  if max_new_tokens > 1:
@@ -418,10 +422,11 @@ def run_generation(n_clicks, model_name, prompt, max_new_tokens, beam_width, pat
418
  Output('session-activation-store', 'data', allow_duplicate=True)],
419
  Input({'type': 'result-item', 'index': ALL}, 'n_clicks'),
420
  [State('generation-results-store', 'data'),
421
- State('session-activation-store', 'data')],
 
422
  prevent_initial_call=True
423
  )
424
- def store_selected_beam(n_clicks_list, results_data, existing_activation_data):
425
  """
426
  Store selected beam and re-run forward pass on the full sequence.
427
 
@@ -490,7 +495,12 @@ def store_selected_beam(n_clicks_list, results_data, existing_activation_data):
490
  model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
491
  tokenizer = AutoTokenizer.from_pretrained(model_name)
492
  model.eval()
493
- new_activation_data = execute_forward_pass(model, tokenizer, result['text'], config)
 
 
 
 
 
494
  except Exception as e:
495
  import traceback
496
  traceback.print_exc()
@@ -594,9 +604,19 @@ def update_pipeline_content(activation_data, model_name):
594
  # Stage 5: Output
595
  # Get original prompt for context display
596
  original_prompt = activation_data.get('prompt', '')
 
 
 
 
 
597
  outputs.append(f"→ {predicted_token}")
598
- outputs.append(create_output_content(top_tokens, predicted_token, predicted_prob,
599
- original_prompt=original_prompt))
 
 
 
 
 
600
 
601
  return tuple(outputs)
602
 
@@ -606,6 +626,41 @@ def update_pipeline_content(activation_data, model_name):
606
  return tuple(empty_outputs)
607
 
608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  # ============================================================================
610
  # CALLBACKS: Sidebar
611
  # ============================================================================
 
28
  from components.glossary import create_glossary_modal
29
  from components.pipeline import (create_pipeline_container, create_tokenization_content,
30
  create_embedding_content, create_attention_content,
31
+ create_mlp_content, create_output_content,
32
+ _build_token_display, _build_top5_chart)
33
  from components.investigation_panel import create_investigation_panel, create_attribution_results_display
34
  from components.ablation_panel import create_selected_heads_display, create_ablation_results_display
35
  from components.chatbot import create_chatbot_container, render_messages
 
376
  # the full-sequence analysis runs when the user selects a beam.
377
  if max_new_tokens == 1:
378
  full_text = results[0]['text']
379
+ # Pass original_prompt so per-position top-5 is computed for the scrubber
380
+ activation_data = execute_forward_pass(model, tokenizer, full_text, config,
381
+ original_prompt=prompt)
382
  else:
383
  full_text = prompt
384
+ activation_data = execute_forward_pass(model, tokenizer, full_text, config)
385
 
386
  results_ui = []
387
  if max_new_tokens > 1:
 
422
  Output('session-activation-store', 'data', allow_duplicate=True)],
423
  Input({'type': 'result-item', 'index': ALL}, 'n_clicks'),
424
  [State('generation-results-store', 'data'),
425
+ State('session-activation-store', 'data'),
426
+ State('session-original-prompt-store', 'data')],
427
  prevent_initial_call=True
428
  )
429
+ def store_selected_beam(n_clicks_list, results_data, existing_activation_data, original_prompt_data):
430
  """
431
  Store selected beam and re-run forward pass on the full sequence.
432
 
 
495
  model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
496
  tokenizer = AutoTokenizer.from_pretrained(model_name)
497
  model.eval()
498
+ # Pass original_prompt so per-position top-5 data is computed for scrubber
499
+ orig_prompt = original_prompt_data.get('prompt', '') if original_prompt_data else ''
500
+ new_activation_data = execute_forward_pass(
501
+ model, tokenizer, result['text'], config,
502
+ original_prompt=orig_prompt
503
+ )
504
  except Exception as e:
505
  import traceback
506
  traceback.print_exc()
 
604
  # Stage 5: Output
605
  # Get original prompt for context display
606
  original_prompt = activation_data.get('prompt', '')
607
+ # Per-position data for the scrubber (populated when original_prompt was given)
608
+ per_position_data = activation_data.get('per_position_top5', [])
609
+ generated_tokens = activation_data.get('generated_tokens', [])
610
+ scrubber_prompt = activation_data.get('original_prompt', original_prompt)
611
+
612
  outputs.append(f"→ {predicted_token}")
613
+ outputs.append(create_output_content(
614
+ top_tokens, predicted_token, predicted_prob,
615
+ original_prompt=original_prompt,
616
+ per_position_data=per_position_data,
617
+ generated_tokens=generated_tokens,
618
+ prompt_text=scrubber_prompt
619
+ ))
620
 
621
  return tuple(outputs)
622
 
 
626
  return tuple(empty_outputs)
627
 
628
 
629
+ # ============================================================================
630
+ # CALLBACKS: Output Scrubber
631
+ # ============================================================================
632
+
633
+ @app.callback(
634
+ [Output('output-token-display', 'children'),
635
+ Output('output-top5-chart', 'children')],
636
+ [Input('output-scrubber-slider', 'value')],
637
+ [State('session-activation-store', 'data')],
638
+ prevent_initial_call=True
639
+ )
640
+ def update_output_scrubber(position, activation_data):
641
+ """Update the token display and top-5 chart when the scrubber slider moves."""
642
+ if activation_data is None or position is None:
643
+ return no_update, no_update
644
+
645
+ per_position_data = activation_data.get('per_position_top5', [])
646
+ generated_tokens = activation_data.get('generated_tokens', [])
647
+ prompt_text = activation_data.get('original_prompt', activation_data.get('prompt', ''))
648
+
649
+ if not per_position_data or not generated_tokens:
650
+ return no_update, no_update
651
+
652
+ # Clamp position to valid range
653
+ position = max(0, min(position, len(per_position_data) - 1))
654
+ pos_data = per_position_data[position]
655
+
656
+ token_display = _build_token_display(
657
+ prompt_text, generated_tokens, position, pos_data['actual_prob']
658
+ )
659
+ top5_chart = _build_top5_chart(pos_data['top5'], pos_data.get('actual_token'))
660
+
661
+ return token_display, top5_chart
662
+
663
+
664
  # ============================================================================
665
  # CALLBACKS: Sidebar
666
  # ============================================================================
components/pipeline.py CHANGED
@@ -689,18 +689,218 @@ def create_mlp_content(layer_count=None, hidden_dim=None, intermediate_dim=None)
689
  ])
690
 
691
 
692
- def create_output_content(top_tokens=None, predicted_token=None, predicted_prob=None,
693
- top5_chart=None, original_prompt=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694
  """
695
  Create content for the output selection stage.
 
 
 
 
696
 
697
  Args:
698
- top_tokens: List of (token, probability) tuples for top predictions
699
- predicted_token: The final predicted token
700
- predicted_prob: Probability of the predicted token
701
- top5_chart: Optional Plotly figure for top-5 visualization
702
- original_prompt: Original input prompt to show context with prediction
 
 
 
703
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
704
  content_items = [
705
  html.Div([
706
  html.H5("What happens here:", style={'color': '#495057', 'marginBottom': '8px'}),
@@ -711,75 +911,49 @@ def create_output_content(top_tokens=None, predicted_token=None, predicted_prob=
711
  ])
712
  ]
713
 
714
- # Predicted token display with full prompt context
715
  if predicted_token:
716
- # Build the full prompt + predicted token display
717
  prompt_display = original_prompt if original_prompt else ""
718
-
719
  content_items.append(
720
  html.Div([
721
  html.Div([
722
  html.Span("Model prediction:", style={'color': '#495057', 'marginBottom': '12px', 'display': 'block', 'fontWeight': '500'}),
723
  html.Div([
724
- # Original prompt (dimmed)
725
  html.Span(prompt_display, style={
726
- 'color': '#6c757d',
727
- 'fontFamily': 'monospace',
728
- 'fontSize': '15px'
729
  }),
730
- # Predicted token (highlighted)
731
  html.Span(predicted_token, style={
732
- 'padding': '4px 8px',
733
- 'backgroundColor': '#00f2fe',
734
- 'color': '#1a1a2e',
735
- 'borderRadius': '4px',
736
- 'fontFamily': 'monospace',
737
- 'fontWeight': '600',
738
- 'fontSize': '15px',
739
- 'marginLeft': '2px'
740
  })
741
  ], style={'display': 'inline'}),
742
- # Confidence indicator
743
  html.Div([
744
  html.Span(f"{predicted_prob:.1%} confidence" if predicted_prob else "", style={
745
- 'color': '#6c757d',
746
- 'fontSize': '13px',
747
- 'marginTop': '8px',
748
- 'display': 'block'
749
  })
750
  ])
751
  ], style={'textAlign': 'center'})
752
- ], style={'padding': '20px', 'backgroundColor': 'white', 'borderRadius': '8px',
753
  'border': '2px solid #00f2fe', 'marginBottom': '16px'})
754
  )
755
 
756
- # Top-5 bar chart with improved hover formatting
757
  if top_tokens:
758
  tokens = [t[0] for t in top_tokens[:5]]
759
  probs = [t[1] for t in top_tokens[:5]]
760
 
761
  fig = go.Figure(go.Bar(
762
- x=probs,
763
- y=tokens,
764
- orientation='h',
765
  marker_color=['#00f2fe' if i == 0 else '#4facfe' for i in range(len(tokens))],
766
- text=[f"{p:.1%}" for p in probs],
767
- textposition='outside',
768
- # Format hover to show "Token (X%)" instead of long decimals
769
  hovertemplate='%{y} (%{x:.1%})<extra></extra>'
770
  ))
771
-
772
  fig.update_layout(
773
- title="Top 5 Predictions",
774
- xaxis_title="Probability",
775
- yaxis_title="Token",
776
- height=250,
777
- margin=dict(l=20, r=60, t=40, b=20),
778
- paper_bgcolor='rgba(0,0,0,0)',
779
- plot_bgcolor='rgba(0,0,0,0)',
780
  yaxis=dict(autorange='reversed')
781
  )
782
-
783
  content_items.append(
784
  html.Div([
785
  dcc.Graph(figure=fig, config={'displayModeBar': False})
@@ -792,7 +966,6 @@ def create_output_content(top_tokens=None, predicted_token=None, predicted_prob=
792
  ], style={'backgroundColor': 'white', 'borderRadius': '8px', 'border': '1px solid #e2e8f0'})
793
  )
794
 
795
- # Disclaimer about token selection drivers
796
  content_items.append(
797
  html.Div([
798
  html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '8px'}),
 
689
  ])
690
 
691
 
692
+ def _build_token_display(prompt_text, generated_tokens, position, actual_prob):
693
+ """Build the token display for a given scrubber position.
694
+
695
+ Args:
696
+ prompt_text: Original prompt string.
697
+ generated_tokens: List of generated token strings.
698
+ position: Current slider position (0-indexed into generated_tokens).
699
+ actual_prob: Probability of the highlighted token at this position.
700
+ """
701
+ # Context = prompt + all generated tokens before the current position
702
+ context_parts = [
703
+ html.Span(prompt_text, style={
704
+ 'color': '#6c757d',
705
+ 'fontFamily': 'monospace',
706
+ 'fontSize': '15px'
707
+ })
708
+ ]
709
+ for j in range(position):
710
+ context_parts.append(
711
+ html.Span(generated_tokens[j], style={
712
+ 'color': '#6c757d',
713
+ 'fontFamily': 'monospace',
714
+ 'fontSize': '15px'
715
+ })
716
+ )
717
+
718
+ # Highlighted token at the current position
719
+ highlighted = html.Span(generated_tokens[position], style={
720
+ 'padding': '4px 8px',
721
+ 'backgroundColor': '#00f2fe',
722
+ 'color': '#1a1a2e',
723
+ 'borderRadius': '4px',
724
+ 'fontFamily': 'monospace',
725
+ 'fontWeight': '600',
726
+ 'fontSize': '15px',
727
+ 'marginLeft': '2px'
728
+ })
729
+
730
+ confidence = html.Div([
731
+ html.Span(
732
+ f"{actual_prob:.1%} confidence" if actual_prob else "",
733
+ style={'color': '#6c757d', 'fontSize': '13px', 'marginTop': '8px', 'display': 'block'}
734
+ )
735
+ ])
736
+
737
+ return html.Div([
738
+ html.Div([
739
+ html.Span(
740
+ f"Token {position + 1} of {len(generated_tokens)}:",
741
+ style={'color': '#495057', 'marginBottom': '12px', 'display': 'block', 'fontWeight': '500'}
742
+ ),
743
+ html.Div(context_parts + [highlighted], style={'display': 'inline'}),
744
+ confidence
745
+ ], style={'textAlign': 'center'})
746
+ ], style={
747
+ 'padding': '20px', 'backgroundColor': 'white', 'borderRadius': '8px',
748
+ 'border': '2px solid #00f2fe', 'marginBottom': '16px'
749
+ })
750
+
751
+
752
+ def _build_top5_chart(top5_data, actual_token=None):
753
+ """Build the top-5 bar chart for a single scrubber position.
754
+
755
+ Args:
756
+ top5_data: List of {'token': str, 'probability': float}.
757
+ actual_token: The token that was actually generated (highlighted if present).
758
+ """
759
+ tokens = [entry['token'] for entry in top5_data]
760
+ probs = [entry['probability'] for entry in top5_data]
761
+
762
+ # Highlight the actual chosen token if it appears in the top 5
763
+ colors = []
764
+ actual_in_top5 = False
765
+ for t in tokens:
766
+ if actual_token and t.strip() == actual_token.strip():
767
+ colors.append('#00f2fe')
768
+ actual_in_top5 = True
769
+ else:
770
+ colors.append('#4facfe')
771
+
772
+ fig = go.Figure(go.Bar(
773
+ x=probs,
774
+ y=tokens,
775
+ orientation='h',
776
+ marker_color=colors,
777
+ text=[f"{p:.1%}" for p in probs],
778
+ textposition='outside',
779
+ hovertemplate='%{y} (%{x:.1%})<extra></extra>'
780
+ ))
781
+
782
+ fig.update_layout(
783
+ title="Top 5 Next-Token Predictions",
784
+ xaxis_title="Probability",
785
+ yaxis_title="Token",
786
+ height=250,
787
+ margin=dict(l=20, r=60, t=40, b=20),
788
+ paper_bgcolor='rgba(0,0,0,0)',
789
+ plot_bgcolor='rgba(0,0,0,0)',
790
+ yaxis=dict(autorange='reversed')
791
+ )
792
+
793
+ children = [dcc.Graph(figure=fig, config={'displayModeBar': False})]
794
+
795
+ # If the actual token is not in the top 5, add a note below
796
+ if actual_token and not actual_in_top5:
797
+ children.append(html.Div([
798
+ html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '6px'}),
799
+ html.Span([
800
+ "The actual token \"", html.Strong(actual_token.strip()),
801
+ "\" was not in the top 5 predictions at this position."
802
+ ], style={'color': '#6c757d', 'fontSize': '13px'})
803
+ ], style={'padding': '8px 12px'}))
804
+
805
+ return html.Div(children, style={
806
+ 'backgroundColor': 'white', 'borderRadius': '8px', 'border': '1px solid #e2e8f0'
807
+ })
808
+
809
+
810
+ def create_output_content(top_tokens=None, predicted_token=None, predicted_prob=None,
811
+ top5_chart=None, original_prompt=None,
812
+ per_position_data=None, generated_tokens=None,
813
+ prompt_text=None):
814
  """
815
  Create content for the output selection stage.
816
+
817
+ When per_position_data is available the output is an interactive scrubber
818
+ that lets the user step through each generated-token position. Otherwise
819
+ falls back to the previous static display.
820
 
821
  Args:
822
+ top_tokens: List of (token, probability) tuples for top predictions (static mode).
823
+ predicted_token: The final predicted token (static mode).
824
+ predicted_prob: Probability of the predicted token (static mode).
825
+ top5_chart: Optional Plotly figure for top-5 visualization (static mode).
826
+ original_prompt: Original input prompt to show context with prediction (static mode).
827
+ per_position_data: List of per-position dicts from compute_per_position_top5 (scrubber mode).
828
+ generated_tokens: List of generated token strings (scrubber mode).
829
+ prompt_text: Original prompt text for context display (scrubber mode).
830
  """
831
+ # --- Scrubber mode ---
832
+ if per_position_data and generated_tokens:
833
+ num_positions = len(per_position_data)
834
+ prompt_display = prompt_text or original_prompt or ""
835
+
836
+ content_items = [
837
+ html.Div([
838
+ html.H5("What happens here:", style={'color': '#495057', 'marginBottom': '8px'}),
839
+ html.P([
840
+ "The model converts the final hidden state into a ",
841
+ html.Strong("probability distribution"),
842
+ " over all possible next tokens. Use the slider below to step through "
843
+ "each generated token and see the model's top predictions at that point."
844
+ ], style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '16px'})
845
+ ])
846
+ ]
847
+
848
+ # Slider / scrubber
849
+ slider_marks = {i: {'label': generated_tokens[i].strip() or repr(generated_tokens[i])}
850
+ for i in range(num_positions)}
851
+ content_items.append(
852
+ html.Div([
853
+ html.Span("Step through generated tokens:",
854
+ style={'color': '#495057', 'fontWeight': '500', 'display': 'block',
855
+ 'marginBottom': '8px'}),
856
+ dcc.Slider(
857
+ id='output-scrubber-slider',
858
+ min=0,
859
+ max=max(num_positions - 1, 0),
860
+ step=1,
861
+ value=0,
862
+ marks=slider_marks,
863
+ included=False,
864
+ )
865
+ ], style={'marginBottom': '20px', 'padding': '12px 16px',
866
+ 'backgroundColor': '#f8f9fa', 'borderRadius': '8px',
867
+ 'border': '1px solid #dee2e6'})
868
+ )
869
+
870
+ # Initial render at position 0
871
+ pos0 = per_position_data[0]
872
+ content_items.append(
873
+ html.Div(
874
+ _build_token_display(prompt_display, generated_tokens, 0, pos0['actual_prob']),
875
+ id='output-token-display'
876
+ )
877
+ )
878
+ content_items.append(
879
+ html.Div(
880
+ _build_top5_chart(pos0['top5'], pos0.get('actual_token')),
881
+ id='output-top5-chart'
882
+ )
883
+ )
884
+
885
+ # Disclaimer
886
+ content_items.append(
887
+ html.Div([
888
+ html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '8px'}),
889
+ html.Span([
890
+ html.Strong("Note on Token Selection: "),
891
+ "While the probabilities above show the model's raw preference for the immediate next token, the final choice ",
892
+ "can be influenced by other factors. Techniques like ", html.Strong("Beam Search"),
893
+ " look ahead at multiple possible sequences to find the best overall result, rather than just the single most likely token at each step. ",
894
+ "Additionally, architectures like ", html.Strong("Mixture of Experts (MoE)"),
895
+ " might route processing through different specialized internal networks which can impact the final output distribution."
896
+ ], style={'color': '#6c757d', 'fontSize': '13px'})
897
+ ], style={'marginTop': '16px', 'padding': '12px', 'backgroundColor': '#f8f9fa',
898
+ 'borderRadius': '6px', 'border': '1px solid #dee2e6'})
899
+ )
900
+
901
+ return html.Div(content_items)
902
+
903
+ # --- Static fallback (prompt-only analysis, no generated tokens yet) ---
904
  content_items = [
905
  html.Div([
906
  html.H5("What happens here:", style={'color': '#495057', 'marginBottom': '8px'}),
 
911
  ])
912
  ]
913
 
 
914
  if predicted_token:
 
915
  prompt_display = original_prompt if original_prompt else ""
 
916
  content_items.append(
917
  html.Div([
918
  html.Div([
919
  html.Span("Model prediction:", style={'color': '#495057', 'marginBottom': '12px', 'display': 'block', 'fontWeight': '500'}),
920
  html.Div([
 
921
  html.Span(prompt_display, style={
922
+ 'color': '#6c757d', 'fontFamily': 'monospace', 'fontSize': '15px'
 
 
923
  }),
 
924
  html.Span(predicted_token, style={
925
+ 'padding': '4px 8px', 'backgroundColor': '#00f2fe',
926
+ 'color': '#1a1a2e', 'borderRadius': '4px',
927
+ 'fontFamily': 'monospace', 'fontWeight': '600',
928
+ 'fontSize': '15px', 'marginLeft': '2px'
 
 
 
 
929
  })
930
  ], style={'display': 'inline'}),
 
931
  html.Div([
932
  html.Span(f"{predicted_prob:.1%} confidence" if predicted_prob else "", style={
933
+ 'color': '#6c757d', 'fontSize': '13px', 'marginTop': '8px', 'display': 'block'
 
 
 
934
  })
935
  ])
936
  ], style={'textAlign': 'center'})
937
+ ], style={'padding': '20px', 'backgroundColor': 'white', 'borderRadius': '8px',
938
  'border': '2px solid #00f2fe', 'marginBottom': '16px'})
939
  )
940
 
 
941
  if top_tokens:
942
  tokens = [t[0] for t in top_tokens[:5]]
943
  probs = [t[1] for t in top_tokens[:5]]
944
 
945
  fig = go.Figure(go.Bar(
946
+ x=probs, y=tokens, orientation='h',
 
 
947
  marker_color=['#00f2fe' if i == 0 else '#4facfe' for i in range(len(tokens))],
948
+ text=[f"{p:.1%}" for p in probs], textposition='outside',
 
 
949
  hovertemplate='%{y} (%{x:.1%})<extra></extra>'
950
  ))
 
951
  fig.update_layout(
952
+ title="Top 5 Predictions", xaxis_title="Probability", yaxis_title="Token",
953
+ height=250, margin=dict(l=20, r=60, t=40, b=20),
954
+ paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)',
 
 
 
 
955
  yaxis=dict(autorange='reversed')
956
  )
 
957
  content_items.append(
958
  html.Div([
959
  dcc.Graph(figure=fig, config={'displayModeBar': False})
 
966
  ], style={'backgroundColor': 'white', 'borderRadius': '8px', 'border': '1px solid #e2e8f0'})
967
  )
968
 
 
969
  content_items.append(
970
  html.Div([
971
  html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '8px'}),
plans.md CHANGED
@@ -1,11 +1,12 @@
1
  - specs on what each attention head does
2
- - change attention to entire generated sequence
3
- - output slider to look at each token
4
- - put in a more obvious place?
5
  - experiment results side by side comparison
6
  - output streaming for chatbot
7
- - change width of chatbot window
8
  - shorter, concise responses in system prompt
9
  - add video links to glossary
10
  - three blue one brown
11
- - add output token generation to attention, tokenization, etc
 
 
 
 
 
 
1
  - specs on what each attention head does
 
 
 
2
  - experiment results side by side comparison
3
  - output streaming for chatbot
 
4
  - shorter, concise responses in system prompt
5
  - add video links to glossary
6
  - three blue one brown
7
+
8
+ Done:
9
+ - change attention to entire generated sequence
10
+ - change width of chatbot window
11
+ - add output token generation to attention, tokenization, etc
12
+ - output slider to look at each token (scrubber with top-5 at each position)
tests/conftest.py CHANGED
@@ -5,6 +5,10 @@ Provides reusable mock data structures and synthetic tensors
5
  to test utility functions without loading actual ML models.
6
  """
7
 
 
 
 
 
8
  import pytest
9
  import torch
10
  import numpy as np
 
5
  to test utility functions without loading actual ML models.
6
  """
7
 
8
+ # Disable TensorFlow before any other imports (mirrors app.py)
9
+ import os
10
+ os.environ["USE_TF"] = "0"
11
+
12
  import pytest
13
  import torch
14
  import numpy as np
tests/test_model_patterns.py CHANGED
@@ -263,6 +263,136 @@ class TestMultiLayerHeadAblation:
263
  assert '99' in result['error'] # Should mention the invalid layer
264
 
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  class TestFullSequenceAttentionData:
267
  """
268
  Tests verifying that activation data for full sequences (prompt + generated output)
 
263
  assert '99' in result['error'] # Should mention the invalid layer
264
 
265
 
266
+ class TestComputePerPositionTop5:
267
+ """Tests for compute_per_position_top5 function."""
268
+
269
+ def _make_mock_output(self, seq_len, vocab_size=10):
270
+ """Create a mock model output with predictable logits.
271
+
272
+ At each position i, logit[i] = 10.0 (highest), so the top-1 token
273
+ is always token index == position index. Other logits are 1.0.
274
+ """
275
+ logits = torch.ones(1, seq_len, vocab_size)
276
+ for i in range(seq_len):
277
+ # Make token (i % vocab_size) the top prediction at position i
278
+ logits[0, i, i % vocab_size] = 10.0
279
+
280
+ class MockOutput:
281
+ pass
282
+ out = MockOutput()
283
+ out.logits = logits
284
+ return out
285
+
286
+ def _make_mock_tokenizer(self, vocab_size=10):
287
+ """Create a mock tokenizer that decodes token IDs to 'tok_N'."""
288
+ from unittest.mock import MagicMock
289
+ tok = MagicMock()
290
+ def decode_fn(ids, skip_special_tokens=False):
291
+ if isinstance(ids, list) and len(ids) == 1:
292
+ return f"tok_{ids[0]}"
293
+ return "".join(f"tok_{i}" for i in ids)
294
+ tok.decode = decode_fn
295
+ return tok
296
+
297
+ def test_returns_correct_number_of_positions(self):
298
+ """With prompt_token_count=3 and seq_len=7, should return 4 positions (7-3)."""
299
+ from utils.model_patterns import compute_per_position_top5
300
+ model_output = self._make_mock_output(seq_len=7, vocab_size=10)
301
+ tokenizer = self._make_mock_tokenizer(vocab_size=10)
302
+ # Full sequence has 7 tokens, prompt has 3, so 4 generated tokens
303
+ result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
304
+ assert len(result) == 4 # positions 0, 1, 2, 3
305
+
306
+ def test_single_generated_token(self):
307
+ """With 1 generated token, should return exactly 1 position."""
308
+ from utils.model_patterns import compute_per_position_top5
309
+ model_output = self._make_mock_output(seq_len=4, vocab_size=10)
310
+ tokenizer = self._make_mock_tokenizer(vocab_size=10)
311
+ result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
312
+ assert len(result) == 1
313
+ assert result[0]['position'] == 0
314
+
315
+ def test_each_position_has_top_k_entries(self):
316
+ """Each position should have exactly top_k entries in top5 list."""
317
+ from utils.model_patterns import compute_per_position_top5
318
+ model_output = self._make_mock_output(seq_len=8, vocab_size=10)
319
+ tokenizer = self._make_mock_tokenizer(vocab_size=10)
320
+ result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
321
+ for pos_data in result:
322
+ assert len(pos_data['top5']) == 5
323
+
324
+ def test_top_k_3(self):
325
+ """Should respect custom top_k parameter."""
326
+ from utils.model_patterns import compute_per_position_top5
327
+ model_output = self._make_mock_output(seq_len=6, vocab_size=10)
328
+ tokenizer = self._make_mock_tokenizer(vocab_size=10)
329
+ result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=3)
330
+ for pos_data in result:
331
+ assert len(pos_data['top5']) == 3
332
+
333
+ def test_probabilities_sorted_descending(self):
334
+ """Top-5 probabilities should be in descending order."""
335
+ from utils.model_patterns import compute_per_position_top5
336
+ model_output = self._make_mock_output(seq_len=6, vocab_size=10)
337
+ tokenizer = self._make_mock_tokenizer(vocab_size=10)
338
+ result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
339
+ for pos_data in result:
340
+ probs = [entry['probability'] for entry in pos_data['top5']]
341
+ assert probs == sorted(probs, reverse=True)
342
+
343
+ def test_probabilities_are_valid(self):
344
+ """All probabilities should be between 0 and 1."""
345
+ from utils.model_patterns import compute_per_position_top5
346
+ model_output = self._make_mock_output(seq_len=6, vocab_size=10)
347
+ tokenizer = self._make_mock_tokenizer(vocab_size=10)
348
+ result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
349
+ for pos_data in result:
350
+ for entry in pos_data['top5']:
351
+ assert 0.0 <= entry['probability'] <= 1.0
352
+ assert 0.0 <= pos_data['actual_prob'] <= 1.0
353
+
354
+ def test_actual_token_field_present(self):
355
+ """Each position should have actual_token and actual_prob fields."""
356
+ from utils.model_patterns import compute_per_position_top5
357
+ model_output = self._make_mock_output(seq_len=6, vocab_size=10)
358
+ tokenizer = self._make_mock_tokenizer(vocab_size=10)
359
+ result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
360
+ for pos_data in result:
361
+ assert 'actual_token' in pos_data
362
+ assert 'actual_prob' in pos_data
363
+ assert isinstance(pos_data['actual_token'], str)
364
+ assert isinstance(pos_data['actual_prob'], float)
365
+
366
+ def test_position_indices_sequential(self):
367
+ """Position indices should be 0, 1, 2, ... N-1."""
368
+ from utils.model_patterns import compute_per_position_top5
369
+ model_output = self._make_mock_output(seq_len=8, vocab_size=10)
370
+ tokenizer = self._make_mock_tokenizer(vocab_size=10)
371
+ result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
372
+ positions = [r['position'] for r in result]
373
+ assert positions == list(range(5)) # 8 - 3 = 5 positions
374
+
375
+ def test_does_not_include_position_beyond_sequence(self):
376
+ """Should NOT produce a position that predicts beyond the last token."""
377
+ from utils.model_patterns import compute_per_position_top5
378
+ model_output = self._make_mock_output(seq_len=5, vocab_size=10)
379
+ tokenizer = self._make_mock_tokenizer(vocab_size=10)
380
+ # prompt=3, seq=5, so 2 generated tokens -> positions 0 and 1
381
+ result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
382
+ assert len(result) == 2
383
+ # Position 0: logits at index 2 (prompt_len-1), predicts token at index 3
384
+ # Position 1: logits at index 3, predicts token at index 4
385
+ # NO position for logits at index 4 (would predict beyond sequence)
386
+
387
+ def test_prompt_equals_sequence_returns_empty(self):
388
+ """When prompt_token_count == seq_len (no generated tokens), return empty."""
389
+ from utils.model_patterns import compute_per_position_top5
390
+ model_output = self._make_mock_output(seq_len=3, vocab_size=10)
391
+ tokenizer = self._make_mock_tokenizer(vocab_size=10)
392
+ result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
393
+ assert result == []
394
+
395
+
396
  class TestFullSequenceAttentionData:
397
  """
398
  Tests verifying that activation data for full sequences (prompt + generated output)
todo.md CHANGED
@@ -208,3 +208,20 @@
208
  - [x] Added 5 tests in `test_model_patterns.py` (`TestFullSequenceAttentionData`) verifying attention matrix dimensions match full sequence length
209
  - Attention visualization now covers the entire chosen output (input + generated tokens), not just the input prompt
210
  - No changes needed in `model_patterns.py`, `beam_search.py`, `pipeline.py`, or `head_detection.py`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  - [x] Added 5 tests in `test_model_patterns.py` (`TestFullSequenceAttentionData`) verifying attention matrix dimensions match full sequence length
209
  - Attention visualization now covers the entire chosen output (input + generated tokens), not just the input prompt
210
  - No changes needed in `model_patterns.py`, `beam_search.py`, `pipeline.py`, or `head_detection.py`
211
+
212
+ ## Completed: Output Token Scrubber
213
+
214
+ - [x] Add `compute_per_position_top5()` to `utils/model_patterns.py` — extracts top-5 next-token probabilities at each generated-token position from a single forward pass
215
+ - [x] Add `original_prompt` parameter to `execute_forward_pass()` — when provided, computes per-position top-5 data and stores in activation_data
216
+ - [x] Export `compute_per_position_top5` in `utils/__init__.py`
217
+ - [x] Update `run_generation()` in app.py — passes `original_prompt=prompt` for single-token generation
218
+ - [x] Update `store_selected_beam()` in app.py — reads original prompt from session store and passes to forward pass
219
+ - [x] Rewrite `create_output_content()` in `components/pipeline.py` — scrubber mode with `dcc.Slider`, token display, and top-5 chart; falls back to static mode when no per-position data
220
+ - [x] Add `_build_token_display()` and `_build_top5_chart()` helpers in pipeline.py
221
+ - [x] Add `update_output_scrubber()` callback in app.py — responds to slider changes, updates token highlight and chart
222
+ - [x] Update `update_pipeline_content()` in app.py — extracts per-position data and passes to output content
223
+ - [x] Add 10 tests for `compute_per_position_top5` in `test_model_patterns.py`
224
+ - [x] Fix `conftest.py` to set `USE_TF=0` for test import compatibility
225
+ - [x] All 100 tests pass
226
+ - Scrubber shows prompt context (gray) + highlighted token (cyan) + top-5 bar chart at each slider position
227
+ - Pre-beam-selection falls back to static output display; scrubber activates after beam selection or single-token generation
utils/__init__.py CHANGED
@@ -4,7 +4,8 @@ from .model_patterns import (load_model_and_get_patterns, execute_forward_pass,
4
  execute_forward_pass_with_head_ablation,
5
  execute_forward_pass_with_multi_layer_head_ablation,
6
  merge_token_probabilities,
7
- compute_global_top5_tokens, detect_significant_probability_increases,
 
8
  evaluate_sequence_ablation, generate_bertviz_model_view_html)
9
  from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
10
  from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
@@ -25,6 +26,7 @@ __all__ = [
25
  'generate_bertviz_html',
26
  'merge_token_probabilities',
27
  'compute_global_top5_tokens',
 
28
  'detect_significant_probability_increases',
29
  'generate_bertviz_model_view_html',
30
 
 
4
  execute_forward_pass_with_head_ablation,
5
  execute_forward_pass_with_multi_layer_head_ablation,
6
  merge_token_probabilities,
7
+ compute_global_top5_tokens, compute_per_position_top5,
8
+ detect_significant_probability_increases,
9
  evaluate_sequence_ablation, generate_bertviz_model_view_html)
10
  from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
11
  from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
 
26
  'generate_bertviz_html',
27
  'merge_token_probabilities',
28
  'compute_global_top5_tokens',
29
+ 'compute_per_position_top5',
30
  'detect_significant_probability_increases',
31
  'generate_bertviz_model_view_html',
32
 
utils/model_patterns.py CHANGED
@@ -125,6 +125,98 @@ def compute_global_top5_tokens(model_output, tokenizer, top_k: int = 5) -> List[
125
  return [{'token': t, 'probability': p} for t, p in merged[:top_k]]
126
 
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def get_actual_model_output(model_output, tokenizer) -> Tuple[str, float]:
129
  """
130
  Extract the predicted token from model's output.
@@ -148,16 +240,21 @@ def get_actual_model_output(model_output, tokenizer) -> Tuple[str, float]:
148
  return token_str, top_prob.item()
149
 
150
 
151
- def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any], ablation_config: Optional[Dict[int, List[int]]] = None) -> Dict[str, Any]:
 
 
152
  """
153
  Execute forward pass with PyVene IntervenableModel to capture activations from specified modules.
154
 
155
  Args:
156
  model: Loaded transformer model
157
  tokenizer: Loaded tokenizer
158
- prompt: Input text prompt
159
  config: Dict with module lists like {"attention_modules": [...], "block_modules": [...], ...}
160
  ablation_config: Optional dict mapping layer numbers to list of head indices to ablate.
 
 
 
161
 
162
  Returns:
163
  JSON-serializable dict with captured activations and metadata
@@ -255,6 +352,30 @@ def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any],
255
  except Exception as e:
256
  print(f"Warning: Could not extract model output: {e}")
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  # Build output dictionary
259
  result = {
260
  "model": getattr(model.config, "name_or_path", "unknown"),
@@ -267,7 +388,11 @@ def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any],
267
  "norm_parameters": norm_parameters,
268
  "norm_data": norm_data,
269
  "actual_output": actual_output,
270
- "global_top5_tokens": global_top5_tokens # New: global top 5 from final output
 
 
 
 
271
  }
272
 
273
  print(f"Captured {len(captured)} module outputs using PyVene")
 
125
  return [{'token': t, 'probability': p} for t, p in merged[:top_k]]
126
 
127
 
128
+ def compute_per_position_top5(model_output, tokenizer, prompt_token_count: int, top_k: int = 5) -> List[Dict[str, Any]]:
129
+ """
130
+ Compute top-K next-token probabilities at each generated-token position.
131
+
132
+ Uses logits already produced by the forward pass on the full sequence
133
+ (prompt + generated tokens). Position i in the returned list corresponds
134
+ to the prediction of generated token g_i given the prefix up to g_{i-1}.
135
+
136
+ Args:
137
+ model_output: Output from model(**inputs) containing logits [1, seq_len, vocab].
138
+ tokenizer: Tokenizer for decoding token IDs.
139
+ prompt_token_count: Number of tokens in the original prompt (P).
140
+ top_k: Number of top tokens per position (default 5).
141
+
142
+ Returns:
143
+ List of dicts, one per generated token position::
144
+
145
+ [
146
+ {
147
+ "position": 0,
148
+ "top5": [{"token": str, "probability": float}, ...],
149
+ "actual_token": str, # token actually generated at this position
150
+ "actual_prob": float # its probability at this position
151
+ },
152
+ ...
153
+ ]
154
+ """
155
+ seq_len = model_output.logits.shape[1]
156
+ num_generated = seq_len - prompt_token_count
157
+ if num_generated <= 0:
158
+ return []
159
+
160
+ results = []
161
+ with torch.no_grad():
162
+ # Precompute input_ids from the logits tensor shape for actual-token lookup.
163
+ # The actual token at generated position i lives at input index prompt_token_count + i.
164
+ # We recover it from argmax only when we don't have the real ids; however
165
+ # the caller should pass the full-sequence ids. Here we derive the actual
166
+ # token from the logits tensor's *next* position in the sequence.
167
+ all_logits = model_output.logits[0] # [seq_len, vocab]
168
+
169
+ for i in range(num_generated):
170
+ logit_idx = prompt_token_count - 1 + i # index into logits
171
+ next_token_idx = prompt_token_count + i # index of the actual next token
172
+
173
+ probs = F.softmax(all_logits[logit_idx], dim=-1)
174
+
175
+ # --- top-K with merge ---
176
+ top_probs, top_indices = torch.topk(probs, k=min(top_k * 2, len(probs)))
177
+ candidates = [
178
+ (tokenizer.decode([idx.item()], skip_special_tokens=False), prob.item())
179
+ for idx, prob in zip(top_indices, top_probs)
180
+ ]
181
+ merged = merge_token_probabilities(candidates)
182
+ top5 = [{'token': t, 'probability': p} for t, p in merged[:top_k]]
183
+
184
+ # --- actual token at this position ---
185
+ # The actual next token is whichever token the model *was given* at
186
+ # next_token_idx. We can infer it from the argmax of the embedding
187
+ # lookup, but the simplest reliable way is to use the input_ids that
188
+ # produced these logits. Since we don't have direct access to
189
+ # input_ids here, we look at the logits at the *next* position:
190
+ # the token fed at position next_token_idx determined that position's
191
+ # context. We recover it by checking which token index has the
192
+ # highest *un-softmaxed* logit at position (logit_idx - 1) ... but
193
+ # that is circular. Instead, the caller stores the actual token ids
194
+ # alongside model_output. We fall back to a secondary attribute.
195
+ actual_token_id = None
196
+ if hasattr(model_output, 'input_ids') and model_output.input_ids is not None:
197
+ actual_token_id = model_output.input_ids[0, next_token_idx].item()
198
+ elif hasattr(model_output, '_input_ids'):
199
+ actual_token_id = model_output._input_ids[0, next_token_idx].item()
200
+
201
+ if actual_token_id is not None:
202
+ actual_token = tokenizer.decode([actual_token_id], skip_special_tokens=False)
203
+ actual_prob = probs[actual_token_id].item()
204
+ else:
205
+ # Fallback: use the argmax as "actual" (only correct for greedy)
206
+ top_prob, top_idx = probs.max(dim=-1)
207
+ actual_token = tokenizer.decode([top_idx.item()], skip_special_tokens=False)
208
+ actual_prob = top_prob.item()
209
+
210
+ results.append({
211
+ 'position': i,
212
+ 'top5': top5,
213
+ 'actual_token': actual_token,
214
+ 'actual_prob': float(actual_prob),
215
+ })
216
+
217
+ return results
218
+
219
+
220
  def get_actual_model_output(model_output, tokenizer) -> Tuple[str, float]:
221
  """
222
  Extract the predicted token from model's output.
 
240
  return token_str, top_prob.item()
241
 
242
 
243
+ def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any],
244
+ ablation_config: Optional[Dict[int, List[int]]] = None,
245
+ original_prompt: Optional[str] = None) -> Dict[str, Any]:
246
  """
247
  Execute forward pass with PyVene IntervenableModel to capture activations from specified modules.
248
 
249
  Args:
250
  model: Loaded transformer model
251
  tokenizer: Loaded tokenizer
252
+ prompt: Input text prompt (may be full sequence: original prompt + generated tokens)
253
  config: Dict with module lists like {"attention_modules": [...], "block_modules": [...], ...}
254
  ablation_config: Optional dict mapping layer numbers to list of head indices to ablate.
255
+ original_prompt: When provided, enables per-position top-5 computation for
256
+ the output scrubber. If prompt contains generated tokens beyond
257
+ original_prompt, each generated-token position gets its own top-5 data.
258
 
259
  Returns:
260
  JSON-serializable dict with captured activations and metadata
 
352
  except Exception as e:
353
  print(f"Warning: Could not extract model output: {e}")
354
 
355
+ # --- Per-position top-5 for the output scrubber ---
356
+ per_position_top5 = []
357
+ prompt_token_count = None
358
+ generated_tokens = []
359
+ if original_prompt is not None:
360
+ prompt_ids = tokenizer(original_prompt, return_tensors="pt")["input_ids"]
361
+ prompt_token_count = prompt_ids.shape[1]
362
+ seq_len = inputs["input_ids"].shape[1]
363
+ num_generated = seq_len - prompt_token_count
364
+
365
+ if num_generated > 0:
366
+ # Attach input_ids to model_output so compute_per_position_top5
367
+ # can look up the actual token at each position.
368
+ model_output.input_ids = inputs["input_ids"]
369
+ per_position_top5 = compute_per_position_top5(
370
+ model_output, tokenizer, prompt_token_count, top_k=5
371
+ )
372
+ # Decode each generated token individually for slider marks
373
+ full_ids = inputs["input_ids"][0].tolist()
374
+ generated_tokens = [
375
+ tokenizer.decode([full_ids[prompt_token_count + i]], skip_special_tokens=False)
376
+ for i in range(num_generated)
377
+ ]
378
+
379
  # Build output dictionary
380
  result = {
381
  "model": getattr(model.config, "name_or_path", "unknown"),
 
388
  "norm_parameters": norm_parameters,
389
  "norm_data": norm_data,
390
  "actual_output": actual_output,
391
+ "global_top5_tokens": global_top5_tokens,
392
+ "per_position_top5": per_position_top5,
393
+ "prompt_token_count": prompt_token_count,
394
+ "generated_tokens": generated_tokens,
395
+ "original_prompt": original_prompt,
396
  }
397
 
398
  print(f"Captured {len(captured)} module outputs using PyVene")