File size: 16,474 Bytes
4c1c394
 
 
 
 
62166b9
 
 
 
 
 
 
 
 
4c1c394
 
 
62166b9
4c1c394
 
 
 
7911b1a
 
 
 
 
62166b9
7911b1a
8e457b7
 
fb9c7be
 
 
8e457b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ed3016
8e457b7
8ed3016
 
8e457b7
 
 
 
 
7911b1a
4c1c394
 
 
 
4a5ed93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c1c394
7911b1a
 
 
4c1c394
 
fb9c7be
4c1c394
 
 
 
fb9c7be
4c1c394
62166b9
4c1c394
 
 
 
62166b9
4c1c394
 
 
 
 
 
 
 
 
4a5ed93
 
 
4c1c394
 
4a5ed93
 
 
 
 
 
 
 
 
 
4c1c394
 
 
 
62166b9
4c1c394
 
62166b9
4c1c394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62166b9
 
 
4c1c394
62166b9
 
 
 
 
fb9c7be
62166b9
 
 
 
 
 
 
 
 
 
 
 
4c1c394
62166b9
4c1c394
62166b9
 
 
 
 
 
 
 
 
4a5ed93
 
 
4c1c394
 
62166b9
 
4a5ed93
62166b9
 
 
 
fb9c7be
 
4a5ed93
 
 
 
4c1c394
 
62166b9
 
4a5ed93
62166b9
 
 
4a5ed93
62166b9
 
 
 
 
fb9c7be
62166b9
fb9c7be
62166b9
4a5ed93
 
 
 
62166b9
 
 
 
 
4a5ed93
62166b9
 
 
 
4a5ed93
62166b9
 
 
 
 
fb9c7be
62166b9
fb9c7be
62166b9
4a5ed93
 
 
 
62166b9
 
 
 
 
 
 
 
4a5ed93
62166b9
 
 
 
 
 
 
 
 
 
4a5ed93
62166b9
 
 
 
 
 
 
 
 
 
 
 
 
4a5ed93
62166b9
 
 
4a5ed93
 
62166b9
 
 
 
4a5ed93
62166b9
 
 
 
 
 
 
 
 
 
1361d88
62166b9
 
 
 
4a5ed93
 
 
 
 
62166b9
 
 
4a5ed93
 
 
 
62166b9
4c1c394
62166b9
4a5ed93
 
 
 
 
 
4c1c394
 
62166b9
 
 
 
4a5ed93
62166b9
 
 
4a5ed93
 
 
 
4c1c394
 
62166b9
 
 
 
 
4a5ed93
 
 
 
 
4c1c394
 
62166b9
 
 
 
 
 
 
4a5ed93
 
 
4c1c394
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
"""
SAE Feature Explorer β€” Bokeh server entry point.

Launch with:  bokeh serve scripts/explorer --port 5006 --args --data ...

Layout:
  Upper workspace  β€” active steering & composition
    left  : example presets sidebar
    center: patch explorer | active features tile strip | DynaDiff controls & output
  ── dashed divider ──────────────────────────────────────────────────────────
  Lower workspace  β€” feature search & analysis
    left  : CLIP search input + result cards
    center: feature activation MEI grid
    right : feature naming + cortical profile + SAE summary
"""

import os
import random
import sys

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))

# ---------------------------------------------------------------------------
# Multi-session fix: Bokeh re-executes main.py for every new browser session
# but Python caches imported submodules.  Modules that create Bokeh model
# instances at import time (widgets, all panels) must be cleared so each
# session gets fresh instances.  Data/logic modules (state, datasets, brain,
# inference, rendering) are deliberately kept cached.
# ---------------------------------------------------------------------------
_pkg_dir = os.path.dirname(os.path.abspath(__file__))
_keep_stems = frozenset(
    ['state', 'datasets', 'args', 'brain', 'inference', 'rendering',
     'steering_logic', 'feature_logic', 'feature_list_logic',
     'clip_search_logic', '__init__'])


