Marlin Lee
UI redesign: consistent design system, card layout, unified colors and typography
4a5ed93
"""
SAE Feature Explorer — Bokeh server entry point.
Launch with: bokeh serve scripts/explorer --port 5006 --args --data ...
Layout:
Upper workspace — active steering & composition
left : example presets sidebar
center: patch explorer | active features tile strip | DynaDiff controls & output
── dashed divider ──────────────────────────────────────────────────────────
Lower workspace — feature search & analysis
left : CLIP search input + result cards
center: feature activation MEI grid
right : feature naming + cortical profile + SAE summary
"""
import os
import random
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
# ---------------------------------------------------------------------------
# Multi-session fix: Bokeh re-executes main.py for every new browser session
# but Python caches imported submodules. Modules that create Bokeh model
# instances at import time (widgets, all panels) must be cleared so each
# session gets fresh instances. Data/logic modules (state, datasets, brain,
# inference, rendering) are deliberately kept cached.
# ---------------------------------------------------------------------------
_pkg_dir = os.path.dirname(os.path.abspath(__file__))
_keep_stems = frozenset(
['state', 'datasets', 'args', 'brain', 'inference', 'rendering',
'steering_logic', 'feature_logic', 'feature_list_logic',
'clip_search_logic', '__init__'])
def _clear_widget_modules():
for _k in list(sys.modules.keys()):
_m = sys.modules.get(_k)
if _m is None:
continue
_f = getattr(_m, '__file__', None) or ''
if not _f.startswith(_pkg_dir) or _k == __name__:
continue
if os.path.basename(_f).split('.')[0] in _keep_stems:
continue
if '.' in _k:
_par = sys.modules.get(_k.rsplit('.', 1)[0])
if _par is not None:
try:
delattr(_par, _k.rsplit('.', 1)[1])
except AttributeError:
pass
del sys.modules[_k]
_clear_widget_modules()
del _clear_widget_modules, _pkg_dir, _keep_stems
from bokeh.io import curdoc
from bokeh.layouts import column, row
from bokeh.models import Div
# ── Global CSS theme ──────────────────────────────────────────────
_theme_css = Div(text="""<style>
:root {
--bg: #f0f2f5;
--card-bg: #ffffff;
--card-border: #e2e5ea;
--card-shadow: 0 1px 3px rgba(0,0,0,0.06), 0 1px 2px rgba(0,0,0,0.04);
--accent: #2563eb;
--accent-hover: #1d4ed8;
--accent-light: #eff4ff;
--text-primary: #1a1d23;
--text-secondary: #4b5563;
--text-muted: #9ca3af;
--destructive: #dc2626;
--success: #059669;
--warning: #d97706;
--section-header: 600;
--radius: 8px;
--radius-sm: 6px;
}
body, .bk-root {
font-family: system-ui, -apple-system, 'Segoe UI', Roboto, sans-serif !important;
background: var(--bg) !important;
color: var(--text-primary);
}
/* Card container utility */
.sae-card {
background: var(--card-bg);
border: 1px solid var(--card-border);
border-radius: var(--radius);
box-shadow: var(--card-shadow);
padding: 14px 16px;
}
.sae-card-header {
font-size: 13px;
font-weight: var(--section-header);
color: var(--text-secondary);
text-transform: uppercase;
letter-spacing: 0.03em;
margin: 0 0 10px 0;
padding-bottom: 6px;
border-bottom: 1px solid var(--card-border);
}
/* Section header */
.sae-section-title {
font-size: 15px;
font-weight: var(--section-header);
color: var(--text-primary);
margin: 0 0 8px 0;
}
/* Feature number badge */
.sae-feat-num {
font-family: 'SF Mono', 'Fira Code', 'Cascadia Code', monospace;
font-size: 11px;
color: var(--text-muted);
background: #f3f4f6;
padding: 1px 5px;
border-radius: 3px;
}
/* Primary button override */
.bk-btn-primary {
background-color: var(--accent) !important;
border-color: var(--accent) !important;
border-radius: var(--radius-sm) !important;
font-weight: 500 !important;
font-size: 13px !important;
}
.bk-btn-primary:hover {
background-color: var(--accent-hover) !important;
border-color: var(--accent-hover) !important;
}
/* Success button override */
.bk-btn-success {
background-color: var(--accent) !important;
border-color: var(--accent) !important;
border-radius: var(--radius-sm) !important;
font-weight: 500 !important;
font-size: 13px !important;
}
.bk-btn-success:hover {
background-color: var(--accent-hover) !important;
border-color: var(--accent-hover) !important;
}
/* Warning (secondary) button override */
.bk-btn-warning {
background-color: transparent !important;
border: 1.5px solid var(--card-border) !important;
color: var(--text-secondary) !important;
border-radius: var(--radius-sm) !important;
font-weight: 500 !important;
font-size: 13px !important;
}
.bk-btn-warning:hover {
background-color: #f9fafb !important;
border-color: var(--accent) !important;
color: var(--accent) !important;
}
/* Light button */
.bk-btn-light {
border-radius: var(--radius-sm) !important;
font-size: 13px !important;
font-weight: 500 !important;
}
/* Default button */
.bk-btn-default {
border-radius: var(--radius-sm) !important;
font-size: 13px !important;
}
/* Slider labels */
.bk-Slider .bk-slider-title {
font-size: 12px !important;
color: var(--text-secondary) !important;
font-weight: 500 !important;
}
/* Text inputs */
.bk-input {
border-radius: var(--radius-sm) !important;
border-color: var(--card-border) !important;
font-size: 13px !important;
}
.bk-input:focus {
border-color: var(--accent) !important;
box-shadow: 0 0 0 2px var(--accent-light) !important;
}
.bk-input-group > label {
font-size: 12px !important;
font-weight: 500 !important;
color: var(--text-secondary) !important;
}
/* DataTable styling */
.bk-data-table {
border-radius: var(--radius) !important;
overflow: hidden;
}
/* Patch figure */
.bk-Figure {
border-radius: var(--radius) !important;
overflow: hidden;
}
</style>""", visible=False)
from .datasets import load_all_datasets
from .state import _all_datasets, active_ds
if not _all_datasets:
load_all_datasets()
from .brain import HAS_DYNADIFF
from .widgets import make_collapsible
from . import widgets
from .panels import feature as feature_panel
from .panels import feature_list as flist_panel
from .panels import steering as steer_panel
from .panels import clip_search as clip_panel
from .panels import examples as examples_panel
from .inference import warmup_gpu_runner
# ---------- SAE Summary div ----------
def _make_summary_html() -> str:
ds = active_ds()
n_umap_act = int(ds['live_mask'].sum())
n_truly_active = int((ds['freq'] > 0).sum())
n_dead = ds['d_model'] - n_truly_active
tok_label = f"{ds['patch_grid']}×{ds['patch_grid']} = {ds['patch_grid']**2} patches"
backbone_label = ds.get('backbone', 'dinov2').upper()
sae_url = ds.get('sae_url')
dl_row = (f'<tr><td style="padding-right:12px;font-weight:500">SAE weights</td>'
f'<td><a href="{sae_url}" download style="color:#2563eb;text-decoration:none;'
f'font-weight:500">⬇ Download</a></td></tr>'
if sae_url else '')
return f"""
<div class="sae-card" style="margin-bottom:8px;">
<div class="sae-card-header">SAE Summary</div>
<table style="font-size:13px;line-height:1.7;color:#4b5563">
<tr><td style="padding-right:12px;font-weight:500">Active model</td><td><b style="color:#2563eb">{ds['label']}</b></td></tr>
<tr><td style="padding-right:12px;font-weight:500">Backbone</td><td>{backbone_label}</td></tr>
<tr><td style="padding-right:12px;font-weight:500">Dictionary size</td><td>{ds['d_model']:,}</td></tr>
<tr><td style="padding-right:12px;font-weight:500">Active (fired ≥1)</td><td>{n_truly_active:,} ({100*n_truly_active/ds['d_model']:.1f}%)</td></tr>
<tr><td style="padding-right:12px;font-weight:500">Dead</td><td>{n_dead:,} ({100*n_dead/ds['d_model']:.1f}%)</td></tr>
<tr><td style="padding-right:12px;font-weight:500">Images</td><td>{ds['n_images']:,}</td></tr>
<tr><td style="padding-right:12px;font-weight:500">Tokens/image</td><td>{tok_label}</td></tr>
{dl_row}
</table>
</div>"""
summary_div = Div(text=_make_summary_html(), width=600)
# ---------- Global feature navigation callbacks ----------
def _on_go_click():
try:
feat = int(widgets.feature_input.value)
if 0 <= feat < active_ds()['d_model']:
feature_panel.select_and_display(feat)
else:
feature_panel.stats_div.text = (
f"<h3>Feature {feat} out of range (0–{active_ds()['d_model']-1})</h3>")
except ValueError:
feature_panel.stats_div.text = "<h3>Please enter a valid integer</h3>"
widgets.go_button.on_click(_on_go_click)
def _on_random():
active = active_ds()['active_feats']
if not active:
return
feat = random.choice(active)
widgets.feature_input.value = str(feat)
feature_panel.select_and_display(feat)
widgets.random_btn.on_click(_on_random)
# ---------- JS bridge for gallery / active-feature tile clicks ----------
# Installed once at document_ready so all gallery HTML tiles can call back.
# We extend the bridge JS with the patch-load bridge widget reference.
from bokeh.models import CustomJS as _CustomJS
_full_bridge_js = _CustomJS(
args=dict(
feat_inp=flist_panel.gallery_bridge_input,
page_inp=flist_panel.gallery_page_input,
patch_inp=steer_panel.patch_load_bridge,
),
code="""
window._sae_select_feature = function(feat_idx) {
feat_inp.value = String(feat_idx) + '|' + Date.now();
};
window._sae_gallery_page = function(page_num) {
page_inp.value = String(page_num) + '|' + Date.now();
};
window._sae_load_patch_image = function(img_idx) {
patch_inp.value = String(img_idx) + '|' + Date.now();
};
""",
)
curdoc().js_on_event('document_ready', _full_bridge_js)
# ============================================================
# UPPER WORKSPACE — Active Steering & Composition
# ============================================================
# Left sidebar: example presets
_upper_left = column(
examples_panel.examples_panel,
width=210,
styles={"border-right": "1px solid var(--card-border, #e2e5ea)",
"padding-right": "10px", "margin-right": "8px",
"min-height": "400px"},
)
# Patch column: GT brain above, then patch explorer
_patch_header = Div(
text='<div class="sae-card-header" style="border-bottom:none">GT Brain Response</div>',
width=410,
)
_patch_col = column(
_patch_header,
steer_panel.gt_brain_div,
steer_panel.patch_explorer_panel,
styles={"background": "var(--card-bg, #fff)",
"border": "1px solid var(--card-border, #e2e5ea)",
"border-radius": "8px", "padding": "12px",
"box-shadow": "0 1px 3px rgba(0,0,0,0.06)"},
)
# Active features column: steering sum brain above, then feature tiles
_active_header = Div(
text='<div class="sae-card-header" style="border-bottom:none">Steering Direction</div>',
width=460,
)
_active_features_header = Div(
text='<div class="sae-card-header" style="border-bottom:none;margin-top:8px">'
'Active Features</div>',
width=460,
)
_active_column = column(
_active_header,
steer_panel.steer_brain_div,
_active_features_header,
steer_panel.active_features_div,
width=460,
styles={"background": "var(--card-bg, #fff)",
"border": "1px solid var(--card-border, #e2e5ea)",
"border-radius": "8px", "padding": "12px",
"box-shadow": "0 1px 3px rgba(0,0,0,0.06)"},
)
# DynaDiff column: steered brain above output + controls
if HAS_DYNADIFF:
_dd_controls_header = Div(
text='<div class="sae-card-header" style="border-bottom:none">'
'Expected Steered Brain</div>',
width=480,
)
_dd_run_header = Div(
text='<div class="sae-card-header" style="border-bottom:none;margin-top:8px">'
'Brain Steering</div>',
width=480,
)
_dd_controls = column(
_dd_controls_header,
steer_panel.steered_brain_div,
_dd_run_header,
steer_panel.dynadiff_panel,
width=480,
styles={"background": "var(--card-bg, #fff)",
"border": "1px solid var(--card-border, #e2e5ea)",
"border-radius": "8px", "padding": "12px",
"box-shadow": "0 1px 3px rgba(0,0,0,0.06)"},
)
else:
_dd_controls = Div(text="", width=1)
_upper_center = row(
_patch_col,
_active_column,
_dd_controls,
styles={"gap": "12px"},
)
upper_workspace = row(_upper_left, _upper_center)
# ============================================================
# DIVIDER
# ============================================================
divider = Div(
text="<hr style='border:none;border-top:1px solid #e2e5ea;margin:16px 0'>",
width=1500,
)
# ============================================================
# LOWER WORKSPACE — Feature Search & Analysis
# ============================================================
# Left: CLIP search + gallery
if not clip_panel.clip_unavailable:
_clip_row = row(
clip_panel.clip_query_input,
clip_panel.clip_search_btn,
styles={"margin-bottom": "6px", "align-items": "end"},
)
else:
_clip_row = Div(
text="<i style='color:var(--text-muted, #9ca3af);font-size:11px'>"
"CLIP search unavailable</i>",
width=300,
)
_search_header = Div(
text='<div class="sae-card-header" style="border-bottom:none">'
'Feature Search</div>',
width=300,
)
_clip_results = clip_panel.clip_results_div if not clip_panel.clip_unavailable else Div(text="", width=1)
lower_left = column(
_search_header,
_clip_row,
_clip_results,
# flist_panel.sort_select, # re-enable for CLIP × φ sort
flist_panel.gallery_div,
flist_panel.gallery_bridge_input,
flist_panel.gallery_page_input,
width=340,
styles={"background": "var(--card-bg, #fff)",
"border": "1px solid var(--card-border, #e2e5ea)",
"border-radius": "8px", "padding": "12px",
"box-shadow": "0 1px 3px rgba(0,0,0,0.06)",
"margin-right": "12px"},
)
# Center: feature name + Add to Steer + labeler + MEI gallery
_zoom_controls = row(
widgets.zoom_slider, widgets.heatmap_alpha_slider,
styles={"gap": "16px", "padding": "4px 0 8px 0"},
)
_feature_header_row = row(
feature_panel.stats_div,
feature_panel.add_steer_btn,
styles={"align-items": "center"},
)
_labeler_row = row(
flist_panel.gemini_btn, flist_panel.gemini_status_div,
styles={"align-items": "center", "margin-bottom": "4px"},
)
lower_center = column(
feature_panel.status_div,
_feature_header_row,
flist_panel.name_input,
_labeler_row,
_zoom_controls,
feature_panel.top_heatmap_div,
width=700,
styles={"background": "var(--card-bg, #fff)",
"border": "1px solid var(--card-border, #e2e5ea)",
"border-radius": "8px", "padding": "14px 16px",
"box-shadow": "0 1px 3px rgba(0,0,0,0.06)"},
)
# Right: brain profile + summary
lower_right = column(
feature_panel.brain_div,
make_collapsible("SAE Summary", summary_div),
width=580,
styles={"background": "var(--card-bg, #fff)",
"border": "1px solid var(--card-border, #e2e5ea)",
"border-radius": "8px", "padding": "14px 16px",
"box-shadow": "0 1px 3px rgba(0,0,0,0.06)",
"margin-left": "12px"},
)
lower_workspace = row(lower_left, lower_center, lower_right)
# ============================================================
# ROOT LAYOUT
# ============================================================
layout = column(_theme_css, upper_workspace, divider, lower_workspace,
steer_panel.patch_load_bridge,
styles={"padding": "12px"})
curdoc().add_root(layout)
curdoc().title = "SAE Feature Explorer"
print("Explorer app ready!")
warmup_gpu_runner()