cdpearlman commited on
Commit
670d882
·
1 Parent(s): 78cc5b4

Fixed ablation graphs first token population, still need to fix original graphs when no beam selected

Browse files
Files changed (3) hide show
  1. app.py +3 -80
  2. components/ablation_panel.py +128 -7
  3. plans.md +4 -2
app.py CHANGED
@@ -973,6 +973,8 @@ def update_ablation_scrubber(position, original_data, ablated_data):
973
  if position is None or not original_data or not ablated_data:
974
  import dash
975
  return dash.no_update
 
 
976
 
977
  orig_pos_data = original_data.get('per_position_top5', [])
978
  abl_pos_data = ablated_data.get('per_position_top5', [])
@@ -980,75 +982,6 @@ def update_ablation_scrubber(position, original_data, ablated_data):
980
  orig_tokens = original_data.get('generated_tokens', [])
981
  abl_tokens = ablated_data.get('generated_tokens', [])
982
 
983
- # Helper to build token map
984
- def build_token_map(tokens, current_pos, changed_indices):
985
- from dash import html
986
- elements = []
987
- for i, token in enumerate(tokens):
988
- if i > 0: elements.append(html.Span(" → ", style={'color': '#ced4da', 'margin': '0 4px'}))
989
-
990
- is_current = i == current_pos
991
- is_changed = i in changed_indices
992
-
993
- style = {'fontWeight': 'bold' if is_current else 'normal'}
994
- if is_current:
995
- style['color'] = '#ffffff'
996
- style['backgroundColor'] = '#dc3545' if is_changed else '#28a745'
997
- style['padding'] = '2px 6px'
998
- style['borderRadius'] = '4px'
999
- elif is_changed:
1000
- style['color'] = '#dc3545'
1001
-
1002
- elements.append(html.Span(f"T{i} ({token.strip()})", style=style))
1003
- return elements
1004
-
1005
- # Helper to build text box
1006
- def build_text_box(prompt_text, tokens, current_pos, changed_indices):
1007
- from dash import html
1008
- elements = [html.Span(prompt_text, style={'color': '#6c757d'})]
1009
- for i, token in enumerate(tokens):
1010
- is_current = i == current_pos
1011
- is_changed = i in changed_indices
1012
-
1013
- style = {}
1014
- if is_current:
1015
- style['backgroundColor'] = '#ffc107' if is_changed else '#0dcaf0'
1016
- style['color'] = '#000'
1017
- style['borderRadius'] = '3px'
1018
- style['padding'] = '0 2px'
1019
- style['fontWeight'] = 'bold'
1020
-
1021
- elements.append(html.Span(token, style=style))
1022
- return elements
1023
-
1024
- def build_chart(pos_data, actual_token, main_color):
1025
- import plotly.graph_objs as go
1026
- if not pos_data: return go.Figure().update_layout(margin=dict(l=0, r=0, t=0, b=0), height=200)
1027
-
1028
- top5 = pos_data.get('top5', [])
1029
- tokens = [t['token'] for t in reversed(top5)]
1030
- probs = [t['probability'] for t in reversed(top5)]
1031
-
1032
- colors = []
1033
- for t in tokens:
1034
- if t == actual_token:
1035
- colors.append(main_color)
1036
- else:
1037
- colors.append('#e2e8f0' if main_color == '#4c51bf' else '#f8d7da')
1038
-
1039
- fig = go.Figure(go.Bar(
1040
- x=probs, y=tokens, orientation='h', marker_color=colors,
1041
- text=[f"{p:.1%}" for p in probs], textposition='auto'
1042
- ))
1043
- fig.update_layout(
1044
- margin=dict(l=0, r=0, t=0, b=0), height=200,
1045
- xaxis=dict(visible=False, range=[0, 1]),
1046
- yaxis=dict(tickfont=dict(size=12)),
1047
- paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)',
1048
- showlegend=False
1049
- )
1050
- return fig
1051
-
1052
  # Find changed indices
1053
  changed_indices = set()
1054
  for i in range(max(len(orig_tokens), len(abl_tokens))):
@@ -1066,23 +999,13 @@ def update_ablation_scrubber(position, original_data, ablated_data):
1066
  orig_chart = []