def _clear_widget_modules():
    for _k in list(sys.modules.keys()):
        _m = sys.modules.get(_k)
        if _m is None:
            continue
        _f = getattr(_m, '__file__', None) or ''
        if not _f.startswith(_pkg_dir) or _k == __name__:
            continue
        if os.path.basename(_f).split('.')[0] in _keep_stems:
            continue
        if '.' in _k:
            _par = sys.modules.get(_k.rsplit('.', 1)[0])
            if _par is not None:
                try:
                    delattr(_par, _k.rsplit('.', 1)[1])
                except AttributeError:
                    pass
        del sys.modules[_k]


_clear_widget_modules()
del _clear_widget_modules, _pkg_dir, _keep_stems

from bokeh.io import curdoc
from bokeh.layouts import column, row
from bokeh.models import Div

# ── Global CSS theme ──────────────────────────────────────────────
_theme_css = Div(text="""<style>
:root {
  --bg: #f0f2f5;
  --card-bg: #ffffff;
  --card-border: #e2e5ea;
  --card-shadow: 0 1px 3px rgba(0,0,0,0.06), 0 1px 2px rgba(0,0,0,0.04);
  --accent: #2563eb;
  --accent-hover: #1d4ed8;
  --accent-light: #eff4ff;
  --text-primary: #1a1d23;
  --text-secondary: #4b5563;
  --text-muted: #9ca3af;
  --destructive: #dc2626;
  --success: #059669;
  --warning: #d97706;
  --section-header: 600;
  --radius: 8px;
  --radius-sm: 6px;
}
body, .bk-root {
  font-family: system-ui, -apple-system, 'Segoe UI', Roboto, sans-serif !important;
  background: var(--bg) !important;
  color: var(--text-primary);
}
/* Card container utility */
.sae-card {
  background: var(--card-bg);
  border: 1px solid var(--card-border);
  border-radius: var(--radius);
  box-shadow: var(--card-shadow);
  padding: 14px 16px;
}
.sae-card-header {
  font-size: 13px;
  font-weight: var(--section-header);
  color: var(--text-secondary);
  text-transform: uppercase;
  letter-spacing: 0.03em;
  margin: 0 0 10px 0;
  padding-bottom: 6px;
  border-bottom: 1px solid var(--card-border);
}
/* Section header */
.sae-section-title {
  font-size: 15px;
  font-weight: var(--section-header);
  color: var(--text-primary);
  margin: 0 0 8px 0;
}
/* Feature number badge */
.sae-feat-num {
  font-family: 'SF Mono', 'Fira Code', 'Cascadia Code', monospace;
  font-size: 11px;
  color: var(--text-muted);
  background: #f3f4f6;
  padding: 1px 5px;
  border-radius: 3px;
}
/* Primary button override */
.bk-btn-primary {
  background-color: var(--accent) !important;
  border-color: var(--accent) !important;
  border-radius: var(--radius-sm) !important;
  font-weight: 500 !important;
  font-size: 13px !important;
}
.bk-btn-primary:hover {
  background-color: var(--accent-hover) !important;
  border-color: var(--accent-hover) !important;
}
/* Success button override */
.bk-btn-success {
  background-color: var(--accent) !important;
  border-color: var(--accent) !important;
  border-radius: var(--radius-sm) !important;
  font-weight: 500 !important;
  font-size: 13px !important;
}
.bk-btn-success:hover {
  background-color: var(--accent-hover) !important;
  border-color: var(--accent-hover) !important;
}
/* Warning (secondary) button override */
.bk-btn-warning {
  background-color: transparent !important;
  border: 1.5px solid var(--card-border) !important;
  color: var(--text-secondary) !important;
  border-radius: var(--radius-sm) !important;
  font-weight: 500 !important;
  font-size: 13px !important;
}
.bk-btn-warning:hover {
  background-color: #f9fafb !important;
  border-color: var(--accent) !important;
  color: var(--accent) !important;
}
/* Light button */
.bk-btn-light {
  border-radius: var(--radius-sm) !important;
  font-size: 13px !important;
  font-weight: 500 !important;
}
/* Default button */
.bk-btn-default {
  border-radius: var(--radius-sm) !important;
  font-size: 13px !important;
}
/* Slider labels */
.bk-Slider .bk-slider-title {
  font-size: 12px !important;
  color: var(--text-secondary) !important;
  font-weight: 500 !important;
}
/* Text inputs */
.bk-input {
  border-radius: var(--radius-sm) !important;
  border-color: var(--card-border) !important;
  font-size: 13px !important;
}
.bk-input:focus {
  border-color: var(--accent) !important;
  box-shadow: 0 0 0 2px var(--accent-light) !important;
}
.bk-input-group > label {
  font-size: 12px !important;
  font-weight: 500 !important;
  color: var(--text-secondary) !important;
}
/* DataTable styling */
.bk-data-table {
  border-radius: var(--radius) !important;
  overflow: hidden;
}
/* Patch figure */
.bk-Figure {
  border-radius: var(--radius) !important;
  overflow: hidden;
}
</style>""", visible=False)

