Marlin Lee
UI redesign: consistent design system, card layout, unified colors and typography
4a5ed93
"""
Steering workspace UI: patch explorer + feature steering list + DynaDiff.
Combines patch-level feature exploration with brain-steering controls.
Pure computation lives in explorer.steering_logic; this module handles
Bokeh widgets, callbacks, and layout.
Exports (layout):
patch_explorer_panel β€” column: patch figure + info + results
gt_brain_div β€” Div: GT brain response
steer_brain_div β€” Div: steering direction brain map
steered_brain_div β€” Div: expected steered brain
active_features_div β€” Div: steering feature tile cards
dynadiff_panel β€” column: run button + status + output (or None)
Exports (JS bridges β€” needed by main.py):
patch_load_bridge β€” TextInput: JS -> Python image loading
feat_action_bridge β€” TextInput: JS -> Python remove/set_lam (or None)
Exports (public API for other panels):
add_feature(feat, lam, threshold)
set_preset(entries, label)
set_nsd_sample(basename)
"""
import threading
from bokeh.events import MouseMove
from bokeh.io import curdoc
from bokeh.layouts import column, row
from bokeh.models import (
Button, ColumnDataSource, CustomJS, Div, TextInput,
)
from bokeh.plotting import figure
from ..args import args
from ..state import active_ds
from ..steering_logic import (
compute_patch_activations, get_top_features_for_patches,
resolve_nsd_basename, parse_nsd_img_idx,
compute_steering_direction, compute_steered_fmri,
validate_feature, make_steering_entry,
validate_reconstruction, run_reconstruction, load_gt_fmri,
)
from ..brain import (
HAS_DYNADIFF,
render_fmri_brain_compact_b64,
)
from ..rendering import (
load_image, parse_img_label, pil_to_bokeh_rgba,
make_search_result_html, make_active_features_tile_html,
)
# ══════════════════════════════════════════════════════════════════
# Session state (reset each Bokeh session via module reimport)
# ══════════════════════════════════════════════════════════════════
class _Session:
"""Per-session mutable state for the steering workspace."""
img_idx = None # image index loaded in patch explorer
patch_z = None # (n_patches, d_sae) activations
nsd_basename = None # e.g. 'nsd_22910'
gt_fmri = None # raw fMRI (N_VOXELS,) array
_PATCH_FIG_PX = 400
# ══════════════════════════════════════════════════════════════════
# Patch Explorer
# ══════════════════════════════════════════════════════════════════
_patch_bg_source = ColumnDataSource(data=dict(
image=[], x=[0], y=[0], dw=[16], dh=[16],
))
_pg0 = active_ds()['patch_grid']
def _make_grid_source(pg: int) -> ColumnDataSource:
pr = [r for r in range(pg) for _ in range(pg)]
pc = list(range(pg)) * pg
return ColumnDataSource(data=dict(
x=[c + 0.5 for c in pc],
y=[pg - r - 0.5 for r in pr],
row=pr, col=pc,
))
_patch_grid_source = _make_grid_source(_pg0)
_patch_fig = figure(
width=_PATCH_FIG_PX, height=_PATCH_FIG_PX,
x_range=(0, _pg0), y_range=(0, _pg0),
tools=["tap", "reset"],
title="Click or drag to paint patch selection",
toolbar_location="above",
visible=False,
)
_paint_js = CustomJS(args=dict(source=_patch_grid_source, pg=_pg0), code="""
if (!window._patch_paint_init) {
window._patch_paint_init = true;
window._patch_btn_held = false;
document.addEventListener('mousedown', () => { window._patch_btn_held = true; });
document.addEventListener('mouseup', () => { window._patch_btn_held = false; });
}
if (!window._patch_btn_held) return;
const x = cb_obj.x, y = cb_obj.y;
if (x === null || y === null || x < 0 || x >= pg || y < 0 || y >= pg) return;
const col = Math.floor(x);
const row = pg - 1 - Math.floor(y);
const flat_idx = row * pg + col;
const sel = source.selected.indices.slice();
if (sel.indexOf(flat_idx) === -1) { sel.push(flat_idx); source.selected.indices = sel; }
""")
_patch_fig.js_on_event(MouseMove, _paint_js)
_patch_fig.image_rgba(
source=_patch_bg_source, image='image', x='x', y='y', dw='dw', dh='dh')
_patch_fig.rect(
source=_patch_grid_source, x='x', y='y', width=0.95, height=0.95,
fill_color='yellow', fill_alpha=0.0,
line_color='white', line_alpha=0.35, line_width=0.5,
selection_fill_color='red', selection_fill_alpha=0.45,
nonselection_fill_alpha=0.0, nonselection_line_alpha=0.35,
)
_patch_fig.axis.visible = False
_patch_fig.xgrid.visible = False
_patch_fig.ygrid.visible = False
_patch_results_div = Div(text="", width=310)
_patch_img_input = TextInput(title="Image Index:", value="0", width=120)
_load_patch_btn = Button(label="Load Image", width=90, button_type="primary")
_clear_patch_btn = Button(label="Clear", width=60)
_patch_info_div = Div(
text="<i>Click an image in the Feature Explorer to load it here.</i>",
width=310,
)
# JS bridge: gallery tile onclick -> window._sae_load_patch_image(idx)
patch_load_bridge = TextInput(value="", width=1, height=1, visible=False)
# ── Patch callbacks ──────────────────────────────────────────────
def _on_load_image():
try:
img_idx = parse_img_label(_patch_img_input.value)
except ValueError:
_patch_info_div.text = "<b style='color:red'>Invalid image index</b>"
return
ds = active_ds()
if not (0 <= img_idx < ds['n_images']):
_patch_info_div.text = (
f"<b style='color:red'>Index out of range "
f"(0–{ds['n_images'] - 1})</b>")
return
_Session.img_idx = img_idx
try:
pil = load_image(img_idx)
pg = ds['patch_grid']
bokeh_arr = pil_to_bokeh_rgba(pil, _PATCH_FIG_PX)
_patch_bg_source.data = dict(
image=[bokeh_arr], x=[0], y=[0], dw=[pg], dh=[pg])
except Exception as e:
_patch_info_div.text = f"<b style='color:red'>Error loading image: {e}</b>"
return
_load_patch_btn.disabled = True
_patch_info_div.text = (
"<span style='color:#2563eb'>&#x23F3; Computing patch activations"
+ (" (running GPU inference β€” first image may take ~10 s)…"
if not args.sae_path else "…")
+ "</span>"
)
doc = curdoc()
def _bg():
try:
z_np = compute_patch_activations(img_idx)
except Exception as e:
err = str(e)
def _show_err(err=err):
_load_patch_btn.disabled = False
_patch_info_div.text = (
f"<b style='color:red'>Error: {err}</b>")
doc.add_next_tick_callback(_show_err)
return
def _apply(z_np=z_np, img_idx=img_idx):
_Session.patch_z = z_np
_load_patch_btn.disabled = False
_patch_fig.visible = True
_patch_grid_source.selected.indices = []
_patch_results_div.text = ""
# Sync NSD sample for brain steering
nsd_name = resolve_nsd_basename(img_idx)
if nsd_name:
set_nsd_sample(nsd_name)
_patch_fig.title.text = f"Paint patch selection on {nsd_name}"
else:
_patch_fig.title.text = f"Paint patch selection on image {img_idx}"
if z_np is None:
_patch_info_div.text = (
"<b style='color:#6b7280'>GPU inference unavailable.</b>")
else:
_patch_info_div.text = "Paint patches to find features."
doc.add_next_tick_callback(_apply)
threading.Thread(target=_bg, daemon=True).start()
def _on_patch_select(attr, old, new):
if _Session.img_idx is None:
return
if not new:
_patch_results_div.text = ""
_patch_info_div.text = "<i>Selection cleared.</i>"
return
ds = active_ds()
pg = ds['patch_grid']
rows = [_patch_grid_source.data['row'][i] for i in new]
cols = [_patch_grid_source.data['col'][i] for i in new]
patch_indices = [r * pg + c for r, c in zip(rows, cols)]
feats, acts, freqs, means = get_top_features_for_patches(
_Session.patch_z, patch_indices)
_patch_results_div.text = make_search_result_html(
feats[:10], ds, n_meis=3, size=72)
_patch_info_div.text = "Click a feature to explore it."
def _on_clear():
_patch_grid_source.selected.indices = []
_patch_results_div.text = ""
_patch_info_div.text = "<i>Selection cleared.</i>"
def _on_patch_load_bridge(attr, old, new):
"""JS fires window._sae_load_patch_image(idx) -> sets bridge value."""
try:
img_idx = int(new.split('|')[0])
_patch_img_input.value = str(img_idx)
_on_load_image()
except (ValueError, IndexError):
pass
_patch_grid_source.selected.on_change('indices', _on_patch_select)
_load_patch_btn.on_click(_on_load_image)
_clear_patch_btn.on_click(_on_clear)
patch_load_bridge.on_change('value', _on_patch_load_bridge)
# ══════════════════════════════════════════════════════════════════
# Steering List + Brain Visualisation + DynaDiff
# ══════════════════════════════════════════════════════════════════
# Stubs when DynaDiff is unavailable
gt_brain_div = Div(text="", width=1)
steer_brain_div = Div(text="", width=1)
steered_brain_div = Div(text="", width=1)
active_features_div = Div(text="", width=1)
dynadiff_panel = None
feat_action_bridge = None
_dd_source = ColumnDataSource(data=dict(feat=[], name=[], lam=[], threshold=[]))
def add_feature(feat: int, lam: float = 3.0, threshold: float = 0.10):
"""No-op stub when DynaDiff is disabled."""
pass
def set_preset(entries: list, label: str = ''):
"""No-op stub when DynaDiff is disabled."""
pass
def set_nsd_sample(basename: str):
"""No-op stub when DynaDiff is disabled."""
pass
if HAS_DYNADIFF:
# ── Data source & widgets ────────────────────────────────────
_dd_source = ColumnDataSource(
data=dict(feat=[], name=[], lam=[], threshold=[]))
feat_action_bridge = TextInput(value="", width=1, visible=False)
gt_brain_div = Div(text="", width=410)
steer_brain_div = Div(text="", width=460)
steered_brain_div = Div(text="", width=480)
active_features_div = Div(text="", width=460)
_dd_status_div = Div(text="", width=460)
_dd_output_div = Div(text="", width=460)
_dd_run_btn = Button(
label="Steer & Reconstruct", button_type="primary", width=200)
# ── Brain visualisation updates ──────────────────────────────
def _steerings_from_source():
return (list(_dd_source.data['feat']),
list(_dd_source.data['lam']),
list(_dd_source.data['threshold']))
_tiles_render_token = [0] # discard stale active-feature tile renders
_steer_render_token = [0] # mutable counter to discard stale renders
_gt_render_token = [0] # discard stale GT brain renders
_steered_render_token = [0] # discard stale steered brain renders
def _update_steer_brain():
feats, lams, thrs = _steerings_from_source()
if not feats:
steer_brain_div.text = ''
return
_steer_render_token[0] += 1
my_token = _steer_render_token[0]
steer_brain_div.text = ''
doc = curdoc()
def _bg():
combined = compute_steering_direction(feats, lams, thrs)
b64 = render_fmri_brain_compact_b64(
combined, 'Steering Direction (Ο† sum)')
def _apply():
if _steer_render_token[0] == my_token:
steer_brain_div.text = (
f'<img src="data:image/png;base64,{b64}" '
f'style="max-width:100%"/>'
if b64 else '')
doc.add_next_tick_callback(_apply)
threading.Thread(target=_bg, daemon=True).start()
def _update_steered_brain():
fmri = _Session.gt_fmri
if fmri is None:
steered_brain_div.text = ''
return
feats, lams, thrs = _steerings_from_source()
if not feats:
steered_brain_div.text = ''
return
_steered_render_token[0] += 1
my_token = _steered_render_token[0]
steered_brain_div.text = ''
doc = curdoc()
def _bg():
steered = compute_steered_fmri(fmri, feats, lams, thrs)
b64 = render_fmri_brain_compact_b64(
steered, 'Expected Steered Brain')
def _apply():
if _steered_render_token[0] != my_token:
return
steered_brain_div.text = (
f'<img src="data:image/png;base64,{b64}" '
f'style="max-width:100%"/>'
if b64 else '')
doc.add_next_tick_callback(_apply)
threading.Thread(target=_bg, daemon=True).start()
def _update_active_tiles():
feats = list(_dd_source.data['feat'])
lams = list(_dd_source.data['lam'])
if not feats:
active_features_div.text = make_active_features_tile_html(
[], active_ds(), removable=True)
return
_tiles_render_token[0] += 1
my_token = _tiles_render_token[0]
active_features_div.text = (
'<div style="color:#6b7280;font-style:italic;font-size:11px;'
'padding:6px">Rendering tiles&#x2026;</div>')
doc = curdoc()
def _bg():
html = make_active_features_tile_html(
feats, active_ds(), removable=True, lams=lams)
def _apply():
if _tiles_render_token[0] == my_token:
active_features_div.text = html
doc.add_next_tick_callback(_apply)
threading.Thread(target=_bg, daemon=True).start()
def _on_source_change(attr, old, new):
# Defer to next tick so these updates don't nest inside the
# caller's document lock (avoids Bokeh _pending_writes error).
def _deferred():
_update_active_tiles()
_update_steer_brain()
_update_steered_brain()
curdoc().add_next_tick_callback(_deferred)
_dd_source.on_change('data', _on_source_change)
# ── GT brain loading ─────────────────────────────────────────
def _load_gt_brain(nsd_basename):
"""Load GT fMRI and render brain for an NSD image (threaded)."""
nsd_img_idx = parse_nsd_img_idx(nsd_basename)
if nsd_img_idx is None:
_Session.gt_fmri = None
gt_brain_div.text = ''
steered_brain_div.text = ''
return
_gt_render_token[0] += 1
my_token = _gt_render_token[0]
doc = curdoc()
def _bg():
_, fmri = load_gt_fmri(nsd_basename)
if _gt_render_token[0] != my_token:
return
# Use precomputed GT brain render if available
cached_b64 = active_ds().get('gt_brain_cache', {}).get(nsd_img_idx)
if cached_b64 is not None:
b64 = cached_b64
else:
b64 = (render_fmri_brain_compact_b64(fmri, 'GT Brain Response')
if fmri is not None else None)
def _apply(fmri=fmri, b64=b64):
if _gt_render_token[0] != my_token:
return
_Session.gt_fmri = fmri
gt_brain_div.text = (
f'<img src="data:image/png;base64,{b64}" '
f'style="max-width:100%"/>'
if b64 else '')
_update_steered_brain()
doc.add_next_tick_callback(_apply)
threading.Thread(target=_bg, daemon=True).start()
# ── Public API ───────────────────────────────────────────────
def add_feature(feat: int, lam: float = 3.0,
threshold: float = 0.10):
"""Add a feature to the steering list."""
err = validate_feature(feat)
if err:
_dd_status_div.text = (
f'<span style="color:#dc2626">{err}</span>')
return
if feat in list(_dd_source.data['feat']):
_dd_status_div.text = (
f'<i style="color:#6b7280">'
f'Feature {feat} already in list.</i>')
return
entry = make_steering_entry(feat, lam, threshold)
new_data = {k: list(v) for k, v in _dd_source.data.items()}
new_data['feat'].append(entry['feat'])
new_data['name'].append(entry['name'])
new_data['lam'].append(entry['lam'])
new_data['threshold'].append(entry['threshold'])
_dd_source.data = new_data
_dd_status_div.text = (
f'<i style="color:#6b7280">'
f'Feature {feat} added to steering.</i>')
def set_preset(entries: list, label: str = ''):
"""Replace the steering list with preset entries."""
new_data = dict(feat=[], name=[], lam=[], threshold=[])
for raw in entries:
e = make_steering_entry(
int(raw['feat']),
float(raw.get('lam', 3.0)),
float(raw.get('threshold', 0.10)),
)
new_data['feat'].append(e['feat'])
new_data['name'].append(e['name'])
new_data['lam'].append(e['lam'])
new_data['threshold'].append(e['threshold'])
_dd_source.data = new_data
if label:
_dd_status_div.text = (
f'<i style="color:#6b7280">Loaded preset: {label}</i>')
def set_nsd_sample(basename: str):
"""Update the NSD sample being steered and load its GT brain."""
if basename == _Session.nsd_basename:
return
_Session.nsd_basename = basename
_load_gt_brain(basename)
def load_patch_image(image_label: str):
"""Load an image into the patch explorer by name/index string.
This triggers the full load chain: image display, patch activations,
NSD sample detection, and GT brain rendering.
"""
_patch_img_input.value = image_label
_on_load_image()
# ── Feature action bridge (remove / set_lam from HTML) ───────
def _on_feat_action(attr, old, new):
msg = new.split('|')[0]
if msg.startswith('remove_feat:'):
try:
feat = int(msg.split(':')[1])
except (ValueError, IndexError):
return
feats = list(_dd_source.data['feat'])
if feat not in feats:
return
idx = feats.index(feat)
new_data = {k: [v for i, v in enumerate(vals) if i != idx]
for k, vals in _dd_source.data.items()}
_dd_source.data = new_data
_dd_status_div.text = ''
elif msg.startswith('set_lam:'):
parts = msg.split(':', 2)
if len(parts) != 3:
return
try:
feat = int(parts[1])
new_val = float(parts[2])
except ValueError:
return
feats = list(_dd_source.data['feat'])
if feat not in feats:
return
idx = feats.index(feat)
new_lams = list(_dd_source.data['lam'])
new_lams[idx] = new_val
new_data = dict(_dd_source.data)
new_data['lam'] = new_lams
_dd_source.data = new_data
feat_action_bridge.on_change('value', _on_feat_action)
curdoc().js_on_event('document_ready', CustomJS(
args=dict(bridge=feat_action_bridge),
code="""
window._dd_feat_action = function(msg) {
bridge.value = msg + '|' + Date.now();
};
""",
))
# ── DynaDiff reconstruction ──────────────────────────────────
def _reconstruct_thread(sample_idxs, steerings, doc,
nsd_img_idx=None):
try:
resp = run_reconstruction(
sample_idxs, steerings, seed=42,
nsd_img_idx=nsd_img_idx)
steer_b64 = resp.get('steered_img')
if steer_b64:
html = (
f'<img src="data:image/png;base64,{steer_b64}" '
f'style="max-width:100%;border-radius:4px;'
f'border:1px solid #ddd"/>')
else:
html = ('<div style="color:#aaa;font-style:italic">'
'No steered output.</div>')
def _apply(html=html):
_dd_output_div.text = html
_dd_status_div.text = ''
_dd_run_btn.disabled = False
doc.add_next_tick_callback(_apply)
except Exception as exc:
msg = str(exc)
def _show_err(msg=msg):
_dd_status_div.text = (
f'<span style="color:#dc2626">Error: {msg}</span>')
_dd_run_btn.disabled = False
doc.add_next_tick_callback(_show_err)
def _on_reconstruct():
feats, lams, thrs = _steerings_from_source()
# Prefer the currently loaded patch image's NSD basename
nsd_basename = _Session.nsd_basename
if _Session.img_idx is not None:
img_basename = resolve_nsd_basename(_Session.img_idx)
if img_basename:
nsd_basename = img_basename
_Session.nsd_basename = nsd_basename
sample_idxs, steerings, err = validate_reconstruction(
nsd_basename, feats, lams, thrs)
if err:
_dd_status_div.text = (
f'<span style="color:#dc2626">{err}</span>')
return
nsd_img_idx = parse_nsd_img_idx(nsd_basename)
_dd_run_btn.disabled = True
_dd_status_div.text = (
'<i style="color:#6b7280">'
'Running DynaDiff reconstruction…</i>')
threading.Thread(
target=_reconstruct_thread,
args=(sample_idxs, steerings, curdoc(), nsd_img_idx),
daemon=True,
).start()
_dd_run_btn.on_click(_on_reconstruct)
dynadiff_panel = column(
feat_action_bridge,
row(_dd_run_btn, _dd_status_div),
_dd_output_div,
)
# ══════════════════════════════════════════════════════════════════
# Layout exports
# ══════════════════════════════════════════════════════════════════
patch_explorer_panel = column(
_patch_fig,
_patch_info_div,
_patch_results_div,
)