1067
  abl_chart = []
1068
 
1069
- from dash import html
1070
- divergence_indicator = html.Div()
1071
-
1072
  if position < len(orig_pos_data):
1073
  orig_chart = build_chart(orig_pos_data[position], orig_pos_data[position].get('actual_token'), '#4c51bf')
1074
  if position < len(abl_pos_data):
1075
  abl_chart = build_chart(abl_pos_data[position], abl_pos_data[position].get('actual_token'), '#e53e3e')
1076
 
1077
  is_diverged = position in changed_indices
1078
- if is_diverged:
1079
- divergence_indicator = html.Div([
1080
- html.I(className="fas fa-exclamation-circle", style={'color': '#dc3545', 'fontSize': '32px', 'backgroundColor': '#fff5f5', 'borderRadius': '50%', 'padding': '10px', 'boxShadow': '0 0 15px rgba(220,53,69,0.4)'})
1081
- ])
1082
- else:
1083
- divergence_indicator = html.Div([
1084
- html.I(className="fas fa-check-circle", style={'color': '#28a745', 'fontSize': '32px', 'backgroundColor': '#f0fdf4', 'borderRadius': '50%', 'padding': '10px'})
1085
- ])
1086
 
1087
  return orig_map, orig_text_box, orig_chart, abl_map, abl_text_box, abl_chart, divergence_indicator
1088
 
 
973
  if position is None or not original_data or not ablated_data:
974
  import dash
975
  return dash.no_update
976
+
977
+ from components.ablation_panel import build_token_map, build_text_box, build_chart, build_divergence_indicator
978
 
979
  orig_pos_data = original_data.get('per_position_top5', [])
980
  abl_pos_data = ablated_data.get('per_position_top5', [])
 
982
  orig_tokens = original_data.get('generated_tokens', [])
983
  abl_tokens = ablated_data.get('generated_tokens', [])
984
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985
  # Find changed indices
986
  changed_indices = set()
987
  for i in range(max(len(orig_tokens), len(abl_tokens))):
 
999
  orig_chart = []
1000
  abl_chart = []
1001
 
 
 
 
1002
  if position < len(orig_pos_data):
1003
  orig_chart = build_chart(orig_pos_data[position], orig_pos_data[position].get('actual_token'), '#4c51bf')
1004
  if position < len(abl_pos_data):
1005
  abl_chart = build_chart(abl_pos_data[position], abl_pos_data[position].get('actual_token'), '#e53e3e')
1006
 
1007
  is_diverged = position in changed_indices
1008
+ divergence_indicator = build_divergence_indicator(is_diverged)
 
 
 
 
 
 
 
1009
 
1010
  return orig_map, orig_text_box, orig_chart, abl_map, abl_text_box, abl_chart, divergence_indicator
1011
 
components/ablation_panel.py CHANGED
@@ -10,6 +10,102 @@ import plotly.graph_objs as go
10
  import json
11
  from utils.colors import head_color
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def create_ablation_panel():
14
  """Create the main ablation tool content."""
15
  return html.Div([
@@ -268,6 +364,31 @@ def create_ablation_results_display(original_data, ablated_data, selected_heads,
268
  )
269
  ], style={'padding': '0 20px 20px 20px'})
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  # Comparison Grid
272
  comparison_grid = html.Div([
273
  # Original Output Column (Green Theme)
@@ -277,20 +398,20 @@ def create_ablation_results_display(original_data, ablated_data, selected_heads,
277
  'borderRadius': '16px', 'fontWeight': 'bold', 'fontSize': '12px',
278
  'display': 'inline-block', 'marginBottom': '15px'
279
  }),
280
- html.Div(id='ablation-original-token-map', style={'fontSize': '12px', 'color': '#6c757d', 'marginBottom': '10px', 'minHeight': '40px', 'textAlign': 'left', 'lineHeight': '1.5'}),
281
- html.Div(id='ablation-original-text-box', style={
282
  'backgroundColor': '#f8f9fa', 'border': '1px solid #dee2e6', 'borderRadius': '8px',
283
  'padding': '15px', 'fontFamily': 'monospace', 'fontSize': '14px', 'minHeight': '80px', 'marginBottom': '15px', 'textAlign': 'left', 'whiteSpace': 'pre-wrap'
284
  }),
285
  html.Div("TOP 5 PREDICTIONS", style={'textAlign': 'center', 'fontWeight': 'bold', 'color': '#495057', 'fontSize': '12px', 'marginBottom': '10px'}),
286
- dcc.Graph(id='ablation-original-top5-chart', config={'displayModeBar': False}, style={'height': '200px'})
287
  ], style={
288
  'flex': '1', 'border': '2px solid #28a745', 'borderRadius': '12px',
289
  'padding': '20px', 'textAlign': 'center', 'backgroundColor': 'white', 'width': '45%'
290
  }),