from .datasets import load_all_datasets
from .state import _all_datasets, active_ds
if not _all_datasets:
    load_all_datasets()

from .brain import HAS_DYNADIFF
from .widgets import make_collapsible
from . import widgets

from .panels import feature as feature_panel
from .panels import feature_list as flist_panel
from .panels import steering as steer_panel
from .panels import clip_search as clip_panel
from .panels import examples as examples_panel

from .inference import warmup_gpu_runner


# ---------- SAE Summary div ----------

def _make_summary_html() -> str:
    ds = active_ds()
    n_umap_act     = int(ds['live_mask'].sum())
    n_truly_active = int((ds['freq'] > 0).sum())
    n_dead         = ds['d_model'] - n_truly_active
    tok_label      = f"{ds['patch_grid']}Γ—{ds['patch_grid']} = {ds['patch_grid']**2} patches"
    backbone_label = ds.get('backbone', 'dinov2').upper()
    sae_url        = ds.get('sae_url')
    dl_row         = (f'<tr><td style="padding-right:12px;font-weight:500">SAE weights</td>'
                      f'<td><a href="{sae_url}" download style="color:#2563eb;text-decoration:none;'
                      f'font-weight:500">⬇ Download</a></td></tr>'
                      if sae_url else '')
    return f"""
<div class="sae-card" style="margin-bottom:8px;">
<div class="sae-card-header">SAE Summary</div>
<table style="font-size:13px;line-height:1.7;color:#4b5563">
<tr><td style="padding-right:12px;font-weight:500">Active model</td><td><b style="color:#2563eb">{ds['label']}</b></td></tr>
<tr><td style="padding-right:12px;font-weight:500">Backbone</td><td>{backbone_label}</td></tr>
<tr><td style="padding-right:12px;font-weight:500">Dictionary size</td><td>{ds['d_model']:,}</td></tr>
<tr><td style="padding-right:12px;font-weight:500">Active (fired β‰₯1)</td><td>{n_truly_active:,} ({100*n_truly_active/ds['d_model']:.1f}%)</td></tr>
<tr><td style="padding-right:12px;font-weight:500">Dead</td><td>{n_dead:,} ({100*n_dead/ds['d_model']:.1f}%)</td></tr>
<tr><td style="padding-right:12px;font-weight:500">Images</td><td>{ds['n_images']:,}</td></tr>
<tr><td style="padding-right:12px;font-weight:500">Tokens/image</td><td>{tok_label}</td></tr>
{dl_row}
</table>
</div>"""

summary_div = Div(text=_make_summary_html(), width=600)


# ---------- Global feature navigation callbacks ----------

