Spaces:
Sleeping
Sleeping
Marlin Lee
UI redesign: consistent design system, card layout, unified colors and typography
4a5ed93 | """ | |
| Steering workspace UI: patch explorer + feature steering list + DynaDiff. | |
| Combines patch-level feature exploration with brain-steering controls. | |
| Pure computation lives in explorer.steering_logic; this module handles | |
| Bokeh widgets, callbacks, and layout. | |
| Exports (layout): | |
| patch_explorer_panel β column: patch figure + info + results | |
| gt_brain_div β Div: GT brain response | |
| steer_brain_div β Div: steering direction brain map | |
| steered_brain_div β Div: expected steered brain | |
| active_features_div β Div: steering feature tile cards | |
| dynadiff_panel β column: run button + status + output (or None) | |
| Exports (JS bridges β needed by main.py): | |
| patch_load_bridge β TextInput: JS -> Python image loading | |
| feat_action_bridge β TextInput: JS -> Python remove/set_lam (or None) | |
| Exports (public API for other panels): | |
| add_feature(feat, lam, threshold) | |
| set_preset(entries, label) | |
| set_nsd_sample(basename) | |
| """ | |
| import threading | |
| from bokeh.events import MouseMove | |
| from bokeh.io import curdoc | |
| from bokeh.layouts import column, row | |
| from bokeh.models import ( | |
| Button, ColumnDataSource, CustomJS, Div, TextInput, | |
| ) | |
| from bokeh.plotting import figure | |
| from ..args import args | |
| from ..state import active_ds | |
| from ..steering_logic import ( | |
| compute_patch_activations, get_top_features_for_patches, | |
| resolve_nsd_basename, parse_nsd_img_idx, | |
| compute_steering_direction, compute_steered_fmri, | |
| validate_feature, make_steering_entry, | |
| validate_reconstruction, run_reconstruction, load_gt_fmri, | |
| ) | |
| from ..brain import ( | |
| HAS_DYNADIFF, | |
| render_fmri_brain_compact_b64, | |
| ) | |
| from ..rendering import ( | |
| load_image, parse_img_label, pil_to_bokeh_rgba, | |
| make_search_result_html, make_active_features_tile_html, | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Session state (reset each Bokeh session via module reimport) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class _Session: | |
| """Per-session mutable state for the steering workspace.""" | |
| img_idx = None # image index loaded in patch explorer | |
| patch_z = None # (n_patches, d_sae) activations | |
| nsd_basename = None # e.g. 'nsd_22910' | |
| gt_fmri = None # raw fMRI (N_VOXELS,) array | |
| _PATCH_FIG_PX = 400 | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Patch Explorer | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _patch_bg_source = ColumnDataSource(data=dict( | |
| image=[], x=[0], y=[0], dw=[16], dh=[16], | |
| )) | |
| _pg0 = active_ds()['patch_grid'] | |
| def _make_grid_source(pg: int) -> ColumnDataSource: | |
| pr = [r for r in range(pg) for _ in range(pg)] | |
| pc = list(range(pg)) * pg | |
| return ColumnDataSource(data=dict( | |
| x=[c + 0.5 for c in pc], | |
| y=[pg - r - 0.5 for r in pr], | |
| row=pr, col=pc, | |
| )) | |
| _patch_grid_source = _make_grid_source(_pg0) | |
| _patch_fig = figure( | |
| width=_PATCH_FIG_PX, height=_PATCH_FIG_PX, | |
| x_range=(0, _pg0), y_range=(0, _pg0), | |
| tools=["tap", "reset"], | |
| title="Click or drag to paint patch selection", | |
| toolbar_location="above", | |
| visible=False, | |
| ) | |
| _paint_js = CustomJS(args=dict(source=_patch_grid_source, pg=_pg0), code=""" | |
| if (!window._patch_paint_init) { | |
| window._patch_paint_init = true; | |
| window._patch_btn_held = false; | |
| document.addEventListener('mousedown', () => { window._patch_btn_held = true; }); | |
| document.addEventListener('mouseup', () => { window._patch_btn_held = false; }); | |
| } | |
| if (!window._patch_btn_held) return; | |
| const x = cb_obj.x, y = cb_obj.y; | |
| if (x === null || y === null || x < 0 || x >= pg || y < 0 || y >= pg) return; | |
| const col = Math.floor(x); | |
| const row = pg - 1 - Math.floor(y); | |
| const flat_idx = row * pg + col; | |
| const sel = source.selected.indices.slice(); | |
| if (sel.indexOf(flat_idx) === -1) { sel.push(flat_idx); source.selected.indices = sel; } | |
| """) | |
| _patch_fig.js_on_event(MouseMove, _paint_js) | |
| _patch_fig.image_rgba( | |
| source=_patch_bg_source, image='image', x='x', y='y', dw='dw', dh='dh') | |
| _patch_fig.rect( | |
| source=_patch_grid_source, x='x', y='y', width=0.95, height=0.95, | |
| fill_color='yellow', fill_alpha=0.0, | |
| line_color='white', line_alpha=0.35, line_width=0.5, | |
| selection_fill_color='red', selection_fill_alpha=0.45, | |
| nonselection_fill_alpha=0.0, nonselection_line_alpha=0.35, | |
| ) | |
| _patch_fig.axis.visible = False | |
| _patch_fig.xgrid.visible = False | |
| _patch_fig.ygrid.visible = False | |
| _patch_results_div = Div(text="", width=310) | |
| _patch_img_input = TextInput(title="Image Index:", value="0", width=120) | |
| _load_patch_btn = Button(label="Load Image", width=90, button_type="primary") | |
| _clear_patch_btn = Button(label="Clear", width=60) | |
| _patch_info_div = Div( | |
| text="<i>Click an image in the Feature Explorer to load it here.</i>", | |
| width=310, | |
| ) | |
| # JS bridge: gallery tile onclick -> window._sae_load_patch_image(idx) | |
| patch_load_bridge = TextInput(value="", width=1, height=1, visible=False) | |
| # ββ Patch callbacks ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _on_load_image(): | |
| try: | |
| img_idx = parse_img_label(_patch_img_input.value) | |
| except ValueError: | |
| _patch_info_div.text = "<b style='color:red'>Invalid image index</b>" | |
| return | |
| ds = active_ds() | |
| if not (0 <= img_idx < ds['n_images']): | |
| _patch_info_div.text = ( | |
| f"<b style='color:red'>Index out of range " | |
| f"(0β{ds['n_images'] - 1})</b>") | |
| return | |
| _Session.img_idx = img_idx | |
| try: | |
| pil = load_image(img_idx) | |
| pg = ds['patch_grid'] | |
| bokeh_arr = pil_to_bokeh_rgba(pil, _PATCH_FIG_PX) | |
| _patch_bg_source.data = dict( | |
| image=[bokeh_arr], x=[0], y=[0], dw=[pg], dh=[pg]) | |
| except Exception as e: | |
| _patch_info_div.text = f"<b style='color:red'>Error loading image: {e}</b>" | |
| return | |
| _load_patch_btn.disabled = True | |
| _patch_info_div.text = ( | |
| "<span style='color:#2563eb'>⏳ Computing patch activations" | |
| + (" (running GPU inference β first image may take ~10 s)β¦" | |
| if not args.sae_path else "β¦") | |
| + "</span>" | |
| ) | |
| doc = curdoc() | |
| def _bg(): | |
| try: | |
| z_np = compute_patch_activations(img_idx) | |
| except Exception as e: | |
| err = str(e) | |
| def _show_err(err=err): | |
| _load_patch_btn.disabled = False | |
| _patch_info_div.text = ( | |
| f"<b style='color:red'>Error: {err}</b>") | |
| doc.add_next_tick_callback(_show_err) | |
| return | |
| def _apply(z_np=z_np, img_idx=img_idx): | |
| _Session.patch_z = z_np | |
| _load_patch_btn.disabled = False | |
| _patch_fig.visible = True | |
| _patch_grid_source.selected.indices = [] | |
| _patch_results_div.text = "" | |
| # Sync NSD sample for brain steering | |
| nsd_name = resolve_nsd_basename(img_idx) | |
| if nsd_name: | |
| set_nsd_sample(nsd_name) | |
| _patch_fig.title.text = f"Paint patch selection on {nsd_name}" | |
| else: | |
| _patch_fig.title.text = f"Paint patch selection on image {img_idx}" | |
| if z_np is None: | |
| _patch_info_div.text = ( | |
| "<b style='color:#6b7280'>GPU inference unavailable.</b>") | |
| else: | |
| _patch_info_div.text = "Paint patches to find features." | |
| doc.add_next_tick_callback(_apply) | |
| threading.Thread(target=_bg, daemon=True).start() | |
| def _on_patch_select(attr, old, new): | |
| if _Session.img_idx is None: | |
| return | |
| if not new: | |
| _patch_results_div.text = "" | |
| _patch_info_div.text = "<i>Selection cleared.</i>" | |
| return | |
| ds = active_ds() | |
| pg = ds['patch_grid'] | |
| rows = [_patch_grid_source.data['row'][i] for i in new] | |
| cols = [_patch_grid_source.data['col'][i] for i in new] | |
| patch_indices = [r * pg + c for r, c in zip(rows, cols)] | |
| feats, acts, freqs, means = get_top_features_for_patches( | |
| _Session.patch_z, patch_indices) | |
| _patch_results_div.text = make_search_result_html( | |
| feats[:10], ds, n_meis=3, size=72) | |
| _patch_info_div.text = "Click a feature to explore it." | |
| def _on_clear(): | |
| _patch_grid_source.selected.indices = [] | |
| _patch_results_div.text = "" | |
| _patch_info_div.text = "<i>Selection cleared.</i>" | |
| def _on_patch_load_bridge(attr, old, new): | |
| """JS fires window._sae_load_patch_image(idx) -> sets bridge value.""" | |
| try: | |
| img_idx = int(new.split('|')[0]) | |
| _patch_img_input.value = str(img_idx) | |
| _on_load_image() | |
| except (ValueError, IndexError): | |
| pass | |
| _patch_grid_source.selected.on_change('indices', _on_patch_select) | |
| _load_patch_btn.on_click(_on_load_image) | |
| _clear_patch_btn.on_click(_on_clear) | |
| patch_load_bridge.on_change('value', _on_patch_load_bridge) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Steering List + Brain Visualisation + DynaDiff | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Stubs when DynaDiff is unavailable | |
| gt_brain_div = Div(text="", width=1) | |
| steer_brain_div = Div(text="", width=1) | |
| steered_brain_div = Div(text="", width=1) | |
| active_features_div = Div(text="", width=1) | |
| dynadiff_panel = None | |
| feat_action_bridge = None | |
| _dd_source = ColumnDataSource(data=dict(feat=[], name=[], lam=[], threshold=[])) | |
| def add_feature(feat: int, lam: float = 3.0, threshold: float = 0.10): | |
| """No-op stub when DynaDiff is disabled.""" | |
| pass | |
| def set_preset(entries: list, label: str = ''): | |
| """No-op stub when DynaDiff is disabled.""" | |
| pass | |
| def set_nsd_sample(basename: str): | |
| """No-op stub when DynaDiff is disabled.""" | |
| pass | |
| if HAS_DYNADIFF: | |
| # ββ Data source & widgets ββββββββββββββββββββββββββββββββββββ | |
| _dd_source = ColumnDataSource( | |
| data=dict(feat=[], name=[], lam=[], threshold=[])) | |
| feat_action_bridge = TextInput(value="", width=1, visible=False) | |
| gt_brain_div = Div(text="", width=410) | |
| steer_brain_div = Div(text="", width=460) | |
| steered_brain_div = Div(text="", width=480) | |
| active_features_div = Div(text="", width=460) | |
| _dd_status_div = Div(text="", width=460) | |
| _dd_output_div = Div(text="", width=460) | |
| _dd_run_btn = Button( | |
| label="Steer & Reconstruct", button_type="primary", width=200) | |
| # ββ Brain visualisation updates ββββββββββββββββββββββββββββββ | |
| def _steerings_from_source(): | |
| return (list(_dd_source.data['feat']), | |
| list(_dd_source.data['lam']), | |
| list(_dd_source.data['threshold'])) | |
| _tiles_render_token = [0] # discard stale active-feature tile renders | |
| _steer_render_token = [0] # mutable counter to discard stale renders | |
| _gt_render_token = [0] # discard stale GT brain renders | |
| _steered_render_token = [0] # discard stale steered brain renders | |
| def _update_steer_brain(): | |
| feats, lams, thrs = _steerings_from_source() | |
| if not feats: | |
| steer_brain_div.text = '' | |
| return | |
| _steer_render_token[0] += 1 | |
| my_token = _steer_render_token[0] | |
| steer_brain_div.text = '' | |
| doc = curdoc() | |
| def _bg(): | |
| combined = compute_steering_direction(feats, lams, thrs) | |
| b64 = render_fmri_brain_compact_b64( | |
| combined, 'Steering Direction (Ο sum)') | |
| def _apply(): | |
| if _steer_render_token[0] == my_token: | |
| steer_brain_div.text = ( | |
| f'<img src="data:image/png;base64,{b64}" ' | |
| f'style="max-width:100%"/>' | |
| if b64 else '') | |
| doc.add_next_tick_callback(_apply) | |
| threading.Thread(target=_bg, daemon=True).start() | |
| def _update_steered_brain(): | |
| fmri = _Session.gt_fmri | |
| if fmri is None: | |
| steered_brain_div.text = '' | |
| return | |
| feats, lams, thrs = _steerings_from_source() | |
| if not feats: | |
| steered_brain_div.text = '' | |
| return | |
| _steered_render_token[0] += 1 | |
| my_token = _steered_render_token[0] | |
| steered_brain_div.text = '' | |
| doc = curdoc() | |
| def _bg(): | |
| steered = compute_steered_fmri(fmri, feats, lams, thrs) | |
| b64 = render_fmri_brain_compact_b64( | |
| steered, 'Expected Steered Brain') | |
| def _apply(): | |
| if _steered_render_token[0] != my_token: | |
| return | |
| steered_brain_div.text = ( | |
| f'<img src="data:image/png;base64,{b64}" ' | |
| f'style="max-width:100%"/>' | |
| if b64 else '') | |
| doc.add_next_tick_callback(_apply) | |
| threading.Thread(target=_bg, daemon=True).start() | |
| def _update_active_tiles(): | |
| feats = list(_dd_source.data['feat']) | |
| lams = list(_dd_source.data['lam']) | |
| if not feats: | |
| active_features_div.text = make_active_features_tile_html( | |
| [], active_ds(), removable=True) | |
| return | |
| _tiles_render_token[0] += 1 | |
| my_token = _tiles_render_token[0] | |
| active_features_div.text = ( | |
| '<div style="color:#6b7280;font-style:italic;font-size:11px;' | |
| 'padding:6px">Rendering tiles…</div>') | |
| doc = curdoc() | |
| def _bg(): | |
| html = make_active_features_tile_html( | |
| feats, active_ds(), removable=True, lams=lams) | |
| def _apply(): | |
| if _tiles_render_token[0] == my_token: | |
| active_features_div.text = html | |
| doc.add_next_tick_callback(_apply) | |
| threading.Thread(target=_bg, daemon=True).start() | |
| def _on_source_change(attr, old, new): | |
| # Defer to next tick so these updates don't nest inside the | |
| # caller's document lock (avoids Bokeh _pending_writes error). | |
| def _deferred(): | |
| _update_active_tiles() | |
| _update_steer_brain() | |
| _update_steered_brain() | |
| curdoc().add_next_tick_callback(_deferred) | |
| _dd_source.on_change('data', _on_source_change) | |
| # ββ GT brain loading βββββββββββββββββββββββββββββββββββββββββ | |
| def _load_gt_brain(nsd_basename): | |
| """Load GT fMRI and render brain for an NSD image (threaded).""" | |
| nsd_img_idx = parse_nsd_img_idx(nsd_basename) | |
| if nsd_img_idx is None: | |
| _Session.gt_fmri = None | |
| gt_brain_div.text = '' | |
| steered_brain_div.text = '' | |
| return | |
| _gt_render_token[0] += 1 | |
| my_token = _gt_render_token[0] | |
| doc = curdoc() | |
| def _bg(): | |
| _, fmri = load_gt_fmri(nsd_basename) | |
| if _gt_render_token[0] != my_token: | |
| return | |
| # Use precomputed GT brain render if available | |
| cached_b64 = active_ds().get('gt_brain_cache', {}).get(nsd_img_idx) | |
| if cached_b64 is not None: | |
| b64 = cached_b64 | |
| else: | |
| b64 = (render_fmri_brain_compact_b64(fmri, 'GT Brain Response') | |
| if fmri is not None else None) | |
| def _apply(fmri=fmri, b64=b64): | |
| if _gt_render_token[0] != my_token: | |
| return | |
| _Session.gt_fmri = fmri | |
| gt_brain_div.text = ( | |
| f'<img src="data:image/png;base64,{b64}" ' | |
| f'style="max-width:100%"/>' | |
| if b64 else '') | |
| _update_steered_brain() | |
| doc.add_next_tick_callback(_apply) | |
| threading.Thread(target=_bg, daemon=True).start() | |
| # ββ Public API βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def add_feature(feat: int, lam: float = 3.0, | |
| threshold: float = 0.10): | |
| """Add a feature to the steering list.""" | |
| err = validate_feature(feat) | |
| if err: | |
| _dd_status_div.text = ( | |
| f'<span style="color:#dc2626">{err}</span>') | |
| return | |
| if feat in list(_dd_source.data['feat']): | |
| _dd_status_div.text = ( | |
| f'<i style="color:#6b7280">' | |
| f'Feature {feat} already in list.</i>') | |
| return | |
| entry = make_steering_entry(feat, lam, threshold) | |
| new_data = {k: list(v) for k, v in _dd_source.data.items()} | |
| new_data['feat'].append(entry['feat']) | |
| new_data['name'].append(entry['name']) | |
| new_data['lam'].append(entry['lam']) | |
| new_data['threshold'].append(entry['threshold']) | |
| _dd_source.data = new_data | |
| _dd_status_div.text = ( | |
| f'<i style="color:#6b7280">' | |
| f'Feature {feat} added to steering.</i>') | |
| def set_preset(entries: list, label: str = ''): | |
| """Replace the steering list with preset entries.""" | |
| new_data = dict(feat=[], name=[], lam=[], threshold=[]) | |
| for raw in entries: | |
| e = make_steering_entry( | |
| int(raw['feat']), | |
| float(raw.get('lam', 3.0)), | |
| float(raw.get('threshold', 0.10)), | |
| ) | |
| new_data['feat'].append(e['feat']) | |
| new_data['name'].append(e['name']) | |
| new_data['lam'].append(e['lam']) | |
| new_data['threshold'].append(e['threshold']) | |
| _dd_source.data = new_data | |
| if label: | |
| _dd_status_div.text = ( | |
| f'<i style="color:#6b7280">Loaded preset: {label}</i>') | |
| def set_nsd_sample(basename: str): | |
| """Update the NSD sample being steered and load its GT brain.""" | |
| if basename == _Session.nsd_basename: | |
| return | |
| _Session.nsd_basename = basename | |
| _load_gt_brain(basename) | |
| def load_patch_image(image_label: str): | |
| """Load an image into the patch explorer by name/index string. | |
| This triggers the full load chain: image display, patch activations, | |
| NSD sample detection, and GT brain rendering. | |
| """ | |
| _patch_img_input.value = image_label | |
| _on_load_image() | |
| # ββ Feature action bridge (remove / set_lam from HTML) βββββββ | |
| def _on_feat_action(attr, old, new): | |
| msg = new.split('|')[0] | |
| if msg.startswith('remove_feat:'): | |
| try: | |
| feat = int(msg.split(':')[1]) | |
| except (ValueError, IndexError): | |
| return | |
| feats = list(_dd_source.data['feat']) | |
| if feat not in feats: | |
| return | |
| idx = feats.index(feat) | |
| new_data = {k: [v for i, v in enumerate(vals) if i != idx] | |
| for k, vals in _dd_source.data.items()} | |
| _dd_source.data = new_data | |
| _dd_status_div.text = '' | |
| elif msg.startswith('set_lam:'): | |
| parts = msg.split(':', 2) | |
| if len(parts) != 3: | |
| return | |
| try: | |
| feat = int(parts[1]) | |
| new_val = float(parts[2]) | |
| except ValueError: | |
| return | |
| feats = list(_dd_source.data['feat']) | |
| if feat not in feats: | |
| return | |
| idx = feats.index(feat) | |
| new_lams = list(_dd_source.data['lam']) | |
| new_lams[idx] = new_val | |
| new_data = dict(_dd_source.data) | |
| new_data['lam'] = new_lams | |
| _dd_source.data = new_data | |
| feat_action_bridge.on_change('value', _on_feat_action) | |
| curdoc().js_on_event('document_ready', CustomJS( | |
| args=dict(bridge=feat_action_bridge), | |
| code=""" | |
| window._dd_feat_action = function(msg) { | |
| bridge.value = msg + '|' + Date.now(); | |
| }; | |
| """, | |
| )) | |
| # ββ DynaDiff reconstruction ββββββββββββββββββββββββββββββββββ | |
| def _reconstruct_thread(sample_idxs, steerings, doc, | |
| nsd_img_idx=None): | |
| try: | |
| resp = run_reconstruction( | |
| sample_idxs, steerings, seed=42, | |
| nsd_img_idx=nsd_img_idx) | |
| steer_b64 = resp.get('steered_img') | |
| if steer_b64: | |
| html = ( | |
| f'<img src="data:image/png;base64,{steer_b64}" ' | |
| f'style="max-width:100%;border-radius:4px;' | |
| f'border:1px solid #ddd"/>') | |
| else: | |
| html = ('<div style="color:#aaa;font-style:italic">' | |
| 'No steered output.</div>') | |
| def _apply(html=html): | |
| _dd_output_div.text = html | |
| _dd_status_div.text = '' | |
| _dd_run_btn.disabled = False | |
| doc.add_next_tick_callback(_apply) | |
| except Exception as exc: | |
| msg = str(exc) | |
| def _show_err(msg=msg): | |
| _dd_status_div.text = ( | |
| f'<span style="color:#dc2626">Error: {msg}</span>') | |
| _dd_run_btn.disabled = False | |
| doc.add_next_tick_callback(_show_err) | |
| def _on_reconstruct(): | |
| feats, lams, thrs = _steerings_from_source() | |
| # Prefer the currently loaded patch image's NSD basename | |
| nsd_basename = _Session.nsd_basename | |
| if _Session.img_idx is not None: | |
| img_basename = resolve_nsd_basename(_Session.img_idx) | |
| if img_basename: | |
| nsd_basename = img_basename | |
| _Session.nsd_basename = nsd_basename | |
| sample_idxs, steerings, err = validate_reconstruction( | |
| nsd_basename, feats, lams, thrs) | |
| if err: | |
| _dd_status_div.text = ( | |
| f'<span style="color:#dc2626">{err}</span>') | |
| return | |
| nsd_img_idx = parse_nsd_img_idx(nsd_basename) | |
| _dd_run_btn.disabled = True | |
| _dd_status_div.text = ( | |
| '<i style="color:#6b7280">' | |
| 'Running DynaDiff reconstructionβ¦</i>') | |
| threading.Thread( | |
| target=_reconstruct_thread, | |
| args=(sample_idxs, steerings, curdoc(), nsd_img_idx), | |
| daemon=True, | |
| ).start() | |
| _dd_run_btn.on_click(_on_reconstruct) | |
| dynadiff_panel = column( | |
| feat_action_bridge, | |
| row(_dd_run_btn, _dd_status_div), | |
| _dd_output_div, | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Layout exports | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| patch_explorer_panel = column( | |
| _patch_fig, | |
| _patch_info_div, | |
| _patch_results_div, | |
| ) | |