""" Steering workspace UI: patch explorer + feature steering list + DynaDiff. Combines patch-level feature exploration with brain-steering controls. Pure computation lives in explorer.steering_logic; this module handles Bokeh widgets, callbacks, and layout. Exports (layout): patch_explorer_panel — column: patch figure + info + results gt_brain_div — Div: GT brain response steer_brain_div — Div: steering direction brain map steered_brain_div — Div: expected steered brain active_features_div — Div: steering feature tile cards dynadiff_panel — column: run button + status + output (or None) Exports (JS bridges — needed by main.py): patch_load_bridge — TextInput: JS -> Python image loading feat_action_bridge — TextInput: JS -> Python remove/set_lam (or None) Exports (public API for other panels): add_feature(feat, lam, threshold) set_preset(entries, label) set_nsd_sample(basename) """ import threading from bokeh.events import MouseMove from bokeh.io import curdoc from bokeh.layouts import column, row from bokeh.models import ( Button, ColumnDataSource, CustomJS, Div, TextInput, ) from bokeh.plotting import figure from ..args import args from ..state import active_ds from ..steering_logic import ( compute_patch_activations, get_top_features_for_patches, resolve_nsd_basename, parse_nsd_img_idx, compute_steering_direction, compute_steered_fmri, validate_feature, make_steering_entry, validate_reconstruction, run_reconstruction, load_gt_fmri, ) from ..brain import ( HAS_DYNADIFF, render_fmri_brain_compact_b64, ) from ..rendering import ( load_image, parse_img_label, pil_to_bokeh_rgba, make_search_result_html, make_active_features_tile_html, ) # ══════════════════════════════════════════════════════════════════ # Session state (reset each Bokeh session via module reimport) # ══════════════════════════════════════════════════════════════════ class _Session: """Per-session mutable state for the steering workspace.""" img_idx = None # image index loaded in patch explorer patch_z = None # (n_patches, d_sae) activations nsd_basename = None # e.g. 'nsd_22910' gt_fmri = None # raw fMRI (N_VOXELS,) array _PATCH_FIG_PX = 400 # ══════════════════════════════════════════════════════════════════ # Patch Explorer # ══════════════════════════════════════════════════════════════════ _patch_bg_source = ColumnDataSource(data=dict( image=[], x=[0], y=[0], dw=[16], dh=[16], )) _pg0 = active_ds()['patch_grid'] def _make_grid_source(pg: int) -> ColumnDataSource: pr = [r for r in range(pg) for _ in range(pg)] pc = list(range(pg)) * pg return ColumnDataSource(data=dict( x=[c + 0.5 for c in pc], y=[pg - r - 0.5 for r in pr], row=pr, col=pc, )) _patch_grid_source = _make_grid_source(_pg0) _patch_fig = figure( width=_PATCH_FIG_PX, height=_PATCH_FIG_PX, x_range=(0, _pg0), y_range=(0, _pg0), tools=["tap", "reset"], title="Click or drag to paint patch selection", toolbar_location="above", visible=False, ) _paint_js = CustomJS(args=dict(source=_patch_grid_source, pg=_pg0), code=""" if (!window._patch_paint_init) { window._patch_paint_init = true; window._patch_btn_held = false; document.addEventListener('mousedown', () => { window._patch_btn_held = true; }); document.addEventListener('mouseup', () => { window._patch_btn_held = false; }); } if (!window._patch_btn_held) return; const x = cb_obj.x, y = cb_obj.y; if (x === null || y === null || x < 0 || x >= pg || y < 0 || y >= pg) return; const col = Math.floor(x); const row = pg - 1 - Math.floor(y); const flat_idx = row * pg + col; const sel = source.selected.indices.slice(); if (sel.indexOf(flat_idx) === -1) { sel.push(flat_idx); source.selected.indices = sel; } """) _patch_fig.js_on_event(MouseMove, _paint_js) _patch_fig.image_rgba( source=_patch_bg_source, image='image', x='x', y='y', dw='dw', dh='dh') _patch_fig.rect( source=_patch_grid_source, x='x', y='y', width=0.95, height=0.95, fill_color='yellow', fill_alpha=0.0, line_color='white', line_alpha=0.35, line_width=0.5, selection_fill_color='red', selection_fill_alpha=0.45, nonselection_fill_alpha=0.0, nonselection_line_alpha=0.35, ) _patch_fig.axis.visible = False _patch_fig.xgrid.visible = False _patch_fig.ygrid.visible = False _patch_results_div = Div(text="", width=310) _patch_img_input = TextInput(title="Image Index:", value="0", width=120) _load_patch_btn = Button(label="Load Image", width=90, button_type="primary") _clear_patch_btn = Button(label="Clear", width=60) _patch_info_div = Div( text="Click an image in the Feature Explorer to load it here.", width=310, ) # JS bridge: gallery tile onclick -> window._sae_load_patch_image(idx) patch_load_bridge = TextInput(value="", width=1, height=1, visible=False) # ── Patch callbacks ────────────────────────────────────────────── def _on_load_image(): try: img_idx = parse_img_label(_patch_img_input.value) except ValueError: _patch_info_div.text = "Invalid image index" return ds = active_ds() if not (0 <= img_idx < ds['n_images']): _patch_info_div.text = ( f"Index out of range " f"(0–{ds['n_images'] - 1})") return _Session.img_idx = img_idx try: pil = load_image(img_idx) pg = ds['patch_grid'] bokeh_arr = pil_to_bokeh_rgba(pil, _PATCH_FIG_PX) _patch_bg_source.data = dict( image=[bokeh_arr], x=[0], y=[0], dw=[pg], dh=[pg]) except Exception as e: _patch_info_div.text = f"Error loading image: {e}" return _load_patch_btn.disabled = True _patch_info_div.text = ( "⏳ Computing patch activations" + (" (running GPU inference — first image may take ~10 s)…" if not args.sae_path else "…") + "" ) doc = curdoc() def _bg(): try: z_np = compute_patch_activations(img_idx) except Exception as e: err = str(e) def _show_err(err=err): _load_patch_btn.disabled = False _patch_info_div.text = ( f"Error: {err}") doc.add_next_tick_callback(_show_err) return def _apply(z_np=z_np, img_idx=img_idx): _Session.patch_z = z_np _load_patch_btn.disabled = False _patch_fig.visible = True _patch_grid_source.selected.indices = [] _patch_results_div.text = "" # Sync NSD sample for brain steering nsd_name = resolve_nsd_basename(img_idx) if nsd_name: set_nsd_sample(nsd_name) _patch_fig.title.text = f"Paint patch selection on {nsd_name}" else: _patch_fig.title.text = f"Paint patch selection on image {img_idx}" if z_np is None: _patch_info_div.text = ( "GPU inference unavailable.") else: _patch_info_div.text = "Paint patches to find features." doc.add_next_tick_callback(_apply) threading.Thread(target=_bg, daemon=True).start() def _on_patch_select(attr, old, new): if _Session.img_idx is None: return if not new: _patch_results_div.text = "" _patch_info_div.text = "Selection cleared." return ds = active_ds() pg = ds['patch_grid'] rows = [_patch_grid_source.data['row'][i] for i in new] cols = [_patch_grid_source.data['col'][i] for i in new] patch_indices = [r * pg + c for r, c in zip(rows, cols)] feats, acts, freqs, means = get_top_features_for_patches( _Session.patch_z, patch_indices) _patch_results_div.text = make_search_result_html( feats[:10], ds, n_meis=3, size=72) _patch_info_div.text = "Click a feature to explore it." def _on_clear(): _patch_grid_source.selected.indices = [] _patch_results_div.text = "" _patch_info_div.text = "Selection cleared." def _on_patch_load_bridge(attr, old, new): """JS fires window._sae_load_patch_image(idx) -> sets bridge value.""" try: img_idx = int(new.split('|')[0]) _patch_img_input.value = str(img_idx) _on_load_image() except (ValueError, IndexError): pass _patch_grid_source.selected.on_change('indices', _on_patch_select) _load_patch_btn.on_click(_on_load_image) _clear_patch_btn.on_click(_on_clear) patch_load_bridge.on_change('value', _on_patch_load_bridge) # ══════════════════════════════════════════════════════════════════ # Steering List + Brain Visualisation + DynaDiff # ══════════════════════════════════════════════════════════════════ # Stubs when DynaDiff is unavailable gt_brain_div = Div(text="", width=1) steer_brain_div = Div(text="", width=1) steered_brain_div = Div(text="", width=1) active_features_div = Div(text="", width=1) dynadiff_panel = None feat_action_bridge = None _dd_source = ColumnDataSource(data=dict(feat=[], name=[], lam=[], threshold=[])) def add_feature(feat: int, lam: float = 3.0, threshold: float = 0.10): """No-op stub when DynaDiff is disabled.""" pass def set_preset(entries: list, label: str = ''): """No-op stub when DynaDiff is disabled.""" pass def set_nsd_sample(basename: str): """No-op stub when DynaDiff is disabled.""" pass if HAS_DYNADIFF: # ── Data source & widgets ──────────────────────────────────── _dd_source = ColumnDataSource( data=dict(feat=[], name=[], lam=[], threshold=[])) feat_action_bridge = TextInput(value="", width=1, visible=False) gt_brain_div = Div(text="", width=410) steer_brain_div = Div(text="", width=460) steered_brain_div = Div(text="", width=480) active_features_div = Div(text="", width=460) _dd_status_div = Div(text="", width=460) _dd_output_div = Div(text="", width=460) _dd_run_btn = Button( label="Steer & Reconstruct", button_type="primary", width=200) # ── Brain visualisation updates ────────────────────────────── def _steerings_from_source(): return (list(_dd_source.data['feat']), list(_dd_source.data['lam']), list(_dd_source.data['threshold'])) _tiles_render_token = [0] # discard stale active-feature tile renders _steer_render_token = [0] # mutable counter to discard stale renders _gt_render_token = [0] # discard stale GT brain renders _steered_render_token = [0] # discard stale steered brain renders def _update_steer_brain(): feats, lams, thrs = _steerings_from_source() if not feats: steer_brain_div.text = '' return _steer_render_token[0] += 1 my_token = _steer_render_token[0] steer_brain_div.text = '' doc = curdoc() def _bg(): combined = compute_steering_direction(feats, lams, thrs) b64 = render_fmri_brain_compact_b64( combined, 'Steering Direction (φ sum)') def _apply(): if _steer_render_token[0] == my_token: steer_brain_div.text = ( f'' if b64 else '') doc.add_next_tick_callback(_apply) threading.Thread(target=_bg, daemon=True).start() def _update_steered_brain(): fmri = _Session.gt_fmri if fmri is None: steered_brain_div.text = '' return feats, lams, thrs = _steerings_from_source() if not feats: steered_brain_div.text = '' return _steered_render_token[0] += 1 my_token = _steered_render_token[0] steered_brain_div.text = '' doc = curdoc() def _bg(): steered = compute_steered_fmri(fmri, feats, lams, thrs) b64 = render_fmri_brain_compact_b64( steered, 'Expected Steered Brain') def _apply(): if _steered_render_token[0] != my_token: return steered_brain_div.text = ( f'' if b64 else '') doc.add_next_tick_callback(_apply) threading.Thread(target=_bg, daemon=True).start() def _update_active_tiles(): feats = list(_dd_source.data['feat']) lams = list(_dd_source.data['lam']) if not feats: active_features_div.text = make_active_features_tile_html( [], active_ds(), removable=True) return _tiles_render_token[0] += 1 my_token = _tiles_render_token[0] active_features_div.text = ( '
Rendering tiles…
') doc = curdoc() def _bg(): html = make_active_features_tile_html( feats, active_ds(), removable=True, lams=lams) def _apply(): if _tiles_render_token[0] == my_token: active_features_div.text = html doc.add_next_tick_callback(_apply) threading.Thread(target=_bg, daemon=True).start() def _on_source_change(attr, old, new): # Defer to next tick so these updates don't nest inside the # caller's document lock (avoids Bokeh _pending_writes error). def _deferred(): _update_active_tiles() _update_steer_brain() _update_steered_brain() curdoc().add_next_tick_callback(_deferred) _dd_source.on_change('data', _on_source_change) # ── GT brain loading ───────────────────────────────────────── def _load_gt_brain(nsd_basename): """Load GT fMRI and render brain for an NSD image (threaded).""" nsd_img_idx = parse_nsd_img_idx(nsd_basename) if nsd_img_idx is None: _Session.gt_fmri = None gt_brain_div.text = '' steered_brain_div.text = '' return _gt_render_token[0] += 1 my_token = _gt_render_token[0] doc = curdoc() def _bg(): _, fmri = load_gt_fmri(nsd_basename) if _gt_render_token[0] != my_token: return # Use precomputed GT brain render if available cached_b64 = active_ds().get('gt_brain_cache', {}).get(nsd_img_idx) if cached_b64 is not None: b64 = cached_b64 else: b64 = (render_fmri_brain_compact_b64(fmri, 'GT Brain Response') if fmri is not None else None) def _apply(fmri=fmri, b64=b64): if _gt_render_token[0] != my_token: return _Session.gt_fmri = fmri gt_brain_div.text = ( f'' if b64 else '') _update_steered_brain() doc.add_next_tick_callback(_apply) threading.Thread(target=_bg, daemon=True).start() # ── Public API ─────────────────────────────────────────────── def add_feature(feat: int, lam: float = 3.0, threshold: float = 0.10): """Add a feature to the steering list.""" err = validate_feature(feat) if err: _dd_status_div.text = ( f'{err}') return if feat in list(_dd_source.data['feat']): _dd_status_div.text = ( f'' f'Feature {feat} already in list.') return entry = make_steering_entry(feat, lam, threshold) new_data = {k: list(v) for k, v in _dd_source.data.items()} new_data['feat'].append(entry['feat']) new_data['name'].append(entry['name']) new_data['lam'].append(entry['lam']) new_data['threshold'].append(entry['threshold']) _dd_source.data = new_data _dd_status_div.text = ( f'' f'Feature {feat} added to steering.') def set_preset(entries: list, label: str = ''): """Replace the steering list with preset entries.""" new_data = dict(feat=[], name=[], lam=[], threshold=[]) for raw in entries: e = make_steering_entry( int(raw['feat']), float(raw.get('lam', 3.0)), float(raw.get('threshold', 0.10)), ) new_data['feat'].append(e['feat']) new_data['name'].append(e['name']) new_data['lam'].append(e['lam']) new_data['threshold'].append(e['threshold']) _dd_source.data = new_data if label: _dd_status_div.text = ( f'Loaded preset: {label}') def set_nsd_sample(basename: str): """Update the NSD sample being steered and load its GT brain.""" if basename == _Session.nsd_basename: return _Session.nsd_basename = basename _load_gt_brain(basename) def load_patch_image(image_label: str): """Load an image into the patch explorer by name/index string. This triggers the full load chain: image display, patch activations, NSD sample detection, and GT brain rendering. """ _patch_img_input.value = image_label _on_load_image() # ── Feature action bridge (remove / set_lam from HTML) ─────── def _on_feat_action(attr, old, new): msg = new.split('|')[0] if msg.startswith('remove_feat:'): try: feat = int(msg.split(':')[1]) except (ValueError, IndexError): return feats = list(_dd_source.data['feat']) if feat not in feats: return idx = feats.index(feat) new_data = {k: [v for i, v in enumerate(vals) if i != idx] for k, vals in _dd_source.data.items()} _dd_source.data = new_data _dd_status_div.text = '' elif msg.startswith('set_lam:'): parts = msg.split(':', 2) if len(parts) != 3: return try: feat = int(parts[1]) new_val = float(parts[2]) except ValueError: return feats = list(_dd_source.data['feat']) if feat not in feats: return idx = feats.index(feat) new_lams = list(_dd_source.data['lam']) new_lams[idx] = new_val new_data = dict(_dd_source.data) new_data['lam'] = new_lams _dd_source.data = new_data feat_action_bridge.on_change('value', _on_feat_action) curdoc().js_on_event('document_ready', CustomJS( args=dict(bridge=feat_action_bridge), code=""" window._dd_feat_action = function(msg) { bridge.value = msg + '|' + Date.now(); }; """, )) # ── DynaDiff reconstruction ────────────────────────────────── def _reconstruct_thread(sample_idxs, steerings, doc, nsd_img_idx=None): try: resp = run_reconstruction( sample_idxs, steerings, seed=42, nsd_img_idx=nsd_img_idx) steer_b64 = resp.get('steered_img') if steer_b64: html = ( f'') else: html = ('
' 'No steered output.
') def _apply(html=html): _dd_output_div.text = html _dd_status_div.text = '' _dd_run_btn.disabled = False doc.add_next_tick_callback(_apply) except Exception as exc: msg = str(exc) def _show_err(msg=msg): _dd_status_div.text = ( f'Error: {msg}') _dd_run_btn.disabled = False doc.add_next_tick_callback(_show_err) def _on_reconstruct(): feats, lams, thrs = _steerings_from_source() # Prefer the currently loaded patch image's NSD basename nsd_basename = _Session.nsd_basename if _Session.img_idx is not None: img_basename = resolve_nsd_basename(_Session.img_idx) if img_basename: nsd_basename = img_basename _Session.nsd_basename = nsd_basename sample_idxs, steerings, err = validate_reconstruction( nsd_basename, feats, lams, thrs) if err: _dd_status_div.text = ( f'{err}') return nsd_img_idx = parse_nsd_img_idx(nsd_basename) _dd_run_btn.disabled = True _dd_status_div.text = ( '' 'Running DynaDiff reconstruction…') threading.Thread( target=_reconstruct_thread, args=(sample_idxs, steerings, curdoc(), nsd_img_idx), daemon=True, ).start() _dd_run_btn.on_click(_on_reconstruct) dynadiff_panel = column( feat_action_bridge, row(_dd_run_btn, _dd_status_div), _dd_output_div, ) # ══════════════════════════════════════════════════════════════════ # Layout exports # ══════════════════════════════════════════════════════════════════ patch_explorer_panel = column( _patch_fig, _patch_info_div, _patch_results_div, )