def _on_go_click():
    try:
        feat = int(widgets.feature_input.value)
        if 0 <= feat < active_ds()['d_model']:
            feature_panel.select_and_display(feat)
        else:
            feature_panel.stats_div.text = (
                f"<h3>Feature {feat} out of range (0–{active_ds()['d_model']-1})</h3>")
    except ValueError:
        feature_panel.stats_div.text = "<h3>Please enter a valid integer</h3>"

widgets.go_button.on_click(_on_go_click)


def _on_random():
    active = active_ds()['active_feats']
    if not active:
        return
    feat = random.choice(active)
    widgets.feature_input.value = str(feat)
    feature_panel.select_and_display(feat)

widgets.random_btn.on_click(_on_random)


# ---------- JS bridge for gallery / active-feature tile clicks ----------
# Installed once at document_ready so all gallery HTML tiles can call back.
# We extend the bridge JS with the patch-load bridge widget reference.

from bokeh.models import CustomJS as _CustomJS
_full_bridge_js = _CustomJS(
    args=dict(
        feat_inp=flist_panel.gallery_bridge_input,
        page_inp=flist_panel.gallery_page_input,
        patch_inp=steer_panel.patch_load_bridge,
    ),
    code="""
        window._sae_select_feature = function(feat_idx) {
            feat_inp.value = String(feat_idx) + '|' + Date.now();
        };
        window._sae_gallery_page = function(page_num) {
            page_inp.value = String(page_num) + '|' + Date.now();
        };
        window._sae_load_patch_image = function(img_idx) {
            patch_inp.value = String(img_idx) + '|' + Date.now();
        };
    """,
)
curdoc().js_on_event('document_ready', _full_bridge_js)


# ============================================================
# UPPER WORKSPACE β€” Active Steering & Composition
# ============================================================

# Left sidebar: example presets
_upper_left = column(
    examples_panel.examples_panel,
    width=210,
    styles={"border-right": "1px solid var(--card-border, #e2e5ea)",
            "padding-right": "10px", "margin-right": "8px",
            "min-height": "400px"},
)

# Patch column: GT brain above, then patch explorer
_patch_header = Div(
    text='<div class="sae-card-header" style="border-bottom:none">GT Brain Response</div>',
    width=410,
)
_patch_col = column(
    _patch_header,
    steer_panel.gt_brain_div,
    steer_panel.patch_explorer_panel,
    styles={"background": "var(--card-bg, #fff)",
            "border": "1px solid var(--card-border, #e2e5ea)",
            "border-radius": "8px", "padding": "12px",
            "box-shadow": "0 1px 3px rgba(0,0,0,0.06)"},
)

# Active features column: steering sum brain above, then feature tiles
_active_header = Div(
    text='<div class="sae-card-header" style="border-bottom:none">Steering Direction</div>',
    width=460,
)
_active_features_header = Div(
    text='<div class="sae-card-header" style="border-bottom:none;margin-top:8px">'
         'Active Features</div>',
    width=460,
)
_active_column = column(
    _active_header,
    steer_panel.steer_brain_div,
    _active_features_header,
    steer_panel.active_features_div,
    width=460,
    styles={"background": "var(--card-bg, #fff)",
            "border": "1px solid var(--card-border, #e2e5ea)",
            "border-radius": "8px", "padding": "12px",
            "box-shadow": "0 1px 3px rgba(0,0,0,0.06)"},
)

# DynaDiff column: steered brain above output + controls
if HAS_DYNADIFF:
    _dd_controls_header = Div(
        text='<div class="sae-card-header" style="border-bottom:none">'
             'Expected Steered Brain</div>',
        width=480,
    )
    _dd_run_header = Div(
        text='<div class="sae-card-header" style="border-bottom:none;margin-top:8px">'
             'Brain Steering</div>',
        width=480,
    )
    _dd_controls = column(
        _dd_controls_header,
        steer_panel.steered_brain_div,
        _dd_run_header,
        steer_panel.dynadiff_panel,
        width=480,
        styles={"background": "var(--card-bg, #fff)",
                "border": "1px solid var(--card-border, #e2e5ea)",
                "border-radius": "8px", "padding": "12px",
                "box-shadow": "0 1px 3px rgba(0,0,0,0.06)"},
    )
