Spaces:
Sleeping
Sleeping
Commit ·
670d882
1
Parent(s): 78cc5b4
Fixed ablation graphs first token population, still need to fix original graphs when no beam selected
Browse files- app.py +3 -80
- components/ablation_panel.py +128 -7
- plans.md +4 -2
app.py
CHANGED
|
@@ -973,6 +973,8 @@ def update_ablation_scrubber(position, original_data, ablated_data):
|
|
| 973 |
if position is None or not original_data or not ablated_data:
|
| 974 |
import dash
|
| 975 |
return dash.no_update
|
|
|
|
|
|
|
| 976 |
|
| 977 |
orig_pos_data = original_data.get('per_position_top5', [])
|
| 978 |
abl_pos_data = ablated_data.get('per_position_top5', [])
|
|
@@ -980,75 +982,6 @@ def update_ablation_scrubber(position, original_data, ablated_data):
|
|
| 980 |
orig_tokens = original_data.get('generated_tokens', [])
|
| 981 |
abl_tokens = ablated_data.get('generated_tokens', [])
|
| 982 |
|
| 983 |
-
# Helper to build token map
|
| 984 |
-
def build_token_map(tokens, current_pos, changed_indices):
|
| 985 |
-
from dash import html
|
| 986 |
-
elements = []
|
| 987 |
-
for i, token in enumerate(tokens):
|
| 988 |
-
if i > 0: elements.append(html.Span(" → ", style={'color': '#ced4da', 'margin': '0 4px'}))
|
| 989 |
-
|
| 990 |
-
is_current = i == current_pos
|
| 991 |
-
is_changed = i in changed_indices
|
| 992 |
-
|
| 993 |
-
style = {'fontWeight': 'bold' if is_current else 'normal'}
|
| 994 |
-
if is_current:
|
| 995 |
-
style['color'] = '#ffffff'
|
| 996 |
-
style['backgroundColor'] = '#dc3545' if is_changed else '#28a745'
|
| 997 |
-
style['padding'] = '2px 6px'
|
| 998 |
-
style['borderRadius'] = '4px'
|
| 999 |
-
elif is_changed:
|
| 1000 |
-
style['color'] = '#dc3545'
|
| 1001 |
-
|
| 1002 |
-
elements.append(html.Span(f"T{i} ({token.strip()})", style=style))
|
| 1003 |
-
return elements
|
| 1004 |
-
|
| 1005 |
-
# Helper to build text box
|
| 1006 |
-
def build_text_box(prompt_text, tokens, current_pos, changed_indices):
|
| 1007 |
-
from dash import html
|
| 1008 |
-
elements = [html.Span(prompt_text, style={'color': '#6c757d'})]
|
| 1009 |
-
for i, token in enumerate(tokens):
|
| 1010 |
-
is_current = i == current_pos
|
| 1011 |
-
is_changed = i in changed_indices
|
| 1012 |
-
|
| 1013 |
-
style = {}
|
| 1014 |
-
if is_current:
|
| 1015 |
-
style['backgroundColor'] = '#ffc107' if is_changed else '#0dcaf0'
|
| 1016 |
-
style['color'] = '#000'
|
| 1017 |
-
style['borderRadius'] = '3px'
|
| 1018 |
-
style['padding'] = '0 2px'
|
| 1019 |
-
style['fontWeight'] = 'bold'
|
| 1020 |
-
|
| 1021 |
-
elements.append(html.Span(token, style=style))
|
| 1022 |
-
return elements
|
| 1023 |
-
|
| 1024 |
-
def build_chart(pos_data, actual_token, main_color):
|
| 1025 |
-
import plotly.graph_objs as go
|
| 1026 |
-
if not pos_data: return go.Figure().update_layout(margin=dict(l=0, r=0, t=0, b=0), height=200)
|
| 1027 |
-
|
| 1028 |
-
top5 = pos_data.get('top5', [])
|
| 1029 |
-
tokens = [t['token'] for t in reversed(top5)]
|
| 1030 |
-
probs = [t['probability'] for t in reversed(top5)]
|
| 1031 |
-
|
| 1032 |
-
colors = []
|
| 1033 |
-
for t in tokens:
|
| 1034 |
-
if t == actual_token:
|
| 1035 |
-
colors.append(main_color)
|
| 1036 |
-
else:
|
| 1037 |
-
colors.append('#e2e8f0' if main_color == '#4c51bf' else '#f8d7da')
|
| 1038 |
-
|
| 1039 |
-
fig = go.Figure(go.Bar(
|
| 1040 |
-
x=probs, y=tokens, orientation='h', marker_color=colors,
|
| 1041 |
-
text=[f"{p:.1%}" for p in probs], textposition='auto'
|
| 1042 |
-
))
|
| 1043 |
-
fig.update_layout(
|
| 1044 |
-
margin=dict(l=0, r=0, t=0, b=0), height=200,
|
| 1045 |
-
xaxis=dict(visible=False, range=[0, 1]),
|
| 1046 |
-
yaxis=dict(tickfont=dict(size=12)),
|
| 1047 |
-
paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)',
|
| 1048 |
-
showlegend=False
|
| 1049 |
-
)
|
| 1050 |
-
return fig
|
| 1051 |
-
|
| 1052 |
# Find changed indices
|
| 1053 |
changed_indices = set()
|
| 1054 |
for i in range(max(len(orig_tokens), len(abl_tokens))):
|
|
@@ -1066,23 +999,13 @@ def update_ablation_scrubber(position, original_data, ablated_data):
|
|
| 1066 |
orig_chart = []
|
| 1067 |
abl_chart = []
|
| 1068 |
|
| 1069 |
-
from dash import html
|
| 1070 |
-
divergence_indicator = html.Div()
|
| 1071 |
-
|
| 1072 |
if position < len(orig_pos_data):
|
| 1073 |
orig_chart = build_chart(orig_pos_data[position], orig_pos_data[position].get('actual_token'), '#4c51bf')
|
| 1074 |
if position < len(abl_pos_data):
|
| 1075 |
abl_chart = build_chart(abl_pos_data[position], abl_pos_data[position].get('actual_token'), '#e53e3e')
|
| 1076 |
|
| 1077 |
is_diverged = position in changed_indices
|
| 1078 |
-
|
| 1079 |
-
divergence_indicator = html.Div([
|
| 1080 |
-
html.I(className="fas fa-exclamation-circle", style={'color': '#dc3545', 'fontSize': '32px', 'backgroundColor': '#fff5f5', 'borderRadius': '50%', 'padding': '10px', 'boxShadow': '0 0 15px rgba(220,53,69,0.4)'})
|
| 1081 |
-
])
|
| 1082 |
-
else:
|
| 1083 |
-
divergence_indicator = html.Div([
|
| 1084 |
-
html.I(className="fas fa-check-circle", style={'color': '#28a745', 'fontSize': '32px', 'backgroundColor': '#f0fdf4', 'borderRadius': '50%', 'padding': '10px'})
|
| 1085 |
-
])
|
| 1086 |
|
| 1087 |
return orig_map, orig_text_box, orig_chart, abl_map, abl_text_box, abl_chart, divergence_indicator
|
| 1088 |
|
|
|
|
| 973 |
if position is None or not original_data or not ablated_data:
|
| 974 |
import dash
|
| 975 |
return dash.no_update
|
| 976 |
+
|
| 977 |
+
from components.ablation_panel import build_token_map, build_text_box, build_chart, build_divergence_indicator
|
| 978 |
|
| 979 |
orig_pos_data = original_data.get('per_position_top5', [])
|
| 980 |
abl_pos_data = ablated_data.get('per_position_top5', [])
|
|
|
|
| 982 |
orig_tokens = original_data.get('generated_tokens', [])
|
| 983 |
abl_tokens = ablated_data.get('generated_tokens', [])
|
| 984 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 985 |
# Find changed indices
|
| 986 |
changed_indices = set()
|
| 987 |
for i in range(max(len(orig_tokens), len(abl_tokens))):
|
|
|
|
| 999 |
orig_chart = []
|
| 1000 |
abl_chart = []
|
| 1001 |
|
|
|
|
|
|
|
|
|
|
| 1002 |
if position < len(orig_pos_data):
|
| 1003 |
orig_chart = build_chart(orig_pos_data[position], orig_pos_data[position].get('actual_token'), '#4c51bf')
|
| 1004 |
if position < len(abl_pos_data):
|
| 1005 |
abl_chart = build_chart(abl_pos_data[position], abl_pos_data[position].get('actual_token'), '#e53e3e')
|
| 1006 |
|
| 1007 |
is_diverged = position in changed_indices
|
| 1008 |
+
divergence_indicator = build_divergence_indicator(is_diverged)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1009 |
|
| 1010 |
return orig_map, orig_text_box, orig_chart, abl_map, abl_text_box, abl_chart, divergence_indicator
|
| 1011 |
|
components/ablation_panel.py
CHANGED
|
@@ -10,6 +10,102 @@ import plotly.graph_objs as go
|
|
| 10 |
import json
|
| 11 |
from utils.colors import head_color
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def create_ablation_panel():
|
| 14 |
"""Create the main ablation tool content."""
|
| 15 |
return html.Div([
|
|
@@ -268,6 +364,31 @@ def create_ablation_results_display(original_data, ablated_data, selected_heads,
|
|
| 268 |
)
|
| 269 |
], style={'padding': '0 20px 20px 20px'})
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
# Comparison Grid
|
| 272 |
comparison_grid = html.Div([
|
| 273 |
# Original Output Column (Green Theme)
|
|
@@ -277,20 +398,20 @@ def create_ablation_results_display(original_data, ablated_data, selected_heads,
|
|
| 277 |
'borderRadius': '16px', 'fontWeight': 'bold', 'fontSize': '12px',
|
| 278 |
'display': 'inline-block', 'marginBottom': '15px'
|
| 279 |
}),
|
| 280 |
-
html.Div(id='ablation-original-token-map', style={'fontSize': '12px', 'color': '#6c757d', 'marginBottom': '10px', 'minHeight': '40px', 'textAlign': 'left', 'lineHeight': '1.5'}),
|
| 281 |
-
html.Div(id='ablation-original-text-box', style={
|
| 282 |
'backgroundColor': '#f8f9fa', 'border': '1px solid #dee2e6', 'borderRadius': '8px',
|
| 283 |
'padding': '15px', 'fontFamily': 'monospace', 'fontSize': '14px', 'minHeight': '80px', 'marginBottom': '15px', 'textAlign': 'left', 'whiteSpace': 'pre-wrap'
|
| 284 |
}),
|
| 285 |
html.Div("TOP 5 PREDICTIONS", style={'textAlign': 'center', 'fontWeight': 'bold', 'color': '#495057', 'fontSize': '12px', 'marginBottom': '10px'}),
|
| 286 |
-
dcc.Graph(id='ablation-original-top5-chart', config={'displayModeBar': False}, style={'height': '200px'})
|
| 287 |
], style={
|
| 288 |
'flex': '1', 'border': '2px solid #28a745', 'borderRadius': '12px',
|
| 289 |
'padding': '20px', 'textAlign': 'center', 'backgroundColor': 'white', 'width': '45%'
|
| 290 |
}),
|
| 291 |
|
| 292 |
# Center Divergence Indicator
|
| 293 |
-
html.Div(id='ablation-divergence-indicator', style={
|
| 294 |
'width': '60px', 'display': 'flex', 'flexDirection': 'column',
|
| 295 |
'alignItems': 'center', 'justifyContent': 'center'
|
| 296 |
}),
|
|
@@ -302,13 +423,13 @@ def create_ablation_results_display(original_data, ablated_data, selected_heads,
|
|
| 302 |
'borderRadius': '16px', 'fontWeight': 'bold', 'fontSize': '12px',
|
| 303 |
'display': 'inline-block', 'marginBottom': '15px'
|
| 304 |
}),
|
| 305 |
-
html.Div(id='ablation-ablated-token-map', style={'fontSize': '12px', 'color': '#6c757d', 'marginBottom': '10px', 'minHeight': '40px', 'textAlign': 'left', 'lineHeight': '1.5'}),
|
| 306 |
-
html.Div(id='ablation-ablated-text-box', style={
|
| 307 |
'backgroundColor': '#f8f9fa', 'border': '1px solid #dee2e6', 'borderRadius': '8px',
|
| 308 |
'padding': '15px', 'fontFamily': 'monospace', 'fontSize': '14px', 'minHeight': '80px', 'marginBottom': '15px', 'textAlign': 'left', 'whiteSpace': 'pre-wrap'
|
| 309 |
}),
|
| 310 |
html.Div("TOP 5 PREDICTIONS", style={'textAlign': 'center', 'fontWeight': 'bold', 'color': '#495057', 'fontSize': '12px', 'marginBottom': '10px'}),
|
| 311 |
-
dcc.Graph(id='ablation-ablated-top5-chart', config={'displayModeBar': False}, style={'height': '200px'})
|
| 312 |
], style={
|
| 313 |
'flex': '1', 'border': '2px solid #dc3545', 'borderRadius': '12px',
|
| 314 |
'padding': '20px', 'textAlign': 'center', 'backgroundColor': 'white', 'width': '45%'
|
|
|
|
| 10 |
import json
|
| 11 |
from utils.colors import head_color
|
| 12 |
|
| 13 |
+
|
| 14 |
+
# ============================================================================
|
| 15 |
+
# Shared rendering helpers (used by both initial render and scrubber callback)
|
| 16 |
+
# ============================================================================
|
| 17 |
+
|
| 18 |
+
def build_token_map(tokens, current_pos, changed_indices):
|
| 19 |
+
"""Build a token map display showing token sequence with current position highlighted."""
|
| 20 |
+
elements = []
|
| 21 |
+
for i, token in enumerate(tokens):
|
| 22 |
+
if i > 0:
|
| 23 |
+
elements.append(html.Span(" → ", style={'color': '#ced4da', 'margin': '0 4px'}))
|
| 24 |
+
|
| 25 |
+
is_current = i == current_pos
|
| 26 |
+
is_changed = i in changed_indices
|
| 27 |
+
|
| 28 |
+
style = {'fontWeight': 'bold' if is_current else 'normal'}
|
| 29 |
+
if is_current:
|
| 30 |
+
style['color'] = '#ffffff'
|
| 31 |
+
style['backgroundColor'] = '#dc3545' if is_changed else '#28a745'
|
| 32 |
+
style['padding'] = '2px 6px'
|
| 33 |
+
style['borderRadius'] = '4px'
|
| 34 |
+
elif is_changed:
|
| 35 |
+
style['color'] = '#dc3545'
|
| 36 |
+
|
| 37 |
+
elements.append(html.Span(f"T{i} ({token.strip()})", style=style))
|
| 38 |
+
return elements
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def build_text_box(prompt_text, tokens, current_pos, changed_indices):
|
| 42 |
+
"""Build a text box with the prompt and generated tokens, highlighting the current position."""
|
| 43 |
+
elements = [html.Span(prompt_text, style={'color': '#6c757d'})]
|
| 44 |
+
for i, token in enumerate(tokens):
|
| 45 |
+
is_current = i == current_pos
|
| 46 |
+
is_changed = i in changed_indices
|
| 47 |
+
|
| 48 |
+
style = {}
|
| 49 |
+
if is_current:
|
| 50 |
+
style['backgroundColor'] = '#ffc107' if is_changed else '#0dcaf0'
|
| 51 |
+
style['color'] = '#000'
|
| 52 |
+
style['borderRadius'] = '3px'
|
| 53 |
+
style['padding'] = '0 2px'
|
| 54 |
+
style['fontWeight'] = 'bold'
|
| 55 |
+
|
| 56 |
+
elements.append(html.Span(token, style=style))
|
| 57 |
+
return elements
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def build_chart(pos_data, actual_token, main_color):
|
| 61 |
+
"""Build a horizontal bar chart of top-5 predictions for a given position."""
|
| 62 |
+
if not pos_data:
|
| 63 |
+
return go.Figure().update_layout(margin=dict(l=0, r=0, t=0, b=0), height=200)
|
| 64 |
+
|
| 65 |
+
top5 = pos_data.get('top5', [])
|
| 66 |
+
tokens = [t['token'] for t in reversed(top5)]
|
| 67 |
+
probs = [t['probability'] for t in reversed(top5)]
|
| 68 |
+
|
| 69 |
+
colors = []
|
| 70 |
+
for t in tokens:
|
| 71 |
+
if t == actual_token:
|
| 72 |
+
colors.append(main_color)
|
| 73 |
+
else:
|
| 74 |
+
colors.append('#e2e8f0' if main_color == '#4c51bf' else '#f8d7da')
|
| 75 |
+
|
| 76 |
+
fig = go.Figure(go.Bar(
|
| 77 |
+
x=probs, y=tokens, orientation='h', marker_color=colors,
|
| 78 |
+
text=[f"{p:.1%}" for p in probs], textposition='auto'
|
| 79 |
+
))
|
| 80 |
+
fig.update_layout(
|
| 81 |
+
margin=dict(l=0, r=0, t=0, b=0), height=200,
|
| 82 |
+
xaxis=dict(visible=False, range=[0, 1]),
|
| 83 |
+
yaxis=dict(tickfont=dict(size=12)),
|
| 84 |
+
paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)',
|
| 85 |
+
showlegend=False
|
| 86 |
+
)
|
| 87 |
+
return fig
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def build_divergence_indicator(is_diverged):
|
| 91 |
+
"""Build the divergence indicator icon (check or exclamation)."""
|
| 92 |
+
if is_diverged:
|
| 93 |
+
return html.Div([
|
| 94 |
+
html.I(className="fas fa-exclamation-circle", style={
|
| 95 |
+
'color': '#dc3545', 'fontSize': '32px', 'backgroundColor': '#fff5f5',
|
| 96 |
+
'borderRadius': '50%', 'padding': '10px',
|
| 97 |
+
'boxShadow': '0 0 15px rgba(220,53,69,0.4)'
|
| 98 |
+
})
|
| 99 |
+
])
|
| 100 |
+
else:
|
| 101 |
+
return html.Div([
|
| 102 |
+
html.I(className="fas fa-check-circle", style={
|
| 103 |
+
'color': '#28a745', 'fontSize': '32px', 'backgroundColor': '#f0fdf4',
|
| 104 |
+
'borderRadius': '50%', 'padding': '10px'
|
| 105 |
+
})
|
| 106 |
+
])
|
| 107 |
+
|
| 108 |
+
|
| 109 |
def create_ablation_panel():
|
| 110 |
"""Create the main ablation tool content."""
|
| 111 |
return html.Div([
|
|
|
|
| 364 |
)
|
| 365 |
], style={'padding': '0 20px 20px 20px'})
|
| 366 |
|
| 367 |
+
# --- Pre-populate position 0 content so the display isn't blank ---
|
| 368 |
+
prompt_text = original_data.get('original_prompt', original_data.get('prompt', '')) if original_data else ''
|
| 369 |
+
|
| 370 |
+
# Compute changed indices (same logic as the scrubber callback)
|
| 371 |
+
changed_indices = set()
|
| 372 |
+
for i in range(max(len(orig_tokens), len(abl_tokens))):
|
| 373 |
+
if i >= len(orig_tokens) or i >= len(abl_tokens) or orig_tokens[i] != abl_tokens[i]:
|
| 374 |
+
changed_indices.add(i)
|
| 375 |
+
|
| 376 |
+
pos0 = 0
|
| 377 |
+
init_orig_map = build_token_map(orig_tokens, pos0, set())
|
| 378 |
+
init_abl_map = build_token_map(abl_tokens, pos0, changed_indices)
|
| 379 |
+
init_orig_text = build_text_box(prompt_text, orig_tokens, pos0, set())
|
| 380 |
+
init_abl_text = build_text_box(prompt_text, abl_tokens, pos0, changed_indices)
|
| 381 |
+
|
| 382 |
+
init_orig_chart = go.Figure().update_layout(margin=dict(l=0, r=0, t=0, b=0), height=200)
|
| 383 |
+
init_abl_chart = go.Figure().update_layout(margin=dict(l=0, r=0, t=0, b=0), height=200)
|
| 384 |
+
if pos0 < len(orig_positions):
|
| 385 |
+
init_orig_chart = build_chart(orig_positions[pos0], orig_positions[pos0].get('actual_token'), '#4c51bf')
|
| 386 |
+
if pos0 < len(abl_positions):
|
| 387 |
+
init_abl_chart = build_chart(abl_positions[pos0], abl_positions[pos0].get('actual_token'), '#e53e3e')
|
| 388 |
+
|
| 389 |
+
init_diverged = pos0 in changed_indices
|
| 390 |
+
init_divergence = build_divergence_indicator(init_diverged)
|
| 391 |
+
|
| 392 |
# Comparison Grid
|
| 393 |
comparison_grid = html.Div([
|
| 394 |
# Original Output Column (Green Theme)
|
|
|
|
| 398 |
'borderRadius': '16px', 'fontWeight': 'bold', 'fontSize': '12px',
|
| 399 |
'display': 'inline-block', 'marginBottom': '15px'
|
| 400 |
}),
|
| 401 |
+
html.Div(init_orig_map, id='ablation-original-token-map', style={'fontSize': '12px', 'color': '#6c757d', 'marginBottom': '10px', 'minHeight': '40px', 'textAlign': 'left', 'lineHeight': '1.5'}),
|
| 402 |
+
html.Div(init_orig_text, id='ablation-original-text-box', style={
|
| 403 |
'backgroundColor': '#f8f9fa', 'border': '1px solid #dee2e6', 'borderRadius': '8px',
|
| 404 |
'padding': '15px', 'fontFamily': 'monospace', 'fontSize': '14px', 'minHeight': '80px', 'marginBottom': '15px', 'textAlign': 'left', 'whiteSpace': 'pre-wrap'
|
| 405 |
}),
|
| 406 |
html.Div("TOP 5 PREDICTIONS", style={'textAlign': 'center', 'fontWeight': 'bold', 'color': '#495057', 'fontSize': '12px', 'marginBottom': '10px'}),
|
| 407 |
+
dcc.Graph(id='ablation-original-top5-chart', figure=init_orig_chart, config={'displayModeBar': False}, style={'height': '200px'})
|
| 408 |
], style={
|
| 409 |
'flex': '1', 'border': '2px solid #28a745', 'borderRadius': '12px',
|
| 410 |
'padding': '20px', 'textAlign': 'center', 'backgroundColor': 'white', 'width': '45%'
|
| 411 |
}),
|
| 412 |
|
| 413 |
# Center Divergence Indicator
|
| 414 |
+
html.Div(init_divergence, id='ablation-divergence-indicator', style={
|
| 415 |
'width': '60px', 'display': 'flex', 'flexDirection': 'column',
|
| 416 |
'alignItems': 'center', 'justifyContent': 'center'
|
| 417 |
}),
|
|
|
|
| 423 |
'borderRadius': '16px', 'fontWeight': 'bold', 'fontSize': '12px',
|
| 424 |
'display': 'inline-block', 'marginBottom': '15px'
|
| 425 |
}),
|
| 426 |
+
html.Div(init_abl_map, id='ablation-ablated-token-map', style={'fontSize': '12px', 'color': '#6c757d', 'marginBottom': '10px', 'minHeight': '40px', 'textAlign': 'left', 'lineHeight': '1.5'}),
|
| 427 |
+
html.Div(init_abl_text, id='ablation-ablated-text-box', style={
|
| 428 |
'backgroundColor': '#f8f9fa', 'border': '1px solid #dee2e6', 'borderRadius': '8px',
|
| 429 |
'padding': '15px', 'fontFamily': 'monospace', 'fontSize': '14px', 'minHeight': '80px', 'marginBottom': '15px', 'textAlign': 'left', 'whiteSpace': 'pre-wrap'
|
| 430 |
}),
|
| 431 |
html.Div("TOP 5 PREDICTIONS", style={'textAlign': 'center', 'fontWeight': 'bold', 'color': '#495057', 'fontSize': '12px', 'marginBottom': '10px'}),
|
| 432 |
+
dcc.Graph(id='ablation-ablated-top5-chart', figure=init_abl_chart, config={'displayModeBar': False}, style={'height': '200px'})
|
| 433 |
], style={
|
| 434 |
'flex': '1', 'border': '2px solid #dc3545', 'borderRadius': '12px',
|
| 435 |
'padding': '20px', 'textAlign': 'center', 'backgroundColor': 'white', 'width': '45%'
|
plans.md
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
|
| 3 |
Done:
|
| 4 |
- change attention to entire generated sequence
|
|
@@ -9,4 +10,5 @@ Done:
|
|
| 9 |
- output streaming for chatbot
|
| 10 |
- shorter, concise responses in system prompt
|
| 11 |
- add video links to glossary
|
| 12 |
-
- three blue one brown
|
|
|
|
|
|
| 1 |
+
To Do:
|
| 2 |
+
- generating with 1 output token doesn't show the "selected" beam, so users can't see the output from the model
|
| 3 |
|
| 4 |
Done:
|
| 5 |
- change attention to entire generated sequence
|
|
|
|
| 10 |
- output streaming for chatbot
|
| 11 |
- shorter, concise responses in system prompt
|
| 12 |
- add video links to glossary
|
| 13 |
+
- three blue one brown
|
| 14 |
+
- specs on what each attention head does
|