291
 
292
  # Center Divergence Indicator
293
- html.Div(id='ablation-divergence-indicator', style={
294
  'width': '60px', 'display': 'flex', 'flexDirection': 'column',
295
  'alignItems': 'center', 'justifyContent': 'center'
296
  }),
@@ -302,13 +423,13 @@ def create_ablation_results_display(original_data, ablated_data, selected_heads,
302
  'borderRadius': '16px', 'fontWeight': 'bold', 'fontSize': '12px',
303
  'display': 'inline-block', 'marginBottom': '15px'
304
  }),
305
- html.Div(id='ablation-ablated-token-map', style={'fontSize': '12px', 'color': '#6c757d', 'marginBottom': '10px', 'minHeight': '40px', 'textAlign': 'left', 'lineHeight': '1.5'}),
306
- html.Div(id='ablation-ablated-text-box', style={
307
  'backgroundColor': '#f8f9fa', 'border': '1px solid #dee2e6', 'borderRadius': '8px',
308
  'padding': '15px', 'fontFamily': 'monospace', 'fontSize': '14px', 'minHeight': '80px', 'marginBottom': '15px', 'textAlign': 'left', 'whiteSpace': 'pre-wrap'
309
  }),
310
  html.Div("TOP 5 PREDICTIONS", style={'textAlign': 'center', 'fontWeight': 'bold', 'color': '#495057', 'fontSize': '12px', 'marginBottom': '10px'}),