else:
    _dd_controls = Div(text="", width=1)

_upper_center = row(
    _patch_col,
    _active_column,
    _dd_controls,
    styles={"gap": "12px"},
)

upper_workspace = row(_upper_left, _upper_center)


# ============================================================
# DIVIDER
# ============================================================

divider = Div(
    text="<hr style='border:none;border-top:1px solid #e2e5ea;margin:16px 0'>",
    width=1500,
)


# ============================================================
# LOWER WORKSPACE β€” Feature Search & Analysis
# ============================================================

# Left: CLIP search + gallery
if not clip_panel.clip_unavailable:
    _clip_row = row(
        clip_panel.clip_query_input,
        clip_panel.clip_search_btn,
        styles={"margin-bottom": "6px", "align-items": "end"},
    )
else:
    _clip_row = Div(
        text="<i style='color:var(--text-muted, #9ca3af);font-size:11px'>"
             "CLIP search unavailable</i>",
        width=300,
    )

_search_header = Div(
    text='<div class="sae-card-header" style="border-bottom:none">'
         'Feature Search</div>',
    width=300,
)

_clip_results = clip_panel.clip_results_div if not clip_panel.clip_unavailable else Div(text="", width=1)

lower_left = column(
    _search_header,
    _clip_row,
    _clip_results,
    # flist_panel.sort_select,   # re-enable for CLIP Γ— Ο† sort
    flist_panel.gallery_div,
    flist_panel.gallery_bridge_input,
    flist_panel.gallery_page_input,
    width=340,
    styles={"background": "var(--card-bg, #fff)",
            "border": "1px solid var(--card-border, #e2e5ea)",
            "border-radius": "8px", "padding": "12px",
            "box-shadow": "0 1px 3px rgba(0,0,0,0.06)",
            "margin-right": "12px"},
)

# Center: feature name + Add to Steer + labeler + MEI gallery
_zoom_controls = row(
    widgets.zoom_slider, widgets.heatmap_alpha_slider,
    styles={"gap": "16px", "padding": "4px 0 8px 0"},
)
_feature_header_row = row(
    feature_panel.stats_div,
    feature_panel.add_steer_btn,
    styles={"align-items": "center"},
)

_labeler_row = row(
    flist_panel.gemini_btn, flist_panel.gemini_status_div,
    styles={"align-items": "center", "margin-bottom": "4px"},
)

lower_center = column(
    feature_panel.status_div,
    _feature_header_row,
    flist_panel.name_input,
    _labeler_row,
    _zoom_controls,
    feature_panel.top_heatmap_div,
    width=700,
    styles={"background": "var(--card-bg, #fff)",
            "border": "1px solid var(--card-border, #e2e5ea)",
            "border-radius": "8px", "padding": "14px 16px",
            "box-shadow": "0 1px 3px rgba(0,0,0,0.06)"},
)

# Right: brain profile + summary
lower_right = column(
    feature_panel.brain_div,
    make_collapsible("SAE Summary", summary_div),
    width=580,
    styles={"background": "var(--card-bg, #fff)",
            "border": "1px solid var(--card-border, #e2e5ea)",
            "border-radius": "8px", "padding": "14px 16px",
            "box-shadow": "0 1px 3px rgba(0,0,0,0.06)",
            "margin-left": "12px"},
)

lower_workspace = row(lower_left, lower_center, lower_right)


# ============================================================
# ROOT LAYOUT
# ============================================================

layout = column(_theme_css, upper_workspace, divider, lower_workspace,
                steer_panel.patch_load_bridge,
                styles={"padding": "12px"})
curdoc().add_root(layout)
curdoc().title = "SAE Feature Explorer"

print("Explorer app ready!")

warmup_gpu_runner()