""" SAE Feature Explorer — Bokeh server entry point. Launch with: bokeh serve scripts/explorer --port 5006 --args --data ... Layout: Upper workspace — active steering & composition left : example presets sidebar center: patch explorer | active features tile strip | DynaDiff controls & output ── dashed divider ────────────────────────────────────────────────────────── Lower workspace — feature search & analysis left : CLIP search input + result cards center: feature activation MEI grid right : feature naming + cortical profile + SAE summary """ import os import random import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) # --------------------------------------------------------------------------- # Multi-session fix: Bokeh re-executes main.py for every new browser session # but Python caches imported submodules. Modules that create Bokeh model # instances at import time (widgets, all panels) must be cleared so each # session gets fresh instances. Data/logic modules (state, datasets, brain, # inference, rendering) are deliberately kept cached. # --------------------------------------------------------------------------- _pkg_dir = os.path.dirname(os.path.abspath(__file__)) _keep_stems = frozenset( ['state', 'datasets', 'args', 'brain', 'inference', 'rendering', 'steering_logic', 'feature_logic', 'feature_list_logic', 'clip_search_logic', '__init__']) def _clear_widget_modules(): for _k in list(sys.modules.keys()): _m = sys.modules.get(_k) if _m is None: continue _f = getattr(_m, '__file__', None) or '' if not _f.startswith(_pkg_dir) or _k == __name__: continue if os.path.basename(_f).split('.')[0] in _keep_stems: continue if '.' in _k: _par = sys.modules.get(_k.rsplit('.', 1)[0]) if _par is not None: try: delattr(_par, _k.rsplit('.', 1)[1]) except AttributeError: pass del sys.modules[_k] _clear_widget_modules() del _clear_widget_modules, _pkg_dir, _keep_stems from bokeh.io import curdoc from bokeh.layouts import column, row from bokeh.models import Div # ── Global CSS theme ────────────────────────────────────────────── _theme_css = Div(text="""""", visible=False) from .datasets import load_all_datasets from .state import _all_datasets, active_ds if not _all_datasets: load_all_datasets() from .brain import HAS_DYNADIFF from .widgets import make_collapsible from . import widgets from .panels import feature as feature_panel from .panels import feature_list as flist_panel from .panels import steering as steer_panel from .panels import clip_search as clip_panel from .panels import examples as examples_panel from .inference import warmup_gpu_runner # ---------- SAE Summary div ---------- def _make_summary_html() -> str: ds = active_ds() n_umap_act = int(ds['live_mask'].sum()) n_truly_active = int((ds['freq'] > 0).sum()) n_dead = ds['d_model'] - n_truly_active tok_label = f"{ds['patch_grid']}×{ds['patch_grid']} = {ds['patch_grid']**2} patches" backbone_label = ds.get('backbone', 'dinov2').upper() sae_url = ds.get('sae_url') dl_row = (f'SAE weights' f'⬇ Download' if sae_url else '') return f"""
SAE Summary
{dl_row}
Active model{ds['label']}
Backbone{backbone_label}
Dictionary size{ds['d_model']:,}
Active (fired ≥1){n_truly_active:,} ({100*n_truly_active/ds['d_model']:.1f}%)
Dead{n_dead:,} ({100*n_dead/ds['d_model']:.1f}%)
Images{ds['n_images']:,}
Tokens/image{tok_label}
""" summary_div = Div(text=_make_summary_html(), width=600) # ---------- Global feature navigation callbacks ---------- def _on_go_click(): try: feat = int(widgets.feature_input.value) if 0 <= feat < active_ds()['d_model']: feature_panel.select_and_display(feat) else: feature_panel.stats_div.text = ( f"

Feature {feat} out of range (0–{active_ds()['d_model']-1})

") except ValueError: feature_panel.stats_div.text = "

Please enter a valid integer

" widgets.go_button.on_click(_on_go_click) def _on_random(): active = active_ds()['active_feats'] if not active: return feat = random.choice(active) widgets.feature_input.value = str(feat) feature_panel.select_and_display(feat) widgets.random_btn.on_click(_on_random) # ---------- JS bridge for gallery / active-feature tile clicks ---------- # Installed once at document_ready so all gallery HTML tiles can call back. # We extend the bridge JS with the patch-load bridge widget reference. from bokeh.models import CustomJS as _CustomJS _full_bridge_js = _CustomJS( args=dict( feat_inp=flist_panel.gallery_bridge_input, page_inp=flist_panel.gallery_page_input, patch_inp=steer_panel.patch_load_bridge, ), code=""" window._sae_select_feature = function(feat_idx) { feat_inp.value = String(feat_idx) + '|' + Date.now(); }; window._sae_gallery_page = function(page_num) { page_inp.value = String(page_num) + '|' + Date.now(); }; window._sae_load_patch_image = function(img_idx) { patch_inp.value = String(img_idx) + '|' + Date.now(); }; """, ) curdoc().js_on_event('document_ready', _full_bridge_js) # ============================================================ # UPPER WORKSPACE — Active Steering & Composition # ============================================================ # Left sidebar: example presets _upper_left = column( examples_panel.examples_panel, width=210, styles={"border-right": "1px solid var(--card-border, #e2e5ea)", "padding-right": "10px", "margin-right": "8px", "min-height": "400px"}, ) # Patch column: GT brain above, then patch explorer _patch_header = Div( text='
GT Brain Response
', width=410, ) _patch_col = column( _patch_header, steer_panel.gt_brain_div, steer_panel.patch_explorer_panel, styles={"background": "var(--card-bg, #fff)", "border": "1px solid var(--card-border, #e2e5ea)", "border-radius": "8px", "padding": "12px", "box-shadow": "0 1px 3px rgba(0,0,0,0.06)"}, ) # Active features column: steering sum brain above, then feature tiles _active_header = Div( text='
Steering Direction
', width=460, ) _active_features_header = Div( text='
' 'Active Features
', width=460, ) _active_column = column( _active_header, steer_panel.steer_brain_div, _active_features_header, steer_panel.active_features_div, width=460, styles={"background": "var(--card-bg, #fff)", "border": "1px solid var(--card-border, #e2e5ea)", "border-radius": "8px", "padding": "12px", "box-shadow": "0 1px 3px rgba(0,0,0,0.06)"}, ) # DynaDiff column: steered brain above output + controls if HAS_DYNADIFF: _dd_controls_header = Div( text='
' 'Expected Steered Brain
', width=480, ) _dd_run_header = Div( text='
' 'Brain Steering
', width=480, ) _dd_controls = column( _dd_controls_header, steer_panel.steered_brain_div, _dd_run_header, steer_panel.dynadiff_panel, width=480, styles={"background": "var(--card-bg, #fff)", "border": "1px solid var(--card-border, #e2e5ea)", "border-radius": "8px", "padding": "12px", "box-shadow": "0 1px 3px rgba(0,0,0,0.06)"}, ) else: _dd_controls = Div(text="", width=1) _upper_center = row( _patch_col, _active_column, _dd_controls, styles={"gap": "12px"}, ) upper_workspace = row(_upper_left, _upper_center) # ============================================================ # DIVIDER # ============================================================ divider = Div( text="
", width=1500, ) # ============================================================ # LOWER WORKSPACE — Feature Search & Analysis # ============================================================ # Left: CLIP search + gallery if not clip_panel.clip_unavailable: _clip_row = row( clip_panel.clip_query_input, clip_panel.clip_search_btn, styles={"margin-bottom": "6px", "align-items": "end"}, ) else: _clip_row = Div( text="" "CLIP search unavailable", width=300, ) _search_header = Div( text='
' 'Feature Search
', width=300, ) _clip_results = clip_panel.clip_results_div if not clip_panel.clip_unavailable else Div(text="", width=1) lower_left = column( _search_header, _clip_row, _clip_results, # flist_panel.sort_select, # re-enable for CLIP × φ sort flist_panel.gallery_div, flist_panel.gallery_bridge_input, flist_panel.gallery_page_input, width=340, styles={"background": "var(--card-bg, #fff)", "border": "1px solid var(--card-border, #e2e5ea)", "border-radius": "8px", "padding": "12px", "box-shadow": "0 1px 3px rgba(0,0,0,0.06)", "margin-right": "12px"}, ) # Center: feature name + Add to Steer + labeler + MEI gallery _zoom_controls = row( widgets.zoom_slider, widgets.heatmap_alpha_slider, styles={"gap": "16px", "padding": "4px 0 8px 0"}, ) _feature_header_row = row( feature_panel.stats_div, feature_panel.add_steer_btn, styles={"align-items": "center"}, ) _labeler_row = row( flist_panel.gemini_btn, flist_panel.gemini_status_div, styles={"align-items": "center", "margin-bottom": "4px"}, ) lower_center = column( feature_panel.status_div, _feature_header_row, flist_panel.name_input, _labeler_row, _zoom_controls, feature_panel.top_heatmap_div, width=700, styles={"background": "var(--card-bg, #fff)", "border": "1px solid var(--card-border, #e2e5ea)", "border-radius": "8px", "padding": "14px 16px", "box-shadow": "0 1px 3px rgba(0,0,0,0.06)"}, ) # Right: brain profile + summary lower_right = column( feature_panel.brain_div, make_collapsible("SAE Summary", summary_div), width=580, styles={"background": "var(--card-bg, #fff)", "border": "1px solid var(--card-border, #e2e5ea)", "border-radius": "8px", "padding": "14px 16px", "box-shadow": "0 1px 3px rgba(0,0,0,0.06)", "margin-left": "12px"}, ) lower_workspace = row(lower_left, lower_center, lower_right) # ============================================================ # ROOT LAYOUT # ============================================================ layout = column(_theme_css, upper_workspace, divider, lower_workspace, steer_panel.patch_load_bridge, styles={"padding": "12px"}) curdoc().add_root(layout) curdoc().title = "SAE Feature Explorer" print("Explorer app ready!") warmup_gpu_runner()