311
- dcc.Graph(id='ablation-ablated-top5-chart', config={'displayModeBar': False}, style={'height': '200px'})
312
  ], style={
313
  'flex': '1', 'border': '2px solid #dc3545', 'borderRadius': '12px',
314
  'padding': '20px', 'textAlign': 'center', 'backgroundColor': 'white', 'width': '45%'
 
10
  import json
11
  from utils.colors import head_color
12
 
13
+
14
+ # ============================================================================
15
+ # Shared rendering helpers (used by both initial render and scrubber callback)
16
+ # ============================================================================
17
+
18
+ def build_token_map(tokens, current_pos, changed_indices):
19
+ """Build a token map display showing token sequence with current position highlighted."""
20
+ elements = []
21
+ for i, token in enumerate(tokens):
22
+ if i > 0:
23
+ elements.append(html.Span(" → ", style={'color': '#ced4da', 'margin': '0 4px'}))
24
+
25
+ is_current = i == current_pos
26
+ is_changed = i in changed_indices
27
+
28
+ style = {'fontWeight': 'bold' if is_current else 'normal'}
29
+ if is_current:
30
+ style['color'] = '#ffffff'
31
+ style['backgroundColor'] = '#dc3545' if is_changed else '#28a745'
32
+ style['padding'] = '2px 6px'
33
+ style['borderRadius'] = '4px'
34
+ elif is_changed:
35
+ style['color'] = '#dc3545'
36
+
37
+ elements.append(html.Span(f"T{i} ({token.strip()})", style=style))
38
+ return elements
39
+
40
+
41
+ def build_text_box(prompt_text, tokens, current_pos, changed_indices):
42
+ """Build a text box with the prompt and generated tokens, highlighting the current position."""
43
+ elements = [html.Span(prompt_text, style={'color': '#6c757d'})]
44
+ for i, token in enumerate(tokens):
45
+ is_current = i == current_pos
46
+ is_changed = i in changed_indices
47
+
48
+ style = {}
49
+ if is_current:
50
+ style['backgroundColor'] = '#ffc107' if is_changed else '#0dcaf0'
51
+ style['color'] = '#000'
52
+ style['borderRadius'] = '3px'
53
+ style['padding'] = '0 2px'
54
+ style['fontWeight'] = 'bold'
55
+
56
+ elements.append(html.Span(token, style=style))
57
+ return elements
58
+
59
+
60
+ def build_chart(pos_data, actual_token, main_color):
61
+ """Build a horizontal bar chart of top-5 predictions for a given position."""
62
+ if not pos_data:
63
+ return go.Figure().update_layout(margin=dict(l=0, r=0, t=0, b=0), height=200)
64
+
65
+ top5 = pos_data.get('top5', [])
66
+ tokens = [t['token'] for t in reversed(top5)]
67
+ probs = [t['probability'] for t in reversed(top5)]
68
+
69
+ colors = []
70
+ for t in tokens:
71
+ if t == actual_token:
72
+ colors.append(main_color)
73
+ else:
74
+ colors.append('#e2e8f0' if main_color == '#4c51bf' else '#f8d7da')
75
+
76
+ fig = go.Figure(go.Bar(
77
+ x=probs, y=tokens, orientation='h', marker_color=colors,
78
+ text=[f"{p:.1%}" for p in probs], textposition='auto'
79
+ ))
80
+ fig.update_layout(
81
+ margin=dict(l=0, r=0, t=0, b=0), height=200,
82
+ xaxis=dict(visible=False, range=[0, 1]),
83
+ yaxis=dict(tickfont=dict(size=12)),
84
+ paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)',
85
+ showlegend=False
86
+ )
87
+ return fig
88
+
89
+
90
+ def build_divergence_indicator(is_diverged):
91
+ """Build the divergence indicator icon (check or exclamation)."""
92
+ if is_diverged:
93
+ return html.Div([
94
+ html.I(className="fas fa-exclamation-circle", style={
95
+ 'color': '#dc3545', 'fontSize': '32px', 'backgroundColor': '#fff5f5',
96
+ 'borderRadius': '50%', 'padding': '10px',
97
+ 'boxShadow': '0 0 15px rgba(220,53,69,0.4)'
98
+ })
99
+ ])
100
+ else:
101
+ return html.Div([
102
+ html.I(className="fas fa-check-circle", style={
103
+ 'color': '#28a745', 'fontSize': '32px', 'backgroundColor': '#f0fdf4',
104
+ 'borderRadius': '50%', 'padding': '10px'
105
+ })
106
+ ])
107
+
108
+
109
  def create_ablation_panel():
110
  """Create the main ablation tool content."""
111
  return html.Div([
 
364
  )
365
  ], style={'padding': '0 20px 20px 20px'})
366
 
367
+ # --- Pre-populate position 0 content so the display isn't blank ---
368
+ prompt_text = original_data.get('original_prompt', original_data.get('prompt', '')) if original_data else ''
369
+
370
+ # Compute changed indices (same logic as the scrubber callback)
371
+ changed_indices = set()
372
+ for i in range(max(len(orig_tokens), len(abl_tokens))):
373
+ if i >= len(orig_tokens) or i >= len(abl_tokens) or orig_tokens[i] != abl_tokens[i]:
374
+ changed_indices.add(i)
375
+
376
+ pos0 = 0
377
+ init_orig_map = build_token_map(orig_tokens, pos0, set())
378
+ init_abl_map = build_token_map(abl_tokens, pos0, changed_indices)
379
+ init_orig_text = build_text_box(prompt_text, orig_tokens, pos0, set())
380
+ init_abl_text = build_text_box(prompt_text, abl_tokens, pos0, changed_indices)
381
+
382
+ init_orig_chart = go.Figure().update_layout(margin=dict(l=0, r=0, t=0, b=0), height=200)
383
+ init_abl_chart = go.Figure().update_layout(margin=dict(l=0, r=0, t=0, b=0), height=200)
384
+ if pos0 < len(orig_positions):
385
+ init_orig_chart = build_chart(orig_positions[pos0], orig_positions[pos0].get('actual_token'), '#4c51bf')
386
+ if pos0 < len(abl_positions):
387
+ init_abl_chart = build_chart(abl_positions[pos0], abl_positions[pos0].get('actual_token'), '#e53e3e')
388
+
389
+ init_diverged = pos0 in changed_indices
390
+ init_divergence = build_divergence_indicator(init_diverged)
391
+
392
  # Comparison Grid
393
  comparison_grid = html.Div([
394
  # Original Output Column (Green Theme)
 
398
  'borderRadius': '16px', 'fontWeight': 'bold', 'fontSize': '12px',
399
  'display': 'inline-block', 'marginBottom': '15px'
400
  }),
401
+ html.Div(init_orig_map, id='ablation-original-token-map', style={'fontSize': '12px', 'color': '#6c757d', 'marginBottom': '10px', 'minHeight': '40px', 'textAlign': 'left', 'lineHeight': '1.5'}),
402
+ html.Div(init_orig_text, id='ablation-original-text-box', style={
403
  'backgroundColor': '#f8f9fa', 'border': '1px solid #dee2e6', 'borderRadius': '8px',
404
  'padding': '15px', 'fontFamily': 'monospace', 'fontSize': '14px', 'minHeight': '80px', 'marginBottom': '15px', 'textAlign': 'left', 'whiteSpace': 'pre-wrap'
405
  }),
406
  html.Div("TOP 5 PREDICTIONS", style={'textAlign': 'center', 'fontWeight': 'bold', 'color': '#495057', 'fontSize': '12px', 'marginBottom': '10px'}),
407
+ dcc.Graph(id='ablation-original-top5-chart', figure=init_orig_chart, config={'displayModeBar': False}, style={'height': '200px'})
408
  ], style={
409
  'flex': '1', 'border': '2px solid #28a745', 'borderRadius': '12px',
410
  'padding': '20px', 'textAlign': 'center', 'backgroundColor': 'white', 'width': '45%'
411
  }),
412
 
413
  # Center Divergence Indicator
414
+ html.Div(init_divergence, id='ablation-divergence-indicator', style={
415
  'width': '60px', 'display': 'flex', 'flexDirection': 'column',
416
  'alignItems': 'center', 'justifyContent': 'center'
417
  }),
 
423
  'borderRadius': '16px', 'fontWeight': 'bold', 'fontSize': '12px',
424
  'display': 'inline-block', 'marginBottom': '15px'
425
  }),
426
+ html.Div(init_abl_map, id='ablation-ablated-token-map', style={'fontSize': '12px', 'color': '#6c757d', 'marginBottom': '10px', 'minHeight': '40px', 'textAlign': 'left', 'lineHeight': '1.5'}),
427
+ html.Div(init_abl_text, id='ablation-ablated-text-box', style={
428
  'backgroundColor': '#f8f9fa', 'border': '1px solid #dee2e6', 'borderRadius': '8px',
429
  'padding': '15px', 'fontFamily': 'monospace', 'fontSize': '14px', 'minHeight': '80px', 'marginBottom': '15px', 'textAlign': 'left', 'whiteSpace': 'pre-wrap'
430
  }),
431
  html.Div("TOP 5 PREDICTIONS", style={'textAlign': 'center', 'fontWeight': 'bold', 'color': '#495057', 'fontSize': '12px', 'marginBottom': '10px'}),
432
+ dcc.Graph(id='ablation-ablated-top5-chart', figure=init_abl_chart, config={'displayModeBar': False}, style={'height': '200px'})
433
  ], style={
434
  'flex': '1', 'border': '2px solid #dc3545', 'borderRadius': '12px',
435
  'padding': '20px', 'textAlign': 'center', 'backgroundColor': 'white', 'width': '45%'
plans.md CHANGED
@@ -1,4 +1,5 @@
1
- - specs on what each attention head does
 
2
 
3
  Done:
4
  - change attention to entire generated sequence
@@ -9,4 +10,5 @@ Done:
9
  - output streaming for chatbot
10
  - shorter, concise responses in system prompt
11
  - add video links to glossary
12
- - three blue one brown
 
 
1
+ To Do:
2
+ - generating with 1 output token doesn't show the "selected" beam, so users can't see the output from the model
3
 
4
  Done:
5
  - change attention to entire generated sequence
 
10
  - output streaming for chatbot
11
  - shorter, concise responses in system prompt
12
  - add video links to glossary
13
+ - three blue one brown
14
+ - specs on what each attention head does