Spaces:
Sleeping
Sleeping
| """ | |
| Interactive SAE Feature Explorer - Bokeh Server App. | |
| Visualizes SAE features with: | |
| - UMAP scatter plot of features (activation-based and dictionary-based) | |
| - Click a feature to see its top-activating images with heatmap overlays | |
| - 75th percentile images for distribution understanding | |
| - Patch explorer: click patches of any image to find active features | |
| - Feature naming: assign names to features, saved to JSON, searchable | |
| All display is driven by pre-computed sidecars (_heatmaps.pt, _patch_acts.pt). | |
| No GPU or model weights are required at serve time. | |
| Launch: | |
| bokeh serve explorer_app.py --port 5006 --allow-websocket-origin="*" \ | |
| --session-token-expiration 86400 \ | |
| --args \ | |
| --data ../../smart_init_stability_SAE/explorer_data_d32000_k160_val.pt \ | |
| --image-dir /scratch.global/lee02328/val \ | |
| --extra-image-dir /scratch.global/lee02328/coco/val2017 \ | |
| --primary-label "DINOv3 L24 Spatial (d=32K)" \ | |
| --compare-data ../../smart_init_stability_SAE/explorer_data_18.pt \ | |
| --compare-labels "DINOv3 L18 Spatial (d=20K)" \ | |
| --phi-dir /path/to/phis \ | |
| --brain-data /path/to/brain_meis_dinov3.pt \ | |
| --brain-thumbnails /path/to/nsd_thumbs | |
| Then SSH tunnel: ssh -L 5006:<node>:5006 <user>@<login-node> | |
| Open: http://localhost:5006/explorer_app | |
| """ | |
| import argparse | |
| import os | |
| import io | |
| import json | |
| import base64 | |
| import random | |
| import threading | |
| from collections import OrderedDict | |
| from functools import partial | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import matplotlib.colors as mcolors | |
| from PIL import Image | |
| import sys | |
| sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'src')) | |
| from clip_utils import load_clip, compute_text_embeddings | |
| from bokeh.io import curdoc | |
| from bokeh.layouts import column, row | |
| from bokeh.events import MouseMove | |
| from bokeh.models import ( | |
| ColumnDataSource, HoverTool, Div, Select, TextInput, Button, | |
| DataTable, TableColumn, NumberFormatter, IntEditor, NumberEditor, | |
| Slider, Toggle, RadioButtonGroup, CustomJS, | |
| ) | |
| from bokeh.plotting import figure | |
| from bokeh.palettes import Turbo256 | |
| from bokeh.transform import linear_cmap | |
| # ---------- Parse args ---------- | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--data", type=str, required=True) | |
| parser.add_argument("--image-dir", type=str, required=True, | |
| help="Primary image directory used during precompute") | |
| parser.add_argument("--extra-image-dir", type=str, default=[], nargs="*", | |
| help="Additional image directories used during precompute") | |
| parser.add_argument("--thumb-size", type=int, default=256) | |
| parser.add_argument("--inference-cache-size", type=int, default=64, | |
| help="Number of images to keep in the patch-activations LRU cache") | |
| parser.add_argument("--names-file", type=str, default=None, | |
| help="Path to JSON file for saving feature names " | |
| "(default: <data>_feature_names.json)") | |
| parser.add_argument("--compare-data", type=str, nargs="*", default=[], | |
| help="Additional explorer_data.pt files to show in cross-dataset " | |
| "comparison panel (e.g. layer 18, CLS SAE)") | |
| parser.add_argument("--compare-labels", type=str, nargs="*", default=[], | |
| help="Display labels for each --compare-data file") | |
| parser.add_argument("--primary-label", type=str, default="Primary", | |
| help="Display label for the primary --data file") | |
| parser.add_argument("--clip-model", type=str, default="openai/clip-vit-large-patch14", | |
| help="HuggingFace CLIP model ID for free-text search " | |
| "(only loaded on first out-of-vocab query)") | |
| parser.add_argument("--google-api-key", type=str, default=None, | |
| help="Google API key for Gemini auto-interp button " | |
| "(default: GOOGLE_API_KEY env var)") | |
| parser.add_argument("--sae-url", type=str, default=None, | |
| help="Download URL for the primary dataset's SAE weights — " | |
| "shown as a link in the summary panel") | |
| parser.add_argument("--compare-sae-urls", type=str, nargs="*", default=[], | |
| help="Download URLs for each --compare-data dataset's SAE weights (in order)") | |
| parser.add_argument("--phi-dir", type=str, default=None, | |
| help="Directory containing Phi_cv_*.npy, phi_c_*.npy, voxel_coords.npy " | |
| "(brain-alignment data; enables cortical profile and brain leverage features)") | |
| parser.add_argument("--phi-model", type=str, default=None, | |
| help="Model name substring to match phi files (e.g. 'dinov3', 'dinov2', 'clip_encoder'). " | |
| "Default: pick largest Phi_cv_*.npy by file size.") | |
| parser.add_argument("--dynadiff-dir", type=str, default=None, | |
| help="Path to the local dynadiff repo. " | |
| "When provided (with --phi-dir), enables the brain steering panel.") | |
| parser.add_argument("--dynadiff-checkpoint", type=str, | |
| default="dynadiff_padded_sub01.pth", | |
| help="Checkpoint filename or path (relative to --dynadiff-dir or absolute).") | |
| parser.add_argument("--dynadiff-h5", type=str, | |
| default="extracted_training_data/consolidated_sub01.h5", | |
| help="Path to fMRI H5 (relative to --dynadiff-dir or absolute).") | |
| parser.add_argument("--brain-data", type=str, default=None, | |
| help="Path to brain_meis.pt produced by precompute_nsd_meis.py. " | |
| "Adds 'NSD Brain (DINOv2 L11)' as a selectable dataset in the " | |
| "dataset dropdown, using NSD images and NSD-based UMAPs.") | |
| parser.add_argument("--brain-thumbnails", type=str, default=None, | |
| help="Directory containing NSD JPEG thumbnails (nsd_XXXXX.jpg). " | |
| "Required with --brain-data if image_paths are not absolute paths.") | |
| parser.add_argument("--brain-label", type=str, default="NSD Brain (DINOv2 L11)", | |
| help="Dataset label shown in the dropdown for --brain-data.") | |
| parser.add_argument("--sae-path", type=str, default=None, | |
| help="Path to SAE state-dict .pth file. When provided the backbone + SAE " | |
| "are loaded on GPU so any image can be explored without pre-computed " | |
| "patch activations.") | |
| parser.add_argument("--backbone", type=str, default="dinov2", | |
| help="Backbone name matching the SAE (default: dinov2).") | |
| parser.add_argument("--layer", type=int, default=11, | |
| help="Backbone layer used during SAE training (default: 11).") | |
| parser.add_argument("--top-k", type=int, default=100, | |
| help="SAE top-k sparsity (default: 100).") | |
| args = parser.parse_args() | |
| # ---------- Lazy CLIP model (loaded on first free-text query) ---------- | |
| # _clip_handle[0] is None until the first out-of-vocab query is issued. | |
| _clip_handle = [None] # (model, processor, device) | |
| def _get_clip(): | |
| """Load CLIP once and cache it.""" | |
| if _clip_handle[0] is None: | |
| _dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| print(f"[CLIP] Loading {args.clip_model} on {_dev} (first free-text query)...") | |
| _m, _p = load_clip(_dev, model_name=args.clip_model) | |
| _clip_handle[0] = (_m, _p, _dev) | |
| print("[CLIP] Ready.") | |
| return _clip_handle[0] | |
| # ---------- GPU backbone + SAE runner (optional, lazy-loaded) ---------- | |
| _gpu_runner = [None] # (forward_fn, sae, transform_fn, n_reg, extract_tokens_fn, backbone_name, device) or None | |
| def _get_gpu_runner(): | |
| """Load backbone + SAE on GPU once; return (forward_fn, sae, transform_fn, device) or None.""" | |
| if _gpu_runner[0] is not None: | |
| return _gpu_runner[0] | |
| if not args.sae_path: | |
| return None | |
| if not torch.cuda.is_available(): | |
| print("[GPU runner] No CUDA device — on-the-fly inference disabled.") | |
| return None | |
| import sys, os as _os | |
| sys.path.insert(0, _os.path.abspath(_os.path.join(_os.path.dirname(__file__), '..', 'src'))) | |
| from backbone_runners import load_batched_backbone | |
| from precompute_utils import load_sae, extract_tokens as _et | |
| _dev = torch.device("cuda:0") | |
| print(f"[GPU runner] Loading {args.backbone} layer {args.layer} + SAE on {_dev} ...") | |
| _fwd, _d_hidden, _n_reg, _tfm = load_batched_backbone(args.backbone, args.layer, _dev) | |
| _sae = load_sae(args.sae_path, _d_hidden, d_model, args.top_k, _dev) | |
| _gpu_runner[0] = (_fwd, _sae, _tfm, _n_reg, _et, args.backbone, _dev) | |
| print("[GPU runner] Ready.") | |
| return _gpu_runner[0] | |
| def _run_gpu_inference(pil_img): | |
| """Run pil_img through backbone→SAE; return (n_patches, d_sae) float32 numpy or None.""" | |
| runner = _get_gpu_runner() | |
| if runner is None: | |
| return None | |
| _fwd, _sae, _tfm, _n_reg, _et, _bname, _dev = runner | |
| tensor = _tfm(pil_img).unsqueeze(0).to(_dev) # (1, C, H, W) | |
| with torch.inference_mode(): | |
| hidden = _fwd(tensor) # (1, n_tokens, d_hidden) | |
| tokens = _et(hidden, _bname, 'spatial', _n_reg) # (1, n_patches, d_hidden) | |
| flat = tokens.reshape(-1, tokens.shape[-1]) # (n_patches, d_hidden) | |
| _, z, _ = _sae(flat) # (n_patches, d_sae) | |
| return z.cpu().float().numpy() | |
| # ---------- Brain alignment (Phi) data ---------- | |
| # Loaded once from --phi-dir; None when not provided. | |
| # Phi_cv: (C, V) concept-by-voxel alignment matrix (mmap) | |
| # phi_c: (C,) per-concept cortical leverage scores | |
| # _voxel_coords: (V, 3) MNI coordinates of each voxel | |
| # _voxel_to_vertex: (V,) mapping from fsaverage vertices → voxel indices (surface-space phi only) | |
| _phi_cv = None | |
| _phi_c = None | |
| _voxel_coords = None | |
| _voxel_to_vertex = None | |
| _N_VOXELS_DD = 15724 # DynaDiff voxel count | |
| _N_VERTS_FSAVG = 37984 # fsaverage vertex count | |
| if args.phi_dir: | |
| _pdir = args.phi_dir | |
| _phi_model_key = (args.phi_model or "").lower() | |
| def _pick_phi_file(candidates, model_key): | |
| """Pick best phi file: model_key substring match, else largest by size.""" | |
| if not candidates: | |
| return None | |
| if model_key: | |
| matched = [f for f in candidates if model_key in f.lower()] | |
| if matched: | |
| return sorted(matched)[0] | |
| print(f"[Phi] WARNING: --phi-model '{model_key}' matched no files in {candidates}; " | |
| "falling back to largest file") | |
| # Fall back to largest file by size | |
| return max(candidates, key=lambda f: os.path.getsize(os.path.join(_pdir, f))) | |
| # --- Phi_cv matrix --- | |
| _phi_mat_files = [f for f in os.listdir(_pdir) | |
| if f.lower().startswith('phi_cv') and f.endswith('.npy')] | |
| _phi_mat_pick = _pick_phi_file(_phi_mat_files, _phi_model_key) | |
| if _phi_mat_pick: | |
| _phi_path = os.path.join(_pdir, _phi_mat_pick) | |
| _phi_cv = np.load(_phi_path, mmap_mode='r') | |
| print(f"[Phi] Loaded {_phi_mat_pick}: shape {_phi_cv.shape}, dtype {_phi_cv.dtype}") | |
| if _phi_cv.shape[1] == _N_VERTS_FSAVG: | |
| _v2v_path = os.path.join(_pdir, 'voxel_to_vertex_map.npy') | |
| if os.path.exists(_v2v_path): | |
| _voxel_to_vertex = np.load(_v2v_path) | |
| print(f"[Phi] Surface-space phi; loaded voxel_to_vertex_map: {_voxel_to_vertex.shape}") | |
| else: | |
| print("[Phi] WARNING: surface-space phi but voxel_to_vertex_map.npy not found") | |
| elif _phi_cv.shape[1] == _N_VOXELS_DD: | |
| print("[Phi] Voxel-space phi detected.") | |
| else: | |
| print(f"[Phi] WARNING: unexpected phi dimension {_phi_cv.shape[1]}") | |
| else: | |
| print(f"[Phi] WARNING: no Phi_cv_*.npy found in {_pdir}") | |
| # --- phi_c leverage scores --- | |
| _phi_c_files = [f for f in os.listdir(_pdir) | |
| if f.lower().startswith('phi_c') | |
| and not f.lower().startswith('phi_cv') | |
| and f.endswith('.npy')] | |
| _phi_c_pick = _pick_phi_file(_phi_c_files, _phi_model_key) | |
| if _phi_c_pick: | |
| _phi_c = np.load(os.path.join(_pdir, _phi_c_pick)) | |
| print(f"[Phi] Leverage scores {_phi_c_pick}: shape {_phi_c.shape}, " | |
| f"range [{_phi_c.min():.4f}, {_phi_c.max():.4f}]") | |
| else: | |
| print(f"[Phi] No phi_c_*.npy found in {_pdir} — leverage scores unavailable") | |
| # --- Voxel coordinates --- | |
| _coords_path = os.path.join(_pdir, 'voxel_coords.npy') | |
| if os.path.exists(_coords_path): | |
| _voxel_coords = np.load(_coords_path) | |
| print(f"[Phi] Voxel coordinates: {_voxel_coords.shape}") | |
| else: | |
| print("[Phi] voxel_coords.npy not found — cortical scatter unavailable") | |
| HAS_PHI = _phi_cv is not None | |
| # ---------- DynaDiff steering (in-process) ---------- | |
| # Enabled when --dynadiff-dir is provided and --phi-dir is also set. | |
| _dd_loader = None | |
| HAS_DYNADIFF = False | |
| if args.dynadiff_dir and os.path.isdir(args.dynadiff_dir): | |
| if not HAS_PHI: | |
| print("[DynaDiff] WARNING: --phi-dir not set; steering panel requires Phi data. Disabling.") | |
| else: | |
| try: | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from dynadiff_loader import get_loader | |
| _h5 = args.dynadiff_h5 | |
| if not os.path.isabs(_h5): | |
| _h5 = os.path.join(args.dynadiff_dir, _h5) | |
| _dd_loader = get_loader( | |
| dynadiff_dir = args.dynadiff_dir, | |
| checkpoint = args.dynadiff_checkpoint, | |
| h5_path = _h5, | |
| nsd_thumb_dir = args.brain_thumbnails, | |
| subject_idx = 0, | |
| ) | |
| HAS_DYNADIFF = True | |
| print(f"[DynaDiff] In-process loader ready (checkpoint: {args.dynadiff_checkpoint})") | |
| except Exception as _dd_err: | |
| print(f"[DynaDiff] WARNING: Could not start in-process loader ({_dd_err}). " | |
| "Steering panel will be disabled.") | |
| # ---------- Load all datasets into a unified list ---------- | |
| def _load_dataset_dict(path, label, sae_url=None): | |
| """Load one explorer_data.pt file and return a unified dataset dict.""" | |
| print(f"Loading [{label}] from {path} ...") | |
| d = torch.load(path, map_location='cpu', weights_only=False) | |
| cs = d.get('clip_text_scores', None) | |
| names_file = (args.names_file if path == args.data and args.names_file | |
| else os.path.splitext(path)[0] + '_feature_names.json') | |
| feat_names = {} | |
| if os.path.exists(names_file): | |
| with open(names_file) as _nf: | |
| feat_names = {int(k): v for k, v in json.load(_nf).items()} | |
| auto_interp_file = os.path.splitext(path)[0] + '_auto_interp.json' | |
| auto_interp = {} | |
| if os.path.exists(auto_interp_file): | |
| with open(auto_interp_file) as _af: | |
| auto_interp = {int(k): v for k, v in json.load(_af).items()} | |
| print(f" Loaded {len(auto_interp)} auto-interp labels from " | |
| f"{os.path.basename(auto_interp_file)}") | |
| entry = { | |
| 'label': label, | |
| 'path': path, | |
| 'image_paths': d['image_paths'], | |
| 'd_model': d['d_model'], | |
| 'n_images': d['n_images'], | |
| 'patch_grid': d['patch_grid'], | |
| 'image_size': d['image_size'], | |
| 'token_type': d.get('token_type', 'spatial'), | |
| 'backbone': d.get('backbone', 'dinov3'), | |
| 'top_img_idx': d['top_img_idx'], | |
| 'top_img_act': d['top_img_act'], | |
| 'mean_img_idx': d.get('mean_img_idx', d['top_img_idx']), | |
| 'mean_img_act': d.get('mean_img_act', d['top_img_act']), | |
| 'p75_img_idx': d['p75_img_idx'], | |
| 'p75_img_act': d['p75_img_act'], | |
| 'nsd_top_img_idx': d.get('nsd_top_img_idx', None), | |
| 'nsd_top_img_act': d.get('nsd_top_img_act', None), | |
| 'nsd_mean_img_idx': d.get('nsd_mean_img_idx', None), | |
| 'nsd_mean_img_act': d.get('nsd_mean_img_act', None), | |
| 'feature_frequency': d['feature_frequency'], | |
| 'feature_mean_act': d['feature_mean_act'], | |
| 'feature_p75_val': d['feature_p75_val'], | |
| 'umap_coords': d['umap_coords'].numpy(), | |
| 'dict_umap_coords': d['dict_umap_coords'].numpy() if 'dict_umap_coords' in d else np.full((d['d_model'], 2), np.nan, dtype=np.float32), | |
| 'clip_scores': cs, | |
| 'clip_vocab': d.get('clip_text_vocab', None), | |
| 'clip_embeds': d.get('clip_feature_embeds', None), | |
| 'nsd_clip_embeds': d.get('nsd_clip_feature_embeds', None), | |
| 'clip_scores_f32': cs.float() if cs is not None else None, | |
| 'inference_cache': OrderedDict(), | |
| 'names_file': names_file, | |
| 'auto_interp_file': auto_interp_file, | |
| 'feature_names': feat_names, | |
| 'auto_interp_names': auto_interp, | |
| } | |
| # Load pre-computed heatmaps sidecar if present | |
| sidecar = os.path.splitext(path)[0] + '_heatmaps.pt' | |
| if os.path.exists(sidecar): | |
| print(f" Loading pre-computed heatmaps from {os.path.basename(sidecar)} ...") | |
| hm = torch.load(sidecar, map_location='cpu', weights_only=True) | |
| entry['top_heatmaps'] = hm.get('top_heatmaps') | |
| entry['mean_heatmaps'] = hm.get('mean_heatmaps') | |
| entry['p75_heatmaps'] = hm.get('p75_heatmaps') | |
| entry['nsd_top_heatmaps'] = hm.get('nsd_top_heatmaps') | |
| entry['nsd_mean_heatmaps'] = hm.get('nsd_mean_heatmaps') | |
| # patch_grid stored in sidecar may differ from data (e.g. --force-spatial on CLS SAE) | |
| entry['heatmap_patch_grid'] = hm.get('patch_grid', d['patch_grid']) | |
| has_hm = 'yes (no GPU needed for heatmaps)' | |
| else: | |
| entry['top_heatmaps'] = None | |
| entry['mean_heatmaps'] = None | |
| entry['p75_heatmaps'] = None | |
| entry['nsd_top_heatmaps'] = None | |
| entry['nsd_mean_heatmaps'] = None | |
| entry['heatmap_patch_grid'] = d['patch_grid'] | |
| has_hm = 'no' | |
| # Load pre-computed patch activations sidecar if present. | |
| # Enables complete GPU-free patch exploration for any image covered by the file. | |
| pa_sidecar = os.path.splitext(path)[0] + '_patch_acts.pt' | |
| if os.path.exists(pa_sidecar): | |
| print(f" Loading pre-computed patch acts from {os.path.basename(pa_sidecar)} ...") | |
| pa = torch.load(pa_sidecar, map_location='cpu', weights_only=True) | |
| img_to_row = {int(idx): row for row, idx in enumerate(pa['img_indices'].tolist())} | |
| entry['patch_acts'] = { | |
| 'feat_indices': pa['feat_indices'], # (n_unique, n_patches, top_k) int16 | |
| 'feat_values': pa['feat_values'], # (n_unique, n_patches, top_k) float16 | |
| 'img_to_row': img_to_row, | |
| } | |
| print(f" patch_acts: {len(img_to_row)} images covered (GPU-free patch explorer)") | |
| else: | |
| entry['patch_acts'] = None | |
| entry['sae_url'] = sae_url | |
| print(f" d={entry['d_model']}, n={entry['n_images']}, token={entry['token_type']}, " | |
| f"backbone={entry['backbone']}, clip={'yes' if (cs is not None or entry.get('clip_embeds') is not None) else 'no'}, " | |
| f"heatmaps={has_hm}, patch_acts={'yes' if entry['patch_acts'] else 'no'}") | |
| return entry | |
| _all_datasets = [] | |
| # ---------- Mutable session state ---------- | |
| class _S: | |
| """Mutable module-level state shared by all Bokeh callbacks. | |
| Using a plain-class namespace avoids the ``[value]`` mutable-list idiom; | |
| attributes can be read and written by any function without ``global`` statements. | |
| """ | |
| active: int = 0 # index into _all_datasets for the current view | |
| render_token: int = 0 # incremented on each feature selection; stale renders bail out | |
| search_filter = None # set of feature indices matching the current name search, or None | |
| color_by: str = "Log Frequency" # which field drives UMAP point colour | |
| hf_push: object = None # active Bokeh timeout handle for debounced HuggingFace upload | |
| patch_img = None # image index currently loaded in the patch explorer | |
| patch_z = None # cached (n_patches, d_model) float32 for the loaded image | |
| def _ds(): | |
| """Return the currently-active dataset dict.""" | |
| return _all_datasets[_S.active] | |
| # Primary dataset — always loaded eagerly | |
| _all_datasets.append(_load_dataset_dict(args.data, args.primary_label, sae_url=args.sae_url)) | |
| # Compare datasets — stored as lazy placeholders; loaded on first access | |
| for _ci, _cpath in enumerate(args.compare_data): | |
| _clabel = (args.compare_labels[_ci] | |
| if args.compare_labels and _ci < len(args.compare_labels) | |
| else os.path.basename(_cpath)) | |
| _csae = (args.compare_sae_urls[_ci] | |
| if args.compare_sae_urls and _ci < len(args.compare_sae_urls) | |
| else None) | |
| _all_datasets.append({'label': _clabel, 'path': _cpath, '_lazy': True, 'sae_url': _csae}) | |
| def _load_brain_dataset_dict(path, label, thumb_dir): | |
| """Load a brain_meis.pt file and return a dataset entry dict. | |
| Brain MEI files share the same entry schema as regular explorer_data.pt files | |
| but have a different on-disk layout (NSD image indices, no CLIP embeddings, etc.) | |
| and may store only basenames in image_paths (resolved via thumb_dir at load time). | |
| Returns None on failure. | |
| """ | |
| print(f"[Brain] Loading NSD dataset from {path} ...") | |
| try: | |
| bd = torch.load(path, map_location='cpu', weights_only=False) | |
| except Exception as err: | |
| print(f"[Brain] WARNING: Failed to load NSD dataset: {err}") | |
| return None | |
| # Resolve image_paths: prepend thumb_dir when paths are stored as basenames, | |
| # or when stored as absolute paths that don't exist on this machine. | |
| raw_paths = bd.get('image_paths', []) | |
| if raw_paths and thumb_dir and ( | |
| not os.path.isabs(raw_paths[0]) or not os.path.exists(raw_paths[0]) | |
| ): | |
| bd_paths = [os.path.join(thumb_dir, os.path.basename(p)) for p in raw_paths] | |
| else: | |
| bd_paths = raw_paths | |
| d_model = bd['d_model'] | |
| nan2 = np.full((d_model, 2), np.nan, dtype=np.float32) | |
| stem = os.path.splitext(path)[0] | |
| entry = { | |
| 'label': label, | |
| 'path': path, | |
| 'image_paths': bd_paths, | |
| 'd_model': d_model, | |
| 'n_images': bd.get('n_images', len(bd_paths)), | |
| 'patch_grid': bd.get('patch_grid', 16), | |
| 'image_size': bd.get('image_size', 224), | |
| 'token_type': bd.get('token_type', 'spatial'), | |
| 'backbone': bd.get('backbone', 'dinov2'), | |
| 'top_img_idx': bd['top_img_idx'], | |
| 'top_img_act': bd['top_img_act'], | |
| 'mean_img_idx': bd.get('mean_img_idx', bd['top_img_idx']), | |
| 'mean_img_act': bd.get('mean_img_act', bd['top_img_act']), | |
| 'p75_img_idx': bd.get('p75_img_idx', torch.full((d_model, 1), -1, dtype=torch.long)), | |
| 'p75_img_act': bd.get('p75_img_act', torch.zeros(d_model, 1)), | |
| 'top_heatmaps': None, | |
| 'mean_heatmaps': None, | |
| 'p75_heatmaps': None, | |
| 'heatmap_patch_grid': bd.get('patch_grid', 16), | |
| 'feature_frequency': bd['feature_frequency'], | |
| 'feature_mean_act': bd['feature_mean_act'], | |
| 'feature_p75_val': bd.get('feature_p75_val', torch.zeros(d_model)), | |
| 'umap_coords': bd['umap_coords'].numpy() if 'umap_coords' in bd else nan2, | |
| 'dict_umap_coords': bd['dict_umap_coords'].numpy() if 'dict_umap_coords' in bd else nan2, | |
| 'clip_scores': bd.get('clip_text_scores', None), | |
| 'clip_vocab': bd.get('clip_text_vocab', None), | |
| 'clip_embeds': bd.get('clip_feature_embeds', None), | |
| 'clip_scores_f32': bd['clip_text_scores'].float() if 'clip_text_scores' in bd else None, | |
| 'inference_cache': OrderedDict(), | |
| 'names_file': stem + '_feature_names.json', | |
| 'auto_interp_file': stem + '_auto_interp.json', | |
| 'feature_names': {}, | |
| 'auto_interp_names': {}, | |
| 'sae_url': None, | |
| 'patch_acts': None, | |
| } | |
| # Load pre-computed heatmaps sidecar if present. | |
| sidecar = stem + '_heatmaps.pt' | |
| if os.path.exists(sidecar): | |
| print(f"[Brain] Loading heatmaps sidecar: {os.path.basename(sidecar)} ...") | |
| bhm = torch.load(sidecar, map_location='cpu', weights_only=False) | |
| entry['top_heatmaps'] = bhm.get('top_heatmaps') | |
| entry['mean_heatmaps'] = bhm.get('mean_heatmaps') | |
| entry['p75_heatmaps'] = bhm.get('p75_heatmaps') | |
| entry['heatmap_patch_grid'] = bhm.get('patch_grid', bd.get('patch_grid', 16)) | |
| print(f"[Brain] Added '{label}' dataset: " | |
| f"d_model={d_model}, n_images={entry['n_images']}, backbone={entry['backbone']}") | |
| return entry | |
| # NSD brain dataset — loaded as a regular dataset entry so it appears in the | |
| # dataset dropdown and drives both the MEI image views and the UMAP. | |
| if args.brain_data and os.path.exists(args.brain_data): | |
| _brain_entry = _load_brain_dataset_dict( | |
| args.brain_data, args.brain_label, args.brain_thumbnails or '') | |
| if _brain_entry is not None: | |
| _all_datasets.append(_brain_entry) | |
| elif args.brain_data: | |
| print(f"[Brain] WARNING: --brain-data file not found: {args.brain_data}") | |
| def _ensure_loaded(idx): | |
| """Load dataset at idx if it is still a lazy placeholder.""" | |
| ds = _all_datasets[idx] | |
| if ds.get('_lazy', False): | |
| print(f"[Lazy load] Loading '{ds['label']}' on first access ...") | |
| _all_datasets[idx] = _load_dataset_dict(ds['path'], ds['label'], sae_url=ds.get('sae_url')) | |
| _basename_to_idx = {} # rebuilt by _apply_dataset_globals; basename/stem → image index | |
| def _build_basename_index(paths): | |
| """Build stem→idx and full-basename→idx lookup for fast filename search.""" | |
| d = {} | |
| for i, p in enumerate(paths): | |
| base = os.path.basename(p) | |
| stem = os.path.splitext(base)[0] | |
| d[base] = i | |
| d[stem] = i | |
| return d | |
| def _apply_dataset_globals(idx): | |
| """Swap every module-level data alias to point at dataset[idx]. | |
| Bokeh callbacks capture module-level names at import time, so the | |
| simplest way to support dataset switching is to rebind these aliases | |
| each time the active dataset changes. All callbacks read these names; | |
| only this function and the initialisation below may write them. | |
| """ | |
| global image_paths, d_model, n_images, patch_grid, image_size, heatmap_patch_grid | |
| global top_img_idx, top_img_act, mean_img_idx, mean_img_act | |
| global p75_img_idx, p75_img_act | |
| global nsd_top_img_idx, nsd_top_img_act, nsd_mean_img_idx, nsd_mean_img_act, HAS_NSD_SUBSET | |
| global top_heatmaps, mean_heatmaps, p75_heatmaps | |
| global nsd_top_heatmaps, nsd_mean_heatmaps | |
| global feature_frequency, feature_mean_act, feature_p75_val | |
| global umap_coords, dict_umap_coords | |
| global freq, mean_act, log_freq, p75_np | |
| global live_mask, live_indices, dict_live_mask, dict_live_indices | |
| global umap_backup | |
| global _clip_scores, _clip_vocab, _clip_embeds, _nsd_clip_embeds, _clip_scores_f32, HAS_CLIP | |
| global feature_names, _names_file, auto_interp_names, _auto_interp_file | |
| global _active_feats | |
| global _basename_to_idx | |
| ds = _all_datasets[idx] | |
| image_paths = ds['image_paths'] | |
| _basename_to_idx = _build_basename_index(image_paths) | |
| d_model = ds['d_model'] | |
| n_images = ds['n_images'] | |
| patch_grid = ds['patch_grid'] | |
| image_size = ds['image_size'] | |
| top_img_idx = ds['top_img_idx'] | |
| top_img_act = ds['top_img_act'] | |
| mean_img_idx = ds['mean_img_idx'] | |
| mean_img_act = ds['mean_img_act'] | |
| p75_img_idx = ds['p75_img_idx'] | |
| p75_img_act = ds['p75_img_act'] | |
| nsd_top_img_idx = ds.get('nsd_top_img_idx') | |
| nsd_top_img_act = ds.get('nsd_top_img_act') | |
| nsd_mean_img_idx = ds.get('nsd_mean_img_idx') | |
| nsd_mean_img_act = ds.get('nsd_mean_img_act') | |
| nsd_top_heatmaps = ds.get('nsd_top_heatmaps') | |
| nsd_mean_heatmaps = ds.get('nsd_mean_heatmaps') | |
| HAS_NSD_SUBSET = nsd_top_img_idx is not None | |
| top_heatmaps = ds.get('top_heatmaps') | |
| mean_heatmaps = ds.get('mean_heatmaps') | |
| p75_heatmaps = ds.get('p75_heatmaps') | |
| heatmap_patch_grid = ds.get('heatmap_patch_grid', patch_grid) | |
| feature_frequency = ds['feature_frequency'] | |
| feature_mean_act = ds['feature_mean_act'] | |
| feature_p75_val = ds['feature_p75_val'] | |
| umap_coords = ds['umap_coords'] | |
| dict_umap_coords = ds['dict_umap_coords'] | |
| _clip_scores = ds['clip_scores'] | |
| _clip_vocab = ds['clip_vocab'] | |
| _clip_embeds = ds['clip_embeds'] | |
| _nsd_clip_embeds = ds.get('nsd_clip_embeds') | |
| _clip_scores_f32 = ds['clip_scores_f32'] | |
| HAS_CLIP = _clip_embeds is not None or (_clip_scores is not None and _clip_vocab is not None) | |
| feature_names = ds['feature_names'] | |
| _names_file = ds['names_file'] | |
| auto_interp_names = ds['auto_interp_names'] | |
| _auto_interp_file = ds['auto_interp_file'] | |
| # Derived arrays used by UMAP, feature list, and callbacks | |
| freq = feature_frequency.numpy() | |
| mean_act = feature_mean_act.numpy() | |
| log_freq = np.log10(freq + 1) | |
| p75_np = feature_p75_val.numpy() | |
| live_mask = ~np.isnan(umap_coords[:, 0]) | |
| live_indices = np.where(live_mask)[0] | |
| dict_live_mask = ~np.isnan(dict_umap_coords[:, 0]) | |
| dict_live_indices = np.where(dict_live_mask)[0] | |
| umap_backup = dict( | |
| act_x=umap_coords[live_mask, 0].tolist(), | |
| act_y=umap_coords[live_mask, 1].tolist(), | |
| act_feat=live_indices.tolist(), | |
| dict_x=dict_umap_coords[dict_live_mask, 0].tolist(), | |
| dict_y=dict_umap_coords[dict_live_mask, 1].tolist(), | |
| dict_feat=dict_live_indices.tolist(), | |
| ) | |
| # Features that fired at least once — used by the Random button. | |
| _active_feats = [int(i) for i in range(d_model) if feature_frequency[i].item() > 0] | |
| # Initialise all globals from the primary dataset | |
| _apply_dataset_globals(0) | |
| def _save_names(): | |
| with open(_names_file, 'w') as _f: | |
| json.dump({str(k): v for k, v in sorted(feature_names.items())}, _f, indent=2) | |
| print(f"Saved {len(feature_names)} feature names to {_names_file}") | |
| _schedule_hf_push(_names_file) | |
| def _save_auto_interp(): | |
| with open(_auto_interp_file, 'w') as _f: | |
| json.dump({str(k): v for k, v in sorted(auto_interp_names.items())}, _f, indent=2) | |
| print(f"Saved {len(auto_interp_names)} auto-interp labels to {_auto_interp_file}") | |
| _schedule_hf_push(_auto_interp_file) | |
| def _schedule_hf_push(names_file_path): | |
| """Debounce HF dataset upload: waits 2 s after the last save, then pushes in a thread. | |
| No-op if HF_TOKEN / HF_DATASET_REPO are not set (i.e. running locally).""" | |
| hf_token = os.environ.get("HF_TOKEN") | |
| hf_repo = os.environ.get("HF_DATASET_REPO") | |
| if not (hf_token and hf_repo): | |
| return | |
| # Cancel any already-pending push for this session. | |
| if _S.hf_push is not None: | |
| try: | |
| curdoc().remove_timeout_callback(_S.hf_push) | |
| except Exception: | |
| pass | |
| def _push_thread(): | |
| try: | |
| from huggingface_hub import upload_file | |
| upload_file( | |
| path_or_fileobj=names_file_path, | |
| path_in_repo=os.path.basename(names_file_path), | |
| repo_id=hf_repo, | |
| repo_type="dataset", | |
| token=hf_token, | |
| commit_message="Update feature names", | |
| ) | |
| print(f" Pushed {os.path.basename(names_file_path)} to HF dataset {hf_repo}") | |
| except Exception as e: | |
| print(f" Warning: could not push feature names to HF: {e}") | |
| def _fire(): | |
| _S.hf_push = None | |
| threading.Thread(target=_push_thread, daemon=True).start() | |
| _S.hf_push = curdoc().add_timeout_callback(_fire, 2000) | |
| def _display_name(feat: int) -> str: | |
| """Return the label to show in tables: manual label takes priority over auto-interp.""" | |
| m = feature_names.get(feat) | |
| if m: | |
| return m | |
| a = auto_interp_names.get(feat) | |
| return f"[auto] {a}" if a else "" | |
| def compute_patch_activations(img_idx): | |
| """Return (n_patches, d_sae) float32 for the active dataset, or None. | |
| Priority order: | |
| 1. LRU cache | |
| 2. Pre-computed patch_acts lookup — complete activations for covered images | |
| 3. GPU live inference — full activations via backbone + SAE (requires --sae-path) | |
| Uses a per-dataset LRU cache. | |
| """ | |
| ds = _all_datasets[_S.active] | |
| cache = ds['inference_cache'] | |
| # 1. LRU cache | |
| if img_idx in cache: | |
| cache.move_to_end(img_idx) | |
| return cache[img_idx] | |
| z_np = None | |
| # 2. Try patch_acts lookup (complete activations for covered images) | |
| pa = ds.get('patch_acts') | |
| if pa is not None: | |
| row = pa['img_to_row'].get(img_idx) | |
| if row is not None: | |
| fi = pa['feat_indices'][row].numpy() # (n_patches, top_k) int16 | |
| fv = pa['feat_values'][row].float().numpy() # (n_patches, top_k) float32 | |
| n_p = fi.shape[0] | |
| z_np = np.zeros((n_p, ds['d_model']), dtype=np.float32) | |
| z_np[np.arange(n_p)[:, None], fi.astype(np.int32)] = fv | |
| # 3. GPU live inference | |
| if z_np is None: | |
| try: | |
| pil = load_image(img_idx) | |
| z_np = _run_gpu_inference(pil) | |
| except Exception as _e: | |
| print(f"[GPU runner] inference failed for img {img_idx}: {_e}") | |
| z_np = None | |
| if z_np is not None: | |
| cache[img_idx] = z_np | |
| if len(cache) > args.inference_cache_size: | |
| cache.popitem(last=False) | |
| return z_np | |
| # ---------- Alpha colormap ---------- | |
| def create_alpha_cmap(base='jet'): | |
| base_cmap = plt.cm.get_cmap(base) | |
| colors = base_cmap(np.arange(base_cmap.N)) | |
| colors[:, -1] = np.linspace(0.0, 1.0, base_cmap.N) | |
| return mcolors.LinearSegmentedColormap.from_list('alpha_cmap', colors) | |
| ALPHA_JET = create_alpha_cmap('jet') | |
| # ---------- Image helpers ---------- | |
| THUMB = args.thumb_size | |
| def _parse_img_label(value): | |
| """Parse an image label into an integer index. | |
| Accepts: | |
| - exact filename match: 'nsd_31215.jpg', 'nsd_31215', '000000204103.jpg' | |
| - bare integer index: '42' | |
| - ImageNet-style synset: 'n02655020_475' (basename lookup, then trailing-int fallback) | |
| Basename lookup is tried before integer parsing so that zero-padded COCO | |
| filenames like '000000204103' are resolved to the correct dataset entry | |
| rather than being misinterpreted as raw index 204103. | |
| Raises ValueError on failure. | |
| """ | |
| val = value.strip() | |
| # Basename / stem lookup first — handles COCO zero-padded names and any | |
| # filename where the numeric value differs from the dataset index. | |
| key = os.path.splitext(val)[0] # strip extension if given | |
| if key in _basename_to_idx: | |
| return _basename_to_idx[key] | |
| if val in _basename_to_idx: | |
| return _basename_to_idx[val] | |
| # Fall back to bare integer index | |
| try: | |
| return int(val) | |
| except ValueError: | |
| pass | |
| # Last-resort: extract trailing integer after final underscore | |
| return int(val.rsplit('_', 1)[-1]) | |
| def _resolve_img_path(stored_path): | |
| """Resolve a stored image path, searching image dirs first. Returns None on failure.""" | |
| if os.path.isabs(stored_path) and os.path.exists(stored_path): | |
| return stored_path | |
| basename = os.path.basename(stored_path) | |
| for base in filter(None, [args.image_dir] + (args.extra_image_dir or [])): | |
| candidate = os.path.join(base, basename) | |
| if os.path.exists(candidate): | |
| return candidate | |
| if os.path.exists(stored_path): | |
| return stored_path | |
| return None | |
| def _load_image_by_path(path): | |
| """Load a single image, searching args.image_dir / args.extra_image_dir first.""" | |
| resolved = _resolve_img_path(path) or path | |
| return Image.open(resolved).convert("RGB") | |
| def load_image(img_idx): | |
| """Load an image by index using the active dataset's image_paths.""" | |
| return _load_image_by_path(image_paths[img_idx]) | |
| def render_heatmap_overlay(img_idx, heatmap_16x16, size=THUMB, cmap=ALPHA_JET, alpha=1.0): | |
| """Render image with heatmap overlay.""" | |
| img = load_image(img_idx).resize((size, size), Image.BILINEAR) | |
| img_arr = np.array(img).astype(np.float32) / 255.0 | |
| heatmap = heatmap_16x16.numpy() if isinstance(heatmap_16x16, torch.Tensor) else heatmap_16x16 | |
| heatmap = heatmap.astype(np.float32) | |
| heatmap_up = cv2.resize(heatmap, (size, size), interpolation=cv2.INTER_CUBIC) | |
| hmax = heatmap_up.max() | |
| heatmap_norm = heatmap_up / hmax if hmax > 0 else heatmap_up | |
| overlay = cmap(heatmap_norm) | |
| ov_alpha = overlay[:, :, 3:4] * alpha | |
| blended = img_arr * (1 - ov_alpha) + overlay[:, :, :3] * ov_alpha | |
| blended = np.clip(blended * 255, 0, 255).astype(np.uint8) | |
| return Image.fromarray(blended) | |
| def render_zoomed_overlay(img_idx, heatmap_16x16, size=THUMB, pg=None, alpha=None, | |
| center='peak'): | |
| """Render heatmap overlay cropped to the zoom window at the current slider level. | |
| At full zoom (slider == pg) the whole image is returned. At lower values | |
| the overlay is cropped to a zoom_patches × zoom_patches patch window and | |
| upscaled to `size`. | |
| center='peak' — window centred on the argmax patch (good for max-ranked images) | |
| center='centroid' — window centred on the activation-weighted centroid | |
| (good for mean-ranked images where activation is diffuse) | |
| """ | |
| if pg is None: | |
| pg = heatmap_patch_grid | |
| if alpha is None: | |
| alpha = heatmap_alpha_slider.value | |
| heatmap = heatmap_16x16.numpy() if isinstance(heatmap_16x16, torch.Tensor) else heatmap_16x16 | |
| # Render full overlay at native resolution so the crop is high quality | |
| overlay = render_heatmap_overlay(img_idx, heatmap, size=image_size, alpha=alpha) | |
| zoom_patches = int(zoom_slider.value) | |
| if zoom_patches >= pg: | |
| return overlay.resize((size, size), Image.BILINEAR) | |
| # Find crop centre | |
| if center == 'centroid': | |
| total = heatmap.sum() | |
| if total > 0: | |
| rows = np.arange(pg) | |
| cols = np.arange(pg) | |
| peak_row = int(np.average(rows, weights=heatmap.sum(axis=1))) | |
| peak_col = int(np.average(cols, weights=heatmap.sum(axis=0))) | |
| else: | |
| peak_row, peak_col = pg // 2, pg // 2 | |
| else: | |
| peak_idx = np.argmax(heatmap) | |
| peak_row, peak_col = divmod(int(peak_idx), pg) | |
| patch_px = image_size // pg | |
| half = (zoom_patches * patch_px) // 2 | |
| cy = peak_row * patch_px + patch_px // 2 | |
| cx = peak_col * patch_px + patch_px // 2 | |
| y0 = max(0, cy - half); y1 = min(image_size, cy + half) | |
| x0 = max(0, cx - half); x1 = min(image_size, cx + half) | |
| return overlay.crop((x0, y0, x1, y1)).resize((size, size), Image.BILINEAR) | |
| def pil_to_data_url(img): | |
| buf = io.BytesIO() | |
| img.save(buf, format="JPEG", quality=85) | |
| b64 = base64.b64encode(buf.getvalue()).decode("utf-8") | |
| return f"data:image/jpeg;base64,{b64}" | |
| # ---------- Brain / Phi helpers ---------- | |
| def _phi_c_for_feat(feat): | |
| """Return cortical leverage score φ_c for a feature, or None.""" | |
| if _phi_c is None or feat >= len(_phi_c): | |
| return None | |
| return float(_phi_c[feat]) | |
| def _phi_voxel_row(feat): | |
| """Return the phi row in voxel space (15724,) float32, or None.""" | |
| if _phi_cv is None or feat >= _phi_cv.shape[0]: | |
| return None | |
| phi_row = np.array(_phi_cv[feat], dtype=np.float32) | |
| if _voxel_to_vertex is not None: | |
| return phi_row[_voxel_to_vertex] | |
| return phi_row | |
| def _render_steering_preview(feats, lams, thresholds): | |
| """Render the net combined steering direction across all chosen features. | |
| Computes: sum_i( lam_i * threshold_mask_i * phi_i / max|phi_i| ) | |
| Returns an HTML string with an inline PNG brain map, or "" if no data. | |
| """ | |
| if not feats or _voxel_coords is None: | |
| return "" | |
| combined = np.zeros(_N_VOXELS_DD, dtype=np.float32) | |
| n_valid = 0 | |
| for f, lam, thr in zip(feats, lams, thresholds): | |
| phi = _phi_voxel_row(f) | |
| if phi is None: | |
| continue | |
| phi_max = float(np.abs(phi).max()) | |
| if phi_max < 1e-12: | |
| continue | |
| norm_phi = phi / phi_max | |
| if thr < 1.0: | |
| cutoff = float(np.percentile(np.abs(phi), 100.0 * (1.0 - thr))) | |
| norm_phi = norm_phi * (np.abs(phi) >= cutoff) | |
| combined += lam * norm_phi | |
| n_valid += 1 | |
| if n_valid == 0 or np.abs(combined).max() < 1e-12: | |
| return "" | |
| vmax = float(np.abs(combined).max()) or 1e-6 | |
| fig, axes = plt.subplots(1, 2, figsize=(8, 3.2), facecolor='#f8f8f8') | |
| for ax, (title, xi, yi) in zip(axes, [("Axial (x–y)", 0, 1), ("Coronal (x–z)", 0, 2)]): | |
| sc = ax.scatter( | |
| _voxel_coords[:, xi], _voxel_coords[:, yi], | |
| c=combined, cmap='RdBu_r', s=3, alpha=0.7, | |
| vmin=-vmax, vmax=vmax, rasterized=True, marker='s', | |
| ) | |
| ax.set_title(title, fontsize=9) | |
| ax.set_aspect('equal') | |
| ax.set_xticks([]); ax.set_yticks([]) | |
| ax.set_facecolor('#f8f8f8') | |
| fig.subplots_adjust(right=0.88, top=0.85) | |
| cbar_ax = fig.add_axes([0.91, 0.15, 0.02, 0.65]) | |
| cbar = fig.colorbar(sc, cax=cbar_ax) | |
| cbar.set_label('Δ fMRI (norm.)', fontsize=8) | |
| lbl = f'{n_valid} feature{"s" if n_valid != 1 else ""}' | |
| fig.suptitle(f'Net brain modification — {lbl}', fontsize=10) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', dpi=80, bbox_inches='tight', facecolor='#f8f8f8') | |
| plt.close(fig) | |
| b64 = base64.b64encode(buf.getvalue()).decode('utf-8') | |
| return ( | |
| '<h4 style="margin:6px 0 3px 0;color:#333;font-size:12px">Net Brain Modification</h4>' | |
| f'<img src="data:image/png;base64,{b64}" ' | |
| 'style="max-width:100%;border-radius:4px;border:1px solid #ddd"/>' | |
| ) | |
| def _render_cortical_profile(feat): | |
| """Render two 2D scatter views of voxel phi values as an inline PNG HTML block. | |
| Returns empty string when phi data is unavailable for this feature. | |
| """ | |
| phi_vox = _phi_voxel_row(feat) | |
| if phi_vox is None or _voxel_coords is None: | |
| return "" | |
| vmax = float(np.abs(phi_vox).max()) or 1e-6 | |
| phi_c_val = _phi_c_for_feat(feat) | |
| phi_c_str = f"φ_c = {phi_c_val:.4f}" if phi_c_val is not None else "" | |
| fig, axes = plt.subplots(1, 2, figsize=(10, 4.0), facecolor='#f8f8f8') | |
| view_pairs = [("Axial (x – y)", 0, 1), ("Coronal (x – z)", 0, 2)] | |
| for ax, (title, xi, yi) in zip(axes, view_pairs): | |
| sc = ax.scatter( | |
| _voxel_coords[:, xi], _voxel_coords[:, yi], | |
| c=phi_vox, cmap='RdBu_r', s=4, alpha=0.75, | |
| vmin=-vmax, vmax=vmax, rasterized=True, marker='s', | |
| ) | |
| ax.set_title(title, fontsize=10) | |
| ax.set_aspect('equal') | |
| ax.set_xticks([]); ax.set_yticks([]) | |
| ax.set_facecolor('#f8f8f8') | |
| fig.subplots_adjust(right=0.88, top=0.88) | |
| cbar_ax = fig.add_axes([0.91, 0.15, 0.02, 0.65]) | |
| cbar = fig.colorbar(sc, cax=cbar_ax) | |
| cbar.set_label('Φ weight', fontsize=9) | |
| fig.suptitle( | |
| f'Cortical Profile — Feature {feat}' + (f' ({phi_c_str})' if phi_c_str else ''), | |
| fontsize=11, | |
| ) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', dpi=90, bbox_inches='tight', facecolor='#f8f8f8') | |
| plt.close(fig) | |
| b64 = base64.b64encode(buf.getvalue()).decode('utf-8') | |
| return ( | |
| '<h3 style="margin:4px 0 6px 0;color:#333;border-bottom:2px solid #e0e0e0;' | |
| 'padding-bottom:4px">Cortical Profile (Φ)</h3>' | |
| f'<img src="data:image/png;base64,{b64}" ' | |
| 'style="max-width:100%;border-radius:4px;border:1px solid #ddd"/>' | |
| ) | |
| def _status_html(state, msg): | |
| """Return a styled HTML status banner.""" | |
| styles = { | |
| 'idle': 'background:#f5f5f5;border-left:4px solid #bbb;color:#666', | |
| 'loading': 'background:#fff8e0;border-left:4px solid #f0a020;color:#7a5000', | |
| 'ok': 'background:#e8f4e8;border-left:4px solid #2a8a2a;color:#1a5a1a', | |
| 'dead': 'background:#fce8e8;border-left:4px solid #c03030;color:#8a1a1a', | |
| } | |
| style = styles.get(state, styles['idle']) | |
| return f'<div style="{style};padding:7px 12px;border-radius:3px;font-size:13px">{msg}</div>' | |
| # ---------- DynaDiff steering helpers ---------- | |
| def _dynadiff_request(sample_idx, steerings, seed): | |
| """Run DynaDiff reconstruction. | |
| steerings: list of (phi_voxel np.ndarray, lam float, threshold float) | |
| Returns dict with baseline_img, steered_img, gt_img, beta_std. | |
| """ | |
| status, err = _dd_loader.status | |
| if status == 'loading': | |
| raise RuntimeError('DynaDiff model still loading — try again shortly') | |
| if status == 'error': | |
| raise RuntimeError(f'DynaDiff model load failed: {err}') | |
| return _dd_loader.reconstruct(sample_idx, steerings, seed) | |
| def _make_steering_html(resps, concept_name): | |
| """Build HTML showing GT | Baseline | Steered for one or more trials. | |
| resps: list of (trial_label, resp_dict) pairs. | |
| """ | |
| header = ( | |
| f'<h3 style="margin:4px 0 6px 0;color:#333;border-bottom:2px solid #e0e0e0;' | |
| f'padding-bottom:4px">DynaDiff Steering — {concept_name}</h3>' | |
| ) | |
| rows_html = '' | |
| for trial_label, resp in resps: | |
| parts = [] | |
| for label, key in [('GT', 'gt_img'), | |
| ('Baseline', 'baseline_img'), | |
| ('Steered', 'steered_img')]: | |
| b64 = resp.get(key) | |
| if b64 is None: | |
| img_html = ('<div style="width:160px;height:160px;background:#eee;' | |
| 'display:flex;align-items:center;justify-content:center;' | |
| 'color:#999;font-size:12px">N/A</div>') | |
| else: | |
| img_html = (f'<img src="data:image/png;base64,{b64}" ' | |
| 'style="width:160px;height:160px;object-fit:contain;' | |
| 'border:1px solid #ddd;border-radius:4px"/>') | |
| parts.append( | |
| f'<div style="text-align:center;margin:0 4px">' | |
| f'{img_html}' | |
| f'<div style="font-size:11px;color:#555;margin-top:3px">{label}</div>' | |
| f'</div>' | |
| ) | |
| trial_head = (f'<div style="font-size:11px;font-weight:bold;color:#777;' | |
| f'margin:6px 0 3px 4px">{trial_label}</div>') | |
| rows_html += (trial_head | |
| + '<div style="display:flex;align-items:flex-end;margin-bottom:8px">' | |
| + ''.join(parts) + '</div>') | |
| return header + rows_html | |
| def make_image_grid_html(images_info, title): | |
| if not images_info: | |
| return (f'<h3 style="margin:4px 0 6px 0;color:#444;border-bottom:2px solid #e8e8e8;' | |
| f'padding-bottom:4px">{title}</h3>' | |
| f'<p style="color:#aaa;font-style:italic;margin:4px 0">No examples available</p>') | |
| thumb_w = min(THUMB, 224) | |
| html = (f'<h3 style="margin:4px 0 8px 0;color:#333;border-bottom:2px solid #e0e0e0;' | |
| f'padding-bottom:4px">{title}</h3>') | |
| html += '<div style="display:flex;flex-wrap:wrap;gap:8px;padding:2px 0 10px 0">' | |
| for img, caption in images_info: | |
| url = pil_to_data_url(img) | |
| parts = caption.split('<br>') | |
| cap_html = ''.join(f'<div>{p}</div>' for p in parts) | |
| html += (f'<div style="text-align:center;width:{thumb_w}px">' | |
| f'<img src="{url}" width="{thumb_w}" height="{thumb_w}"' | |
| f' style="border:1px solid #d0d0d0;border-radius:5px;display:block"/>' | |
| f'<div style="font-size:10px;color:#555;margin-top:3px;line-height:1.4">' | |
| f'{cap_html}</div></div>') | |
| html += '</div>' | |
| return html | |
| def make_compare_aggregations_html(top_infos, mean_infos, feat, n_each=6, model_label=None): | |
| """Figure-ready side-by-side comparison of the first two aggregation methods. | |
| Only Top (Max Activation) and Mean Activation are shown so that a screenshot | |
| of this element stands alone as a clean figure panel. | |
| """ | |
| col_thumb = min(THUMB, 160) | |
| # Only the first two methods are shown in the figure | |
| sections = [ | |
| ("Top Activation", "#2563a8", top_infos), | |
| ("Mean Activation", "#1a7a4a", mean_infos), | |
| ] | |
| cols_per_row = 2 | |
| strip_w = cols_per_row * col_thumb + (cols_per_row - 1) * 6 | |
| # Outer container — white background, no border decoration so the figure can | |
| # be cropped cleanly. A subtle bottom-padding keeps images from being clipped. | |
| html = ( | |
| '<div style="font-family:Arial,Helvetica,sans-serif;background:#ffffff;' | |
| 'padding:16px 20px 14px 20px;display:inline-block">' | |
| # Title row | |
| f'<div style="font-size:13px;font-weight:bold;color:#222;margin-bottom:14px;' | |
| f'letter-spacing:0.1px">' | |
| + (f'{model_label} — ' if model_label else '') | |
| + f'Feature {feat}</div>' | |
| '<div style="display:flex;gap:24px;align-items:flex-start">' | |
| ) | |
| for method_name, color, infos in sections: | |
| shown = (infos or [])[:n_each] | |
| html += ( | |
| f'<div style="display:inline-flex;flex-direction:column">' | |
| # Bold, clearly-coloured column header | |
| f'<div style="background:{color};color:#ffffff;font-size:13px;font-weight:bold;' | |
| f'text-align:center;padding:6px 0;border-radius:5px;margin-bottom:10px;' | |
| f'letter-spacing:0.4px;width:{strip_w}px;box-sizing:border-box">{method_name}</div>' | |
| f'<div style="display:grid;grid-template-columns:repeat({cols_per_row},{col_thumb}px);gap:6px">' | |
| ) | |
| if not shown: | |
| html += '<div style="color:#aaa;font-style:italic;font-size:11px;padding:8px">No images</div>' | |
| for img, caption in shown: | |
| url = pil_to_data_url(img) | |
| parts = caption.split('<br>') | |
| cap_html = '<br>'.join(parts) | |
| html += ( | |
| f'<div style="text-align:center">' | |
| f'<img src="{url}" width="{col_thumb}" height="{col_thumb}"' | |
| f' style="border:1px solid #ccc;border-radius:3px;display:block"/>' | |
| f'<div style="font-size:9px;color:#555;margin-top:3px;line-height:1.35">' | |
| f'{cap_html}</div></div>' | |
| ) | |
| html += '</div></div>' | |
| html += '</div></div>' | |
| return html | |
| # ---------- UMAP data source ---------- | |
| # live_mask / live_indices / freq / mean_act / log_freq / umap_backup are all | |
| # already set by _apply_dataset_globals(0) above — just build the source from them. | |
| # Helpers to build phi_c and color_val arrays for any set of feature indices. | |
| def _phi_c_vals(indices): | |
| """Return phi_c leverage values for a list of feature indices (0.0 when unavailable).""" | |
| if _phi_c is None: | |
| return [0.0] * len(indices) | |
| return [float(_phi_c[i]) if i < len(_phi_c) else 0.0 for i in indices] | |
| def _make_point_alphas(n): | |
| """Return uniform 0.6 alpha for all n UMAP points.""" | |
| return [0.6] * n | |
| def _make_color_vals(indices): | |
| """Return color values for the UMAP scatter based on current _S.color_by.""" | |
| cb = _S.color_by | |
| idx_arr = np.array(indices, dtype=int) | |
| if cb == "Mean Activation": | |
| return mean_act[idx_arr].tolist() | |
| elif cb == "Brain Leverage (φ_c)": | |
| return _phi_c_vals(indices) | |
| else: # "Log Frequency" | |
| return log_freq[idx_arr].tolist() | |
| umap_source = ColumnDataSource(data=dict( | |
| x=umap_coords[live_mask, 0].tolist(), | |
| y=umap_coords[live_mask, 1].tolist(), | |
| feature_idx=live_indices.tolist(), | |
| frequency=freq[live_mask].tolist(), | |
| log_freq=log_freq[live_mask].tolist(), | |
| mean_act=mean_act[live_mask].tolist(), | |
| phi_c_val=_phi_c_vals(live_indices.tolist()), | |
| color_val=log_freq[live_mask].tolist(), | |
| point_alpha=_make_point_alphas(int(live_mask.sum())), | |
| )) | |
| # ---------- UMAP figure ---------- | |
| _init_log_freq = log_freq[live_mask] | |
| color_mapper = linear_cmap( | |
| field_name='color_val', palette=Turbo256, | |
| low=float(np.percentile(_init_log_freq, 2)) if live_mask.any() else 0, | |
| high=float(np.percentile(_init_log_freq, 98)) if live_mask.any() else 1, | |
| ) | |
| def _set_color_mapper_range(color_vals): | |
| """Update color_mapper low/high to the 2nd–98th percentile of color_vals.""" | |
| if not color_vals: | |
| return | |
| arr = np.array(color_vals) | |
| lo, hi = float(np.percentile(arr, 2)), float(np.percentile(arr, 98)) | |
| if lo == hi: | |
| hi = lo + 1e-6 | |
| color_mapper['transform'].low = lo | |
| color_mapper['transform'].high = hi | |
| umap_fig = figure( | |
| title="UMAP of SAE Features (by activation pattern)", | |
| width=700, height=650, | |
| tools="pan,wheel_zoom,box_zoom,reset,tap", | |
| active_scroll="wheel_zoom", | |
| ) | |
| umap_scatter = umap_fig.scatter( | |
| 'x', 'y', source=umap_source, size=4, alpha='point_alpha', | |
| color=color_mapper, | |
| selection_color="#FF2222", selection_alpha=1.0, | |
| selection_line_color="white", selection_line_width=1.5, | |
| ) | |
| # Scale point size with zoom: bigger when zoomed in | |
| _zoom_cb = CustomJS(args=dict(renderer=umap_scatter, x_range=umap_fig.x_range), code=""" | |
| const span = x_range.end - x_range.start; | |
| if (window._umap_base_span === undefined) { | |
| window._umap_base_span = span; | |
| } | |
| const zoom = window._umap_base_span / span; | |
| const new_size = Math.min(12, Math.max(3, 3 * Math.pow(zoom, 0.1))); | |
| renderer.glyph.size = new_size; | |
| renderer.nonselection_glyph.size = new_size; | |
| renderer.selection_glyph.size = Math.max(14, new_size * 3); | |
| """) | |
| umap_fig.x_range.js_on_change('start', _zoom_cb) | |
| umap_fig.x_range.js_on_change('end', _zoom_cb) | |
| _phi_hover = [("Brain φ_c", "@phi_c_val{0.0000}")] if HAS_PHI else [] | |
| umap_fig.add_tools(HoverTool(tooltips=[ | |
| ("Feature", "@feature_idx"), | |
| ("Frequency", "@frequency{0}"), | |
| ("Mean Act", "@mean_act{0.000}"), | |
| ] + _phi_hover)) | |
| # ---------- Dataset / model selector ---------- | |
| dataset_select = Select( | |
| title="Dataset:", | |
| value="0", | |
| options=[(str(i), ds['label']) for i, ds in enumerate(_all_datasets)], | |
| width=250, | |
| ) | |
| def _on_dataset_switch(attr, old, new): | |
| idx = int(new) | |
| old_idx = int(old) | |
| _ensure_loaded(idx) | |
| # Capture current feature and old d_model before swapping globals | |
| _prev_feat_str = feature_input.value.strip() | |
| _old_d_model = _all_datasets[old_idx]['d_model'] | |
| _S.active = idx | |
| _apply_dataset_globals(idx) # also resets _active_feats | |
| # Rebuild UMAP scatter | |
| _feat_ids = live_indices.tolist() | |
| _color_vals = _make_color_vals(_feat_ids) | |
| _phi_c_list = _phi_c_vals(_feat_ids) | |
| umap_source.data = dict( | |
| x=umap_coords[live_mask, 0].tolist(), | |
| y=umap_coords[live_mask, 1].tolist(), | |
| feature_idx=_feat_ids, | |
| frequency=freq[live_mask].tolist(), | |
| log_freq=log_freq[live_mask].tolist(), | |
| mean_act=mean_act[live_mask].tolist(), | |
| phi_c_val=_phi_c_list, | |
| color_val=_color_vals, | |
| point_alpha=_make_point_alphas(len(_feat_ids)), | |
| ) | |
| _set_color_mapper_range(_color_vals) | |
| umap_source.selected.indices = [] | |
| umap_type_select.value = "Activation Pattern" | |
| umap_fig.title.text = f"UMAP — {_all_datasets[idx]['label']}" | |
| # Rebuild feature list | |
| _S.search_filter = None | |
| _apply_order(_get_sorted_order()) | |
| # Update summary panel | |
| summary_div.text = _make_summary_html() | |
| # Show/hide patch explorer depending on token type and data availability. | |
| ds = _all_datasets[idx] | |
| has_heatmaps = ds.get('top_heatmaps') is not None | |
| has_patch_acts = ds.get('patch_acts') is not None | |
| can_explore = ( | |
| ds.get('token_type', 'spatial') == 'spatial' | |
| and (has_heatmaps or has_patch_acts) | |
| ) | |
| patch_fig.visible = can_explore | |
| patch_info_div.visible = can_explore | |
| if not can_explore: | |
| if ds.get('token_type') == 'cls': | |
| reason = "CLS token — no patch grid" | |
| else: | |
| reason = "no pre-computed heatmaps or patch_acts for this model" | |
| patch_info_div.text = ( | |
| f'<p style="color:#888;font-style:italic">Patch explorer unavailable: {reason}.</p>') | |
| patch_info_div.visible = True | |
| # Update CLIP search hint | |
| if HAS_CLIP: | |
| clip_result_div.text = "" | |
| clip_result_source.data = dict( | |
| feature_idx=[], clip_score=[], frequency=[], mean_act=[], phi_c_val=[], name=[]) | |
| # If the two datasets share the same feature space, re-display the current feature | |
| _same_space = (_all_datasets[idx]['d_model'] == _old_d_model) | |
| _restore_feat = None | |
| if _same_space and _prev_feat_str: | |
| try: | |
| _restore_feat = int(_prev_feat_str) | |
| except ValueError: | |
| pass | |
| if _restore_feat is not None and 0 <= _restore_feat < d_model: | |
| feature_input.value = str(_restore_feat) | |
| update_feature_display(_restore_feat) | |
| else: | |
| feature_input.value = "" | |
| stats_div.text = "<h3>Select a feature to explore</h3>" | |
| brain_div.text = "" | |
| status_div.text = _status_html('idle', 'Model switched — select a feature to explore.') | |
| if HAS_DYNADIFF: | |
| _dd_output.text = "" | |
| _dd_status.text = "" | |
| for div in [top_heatmap_div, mean_heatmap_div]: | |
| div.text = "" | |
| dataset_select.on_change('value', _on_dataset_switch) | |
| # ---------- Detail panels ---------- | |
| status_div = Div( | |
| text=_status_html('idle', 'Select a feature on the UMAP or from the list to begin.'), | |
| width=900, | |
| ) | |
| stats_div = Div(text="<h3>Click a feature on the UMAP to explore it</h3>", width=900) | |
| top_heatmap_div = Div(text="", width=900) | |
| mean_heatmap_div = Div(text="", width=900) | |
| compare_agg_div = Div(text="", width=1400) # side-by-side aggregation comparison | |
| brain_div = Div(text="", width=900) # cortical profile for selected feature | |
| # ---------- DynaDiff steering panel builder ---------- | |
| # Feature list stored in a ColumnDataSource so the DataTable can edit λ and threshold inline. | |
| def _phi_cv_feat_name(feat): | |
| """Best-effort display name for the feature.""" | |
| if feat is None: | |
| return 'unknown' | |
| ds = _all_datasets[_S.active] if _all_datasets else None | |
| if ds and feat in ds.get('feature_names', {}): | |
| return ds['feature_names'][feat] | |
| return f'feat {feat}' | |
| def _build_dynadiff_panel(): | |
| """Build the DynaDiff brain-steering panel widgets and callbacks. | |
| Returns (panel_body, dd_output, dd_status, dd_feat_input). | |
| When HAS_DYNADIFF is False, panel_body is None and the divs are 1-pixel stubs. | |
| dd_feat_input is None when disabled so callers must guard before use. | |
| """ | |
| if not HAS_DYNADIFF: | |
| return None, Div(text="", width=1), Div(text="", width=1), None, None | |
| # ---- ColumnDataSource backing the feature table ---- | |
| dd_source = ColumnDataSource(data=dict(feat=[], name=[], lam=[], threshold=[])) | |
| dd_table = DataTable( | |
| source=dd_source, | |
| columns=[ | |
| TableColumn(field='feat', title='#', width=55), | |
| TableColumn(field='name', title='Feature', width=190), | |
| TableColumn(field='lam', title='λ', width=60, | |
| editor=NumberEditor(), | |
| formatter=NumberFormatter(format='0.0')), | |
| TableColumn(field='threshold', title='Brain%', width=65, | |
| editor=NumberEditor(), | |
| formatter=NumberFormatter(format='0.00')), | |
| ], | |
| editable=True, | |
| width=460, | |
| height=130, | |
| index_position=None, | |
| ) | |
| # ---- Brain modification preview div ---- | |
| dd_steer_div = Div(text="", width=460) | |
| def _update_dd_preview(): | |
| feats = list(dd_source.data['feat']) | |
| lams = list(dd_source.data['lam']) | |
| thrs = list(dd_source.data['threshold']) | |
| dd_steer_div.text = _render_steering_preview(feats, lams, thrs) | |
| dd_source.on_change('data', lambda attr, old, new: _update_dd_preview()) | |
| # ---- "Add feature" row ---- | |
| dd_feat_input = TextInput(title="Feature index:", placeholder="e.g. 1234", width=120) | |
| dd_add_lam_input = TextInput(title="λ:", value="3.0", width=65) | |
| dd_add_thr_select = Select( | |
| title="Brain %:", | |
| options=[('0.05', '5%'), ('0.10', '10%'), ('0.25', '25%'), ('1.0', '100%')], | |
| value='0.10', | |
| width=90, | |
| ) | |
| dd_feat_add_btn = Button(label="Add", button_type="success", width=55) | |
| dd_feat_remove_btn = Button(label="Remove selected", button_type="light", width=130) | |
| dd_feat_clear_btn = Button(label="Clear all", button_type="light", width=80) | |
| # ---- Global run controls ---- | |
| dd_sample_input = TextInput(title="Sample idx", value="0", width=180) | |
| dd_seed_input = TextInput(title="Seed:", value="42", width=70) | |
| dd_btn = Button(label="Steer & Reconstruct", button_type="primary", width=200) | |
| dd_status = Div(text="", width=460) | |
| dd_output = Div(text="", width=460) | |
| def _on_add_feat(): | |
| try: | |
| f = int(dd_feat_input.value.strip()) | |
| except ValueError: | |
| dd_status.text = '<span style="color:#c00">Invalid feature index.</span>' | |
| return | |
| if _phi_cv is None or f < 0 or f >= _phi_cv.shape[0]: | |
| n = _phi_cv.shape[0] if _phi_cv is not None else '?' | |
| dd_status.text = f'<span style="color:#c00">Feature {f} out of range (0–{n}).</span>' | |
| return | |
| try: | |
| lam = float(dd_add_lam_input.value) | |
| except ValueError: | |
| lam = 3.0 | |
| threshold = float(dd_add_thr_select.value) | |
| new_data = {k: list(v) for k, v in dd_source.data.items()} | |
| new_data['feat'].append(f) | |
| new_data['name'].append(_phi_cv_feat_name(f)) | |
| new_data['lam'].append(lam) | |
| new_data['threshold'].append(threshold) | |
| dd_source.data = new_data | |
| dd_status.text = '' | |
| def _on_remove_feat(): | |
| sel = dd_source.selected.indices | |
| if not sel: | |
| dd_status.text = '<span style="color:#888">Select a row first.</span>' | |
| return | |
| new_data = {k: [v for i, v in enumerate(vals) if i not in sel] | |
| for k, vals in dd_source.data.items()} | |
| dd_source.data = new_data | |
| dd_source.selected.indices = [] | |
| dd_status.text = '' | |
| def _on_clear_feats(): | |
| dd_source.data = dict(feat=[], name=[], lam=[], threshold=[]) | |
| dd_status.text = '' | |
| dd_feat_add_btn.on_click(_on_add_feat) | |
| dd_feat_remove_btn.on_click(_on_remove_feat) | |
| dd_feat_clear_btn.on_click(_on_clear_feats) | |
| def _reconstruct_thread(sample_idxs, steerings, seed, feat_name, doc): | |
| try: | |
| resps = [] | |
| for i, sidx in enumerate(sample_idxs): | |
| trial_label = f'Trial {i+1} (sample {sidx})' | |
| resp = _dynadiff_request(sidx, steerings, seed) | |
| resps.append((trial_label, resp)) | |
| html = _make_steering_html(resps, feat_name) | |
| def _apply(html=html): | |
| dd_output.text = html | |
| dd_status.text = '' | |
| dd_btn.disabled = False | |
| doc.add_next_tick_callback(_apply) | |
| except Exception as exc: | |
| msg = str(exc) | |
| def _show_err(msg=msg): | |
| dd_status.text = f'<span style="color:#c00">Error: {msg}</span>' | |
| dd_btn.disabled = False | |
| doc.add_next_tick_callback(_show_err) | |
| def _on_reconstruct(): | |
| feats = list(dd_source.data['feat']) | |
| lams = list(dd_source.data['lam']) | |
| thrs = list(dd_source.data['threshold']) | |
| if not feats: | |
| dd_status.text = '<span style="color:#c00">Add at least one feature first.</span>' | |
| return | |
| steerings = [] | |
| for f, lam, thr in zip(feats, lams, thrs): | |
| phi = _phi_voxel_row(f) | |
| if phi is not None: | |
| steerings.append((phi, float(lam), float(thr))) | |
| if not steerings: | |
| dd_status.text = '<span style="color:#c00">No phi data for selected features.</span>' | |
| return | |
| _raw = dd_sample_input.value.strip() | |
| try: | |
| _parsed = _parse_img_label(_raw) | |
| except ValueError: | |
| dd_status.text = '<span style="color:#c00">Invalid sample index.</span>' | |
| return | |
| # Check model status before proceeding — _nsd_to_sample is empty while | |
| # loading, so we must gate on status here rather than letting an empty | |
| # lookup produce a misleading "no trials for this subject" error. | |
| _dd_cur_status, _dd_cur_err = _dd_loader.status | |
| if _dd_cur_status == 'loading': | |
| dd_status.text = ('<span style="color:#f0a020">' | |
| 'DynaDiff model still loading — try again shortly.</span>') | |
| return | |
| if _dd_cur_status == 'error': | |
| dd_status.text = (f'<span style="color:#c00">' | |
| f'DynaDiff model load failed: {_dd_cur_err}</span>') | |
| return | |
| # If input looks like an NSD image label (e.g. "nsd_22910"), extract the | |
| # NSD stimulus index from the trailing integer. _parse_img_label returns | |
| # the union-dataset index (e.g. 1431612) which is wrong for DynaDiff — | |
| # it needs the NSD image number (22910). | |
| if '_' in _raw: | |
| try: | |
| nsd_img_idx = int(_raw.rsplit('_', 1)[-1]) | |
| except ValueError: | |
| dd_status.text = '<span style="color:#c00">Could not parse NSD image index.</span>' | |
| return | |
| sample_idxs = _dd_loader.sample_idxs_for_nsd_img(nsd_img_idx) | |
| if not sample_idxs: | |
| dd_status.text = ( | |
| f'<span style="color:#c00">NSD image {nsd_img_idx} has no trials ' | |
| f'for this subject.</span>') | |
| return | |
| else: | |
| sample_idxs = [_parsed] | |
| _n = _dd_loader.n_samples | |
| if _n is not None and any(not (0 <= s < _n) for s in sample_idxs): | |
| dd_status.text = f'<span style="color:#c00">sample_idx must be 0–{_n-1}.</span>' | |
| return | |
| try: | |
| seed = int(dd_seed_input.value) | |
| except ValueError: | |
| seed = 42 | |
| names = list(dd_source.data['name']) | |
| feat_name = ' + '.join(names) if names else 'unknown' | |
| dd_btn.disabled = True | |
| n_trials = len(sample_idxs) | |
| dd_status.text = (f'<i style="color:#888">Running DynaDiff reconstruction ' | |
| f'({n_trials} trial{"s" if n_trials > 1 else ""})…</i>') | |
| doc = curdoc() | |
| threading.Thread( | |
| target=_reconstruct_thread, | |
| args=(sample_idxs, steerings, seed, feat_name, doc), | |
| daemon=True, | |
| ).start() | |
| dd_btn.on_click(_on_reconstruct) | |
| panel_body = column( | |
| row(dd_feat_input, dd_add_lam_input, dd_add_thr_select, dd_feat_add_btn), | |
| row(dd_feat_remove_btn, dd_feat_clear_btn), | |
| dd_table, | |
| dd_steer_div, | |
| row(dd_sample_input, dd_seed_input), | |
| row(dd_btn, dd_status), | |
| dd_output, | |
| ) | |
| return panel_body, dd_output, dd_status, dd_feat_input, dd_sample_input | |
| # ---------- DynaDiff steering widgets ---------- | |
| # _dd_feat_input, _dd_status, _dd_output, _dd_sample_input are referenced by | |
| # update_feature_display and _on_dataset_switch — they must be module-level names. | |
| _dd_panel_body, _dd_output, _dd_status, _dd_feat_input, _dd_sample_input = _build_dynadiff_panel() | |
| # Name editing widget (defined here so update_feature_display can reference it) | |
| name_input = TextInput( | |
| title="Feature name (auto-saved):", | |
| placeholder="Enter a name for this feature...", | |
| width=420, | |
| ) | |
| # Gemini auto-interp button | |
| _gemini_api_key = args.google_api_key or os.environ.get("GOOGLE_API_KEY") | |
| gemini_btn = Button( | |
| label="Label with Gemini", | |
| width=140, | |
| button_type="warning", | |
| disabled=(_gemini_api_key is None), | |
| ) | |
| gemini_status_div = Div(text=( | |
| "<i style='color:#aaa'>No GOOGLE_API_KEY set</i>" | |
| if _gemini_api_key is None else "" | |
| ), width=300) | |
| # Zoom slider — controls neighbourhood size in the zoomed-patch view | |
| zoom_slider = Slider( | |
| title="Zoom (patches)", value=16, start=1, end=16, step=1, width=220, | |
| ) | |
| # Heatmap opacity slider — controls alpha of the overlay in render_heatmap_overlay | |
| heatmap_alpha_slider = Slider( | |
| title="Heatmap opacity", value=1.0, start=0.0, end=1.0, step=0.05, width=220, | |
| ) | |
| # View selector: which image ranking to show in the detail panel | |
| _view_options = ["Top (max activation)", "Mean activation", "Compare aggregations"] | |
| view_select = Select( | |
| title="Image ranking:", | |
| value="Top (max activation)", | |
| options=_view_options, | |
| width=250, | |
| ) | |
| nsd_subset_toggle = RadioButtonGroup( | |
| labels=["All images", "NSD sub01"], | |
| active=0, | |
| width=220, | |
| ) | |
| N_DISPLAY = 6 | |
| def update_feature_display(feature_idx): | |
| feat = int(feature_idx) | |
| _S.render_token += 1 | |
| my_token = _S.render_token | |
| freq_val = feature_frequency[feat].item() | |
| mean_val = feature_mean_act[feat].item() | |
| dead = "DEAD FEATURE" if freq_val == 0 else "" | |
| feat_name = feature_names.get(feat, "") | |
| auto_name = auto_interp_names.get(feat, "") | |
| name_parts = [] | |
| if feat_name: | |
| name_parts.append( | |
| f'<div style="color:#1a6faf;font-style:italic;margin:2px 0 3px 0">' | |
| f'🏷︎ {feat_name}' | |
| f'<span style="font-size:10px;color:#999;margin-left:6px">(manual)</span></div>' | |
| ) | |
| if auto_name: | |
| name_parts.append( | |
| f'<div style="color:#5a9a5a;font-style:italic;margin:2px 0 3px 0">' | |
| f'🤖 {auto_name}' | |
| f'<span style="font-size:10px;color:#999;margin-left:6px">(auto-interp)</span></div>' | |
| ) | |
| name_display = "".join(name_parts) | |
| phi_c_val = _phi_c_for_feat(feat) | |
| phi_chip = (f' · <b>φ_c:</b> {phi_c_val:.4f}' if phi_c_val is not None else '') | |
| stats_div.text = ( | |
| f'<h2 style="margin:4px 0">Feature {feat}' | |
| f'<span style="color:red;margin-left:8px">{dead}</span>' | |
| f'<span style="font-size:13px;font-weight:normal;color:#555;margin-left:14px">' | |
| f'<b>Freq:</b> {int(freq_val):,} · ' | |
| f'<b>Mean act:</b> {mean_val:.4f}' | |
| f'{phi_chip}</span></h2>' | |
| + name_display | |
| ) | |
| name_input.value = feat_name | |
| if freq_val == 0: | |
| status_div.text = _status_html( | |
| 'dead', f'Feature {feat} is dead — it never activated on the precompute set.') | |
| brain_div.text = _render_cortical_profile(feat) # still show cortical profile if available | |
| for div in [top_heatmap_div, mean_heatmap_div, compare_agg_div]: | |
| div.text = "" | |
| return | |
| status_div.text = _status_html( | |
| 'loading', f'⏳ Rendering heatmaps for feature {feat}...') | |
| def _render(): | |
| # Bail out if the user has already clicked a different feature. | |
| if _S.render_token != my_token: | |
| return | |
| _SLOT_EMPTY = object() # sentinel: no more stored slots (img_i < 0) | |
| def _render_one(img_idx_tensor, act_tensor, ranking_idx, heatmap_tensor=None, | |
| center='peak'): | |
| img_i = img_idx_tensor[feat, ranking_idx].item() | |
| if img_i < 0: | |
| return _SLOT_EMPTY # no more slots stored for this feature | |
| try: | |
| # Use pre-computed heatmap | |
| if heatmap_tensor is not None and heatmap_patch_grid > 1: | |
| hmap = heatmap_tensor[feat, ranking_idx].float().numpy() | |
| hmap = hmap.reshape(heatmap_patch_grid, heatmap_patch_grid) | |
| else: | |
| hmap = None | |
| img_label = os.path.splitext(os.path.basename(image_paths[img_i]))[0] | |
| act_val = float(act_tensor[feat, ranking_idx].item()) | |
| caption = f"act={act_val:.4f} {img_label}" | |
| if hmap is None: | |
| plain = load_image(img_i).resize((THUMB, THUMB), Image.BILINEAR) | |
| return (plain, caption) | |
| img_out = render_zoomed_overlay(img_i, hmap, size=THUMB, center=center) | |
| return (img_out, caption) | |
| except (FileNotFoundError, OSError): | |
| return None # image file not available on this machine — skip silently | |
| except Exception as e: | |
| ph = Image.new("RGB", (THUMB, THUMB), "gray") | |
| return (ph, f"Error: {e}") | |
| def _collect(idx_tensor, act_tensor, hm_tensor, n, center='peak'): | |
| """Render up to n images, skipping unavailable files but stopping at empty slots.""" | |
| results = [] | |
| for j in range(min(n, idx_tensor.shape[1])): | |
| hm = _render_one(idx_tensor, act_tensor, j, hm_tensor, center=center) | |
| if hm is _SLOT_EMPTY: | |
| break # no more stored slots | |
| if hm is None: | |
| continue # file missing on this machine — try next slot | |
| results.append(hm) | |
| return results | |
| # --- Top images --- | |
| _use_nsd = nsd_subset_toggle.active == 1 and HAS_NSD_SUBSET | |
| _top_idx = nsd_top_img_idx if _use_nsd else top_img_idx | |
| _top_act = nsd_top_img_act if _use_nsd else top_img_act | |
| _mean_idx = nsd_mean_img_idx if _use_nsd else mean_img_idx | |
| _mean_act = nsd_mean_img_act if _use_nsd else mean_img_act | |
| _top_hm = nsd_top_heatmaps if _use_nsd else top_heatmaps | |
| _mean_hm = nsd_mean_heatmaps if _use_nsd else mean_heatmaps | |
| heatmap_infos = _collect(_top_idx, _top_act, _top_hm, N_DISPLAY) | |
| _subset_label = " [NSD sub01]" if _use_nsd else "" | |
| top_heatmap_div.text = make_image_grid_html( | |
| heatmap_infos, f"Top by Max Activation (feature {feat}){_subset_label}") | |
| # --- Mean-ranked images --- | |
| mean_hm_infos = _collect(_mean_idx, _mean_act, _mean_hm, N_DISPLAY, center='centroid') | |
| mean_heatmap_div.text = make_image_grid_html( | |
| mean_hm_infos, f"Top by Mean Activation (feature {feat}){_subset_label}") | |
| # Side-by-side aggregation comparison (paper-ready screenshot view) | |
| compare_agg_div.text = make_compare_aggregations_html( | |
| heatmap_infos, mean_hm_infos, feat, | |
| model_label=_all_datasets[_S.active]['label']) | |
| brain_div.text = _render_cortical_profile(feat) | |
| # Pre-fill DynaDiff inputs when a feature is selected. | |
| # Sample input: use the stem of the top NSD MEI when the NSD subset toggle | |
| # is active (e.g. "nsd_22910"), so the image index passed to DynaDiff | |
| # refers to the NSD stimulus number, not the union-dataset index. | |
| if HAS_DYNADIFF: | |
| _dd_feat_input.value = str(feat) | |
| _use_nsd_dd = nsd_subset_toggle.active == 1 and HAS_NSD_SUBSET | |
| if _use_nsd_dd and _dd_sample_input is not None: | |
| _top_i = nsd_top_img_idx[feat, 0].item() | |
| if _top_i >= 0: | |
| _dd_sample_input.value = os.path.splitext( | |
| os.path.basename(image_paths[_top_i]))[0] | |
| _dd_status.text = ( | |
| '<i style="color:#888">Feature pre-filled → click Add, then Steer & Reconstruct.</i>' | |
| if _phi_voxel_row(feat) is not None else | |
| '<span style="color:#c00">No phi data for this feature.</span>' | |
| ) | |
| status_div.text = _status_html('ok', f'✓ Feature {feat} ready.') | |
| _update_view_visibility() | |
| curdoc().add_next_tick_callback(_render) | |
| # ---------- View visibility ---------- | |
| def _update_view_visibility(): | |
| v = view_select.value | |
| is_compare = (v == "Compare aggregations") | |
| top_heatmap_div.visible = (v == "Top (max activation)") | |
| mean_heatmap_div.visible = (v == "Mean activation") | |
| compare_agg_div.visible = is_compare | |
| view_select.on_change('value', lambda attr, old, new: _update_view_visibility()) | |
| _update_view_visibility() # set initial state | |
| def _rerender_current_feature(attr, old, new): | |
| """Re-render the current feature when any display control changes.""" | |
| try: | |
| feat = int(feature_input.value) | |
| if 0 <= feat < d_model: | |
| update_feature_display(feat) | |
| except ValueError: | |
| pass | |
| zoom_slider.on_change('value', _rerender_current_feature) | |
| heatmap_alpha_slider.on_change('value', _rerender_current_feature) | |
| nsd_subset_toggle.on_change('active', _rerender_current_feature) | |
| # ---------- Callbacks ---------- | |
| def _umap_alphas_for_selection(selected_pos): | |
| """Return point_alpha list: 0.6 for selected point, 0.2 for all others.""" | |
| n = len(umap_source.data['feature_idx']) | |
| if selected_pos is None: | |
| return [0.6] * n | |
| return [0.6 if i == selected_pos else 0.2 for i in range(n)] | |
| def on_umap_select(attr, old, new): | |
| if new: | |
| umap_source.data['point_alpha'] = _umap_alphas_for_selection(new[0]) | |
| feature_idx = umap_source.data['feature_idx'][new[0]] | |
| feature_input.value = str(feature_idx) | |
| update_feature_display(feature_idx) | |
| else: | |
| umap_source.data['point_alpha'] = _umap_alphas_for_selection(None) | |
| umap_source.selected.on_change('indices', on_umap_select) | |
| # UMAP type toggle | |
| _umap_type_options = ["Activation Pattern", "Dictionary Geometry"] | |
| umap_type_select = Select( | |
| title="UMAP Type", value="Activation Pattern", | |
| options=_umap_type_options, width=220, | |
| ) | |
| # UMAP color select | |
| _color_options = ["Log Frequency", "Mean Activation"] | |
| if _phi_c is not None: | |
| _color_options.append("Brain Leverage (φ_c)") | |
| umap_color_select = Select( | |
| title="Color by:", value="Log Frequency", | |
| options=_color_options, width=200, | |
| ) | |
| def _apply_umap_color(color_by, feat_ids): | |
| """Update umap_source color_val and color_mapper range for the given indices.""" | |
| _S.color_by = color_by | |
| new_colors = _make_color_vals(feat_ids) | |
| umap_source.data['color_val'] = new_colors | |
| _set_color_mapper_range(new_colors) | |
| def _on_umap_color_change(attr, old, new): | |
| feat_ids = list(umap_source.data['feature_idx']) | |
| _apply_umap_color(new, feat_ids) | |
| umap_color_select.on_change('value', _on_umap_color_change) | |
| def on_umap_type_change(attr, old, new): | |
| color_vals = [] | |
| if new == "Activation Pattern": | |
| feat_ids = umap_backup['act_feat'] | |
| color_vals = _make_color_vals(feat_ids) | |
| _phi_c_list = _phi_c_vals(feat_ids) | |
| umap_source.data = dict( | |
| x=umap_backup['act_x'], | |
| y=umap_backup['act_y'], | |
| feature_idx=feat_ids, | |
| frequency=freq[live_mask].tolist(), | |
| log_freq=log_freq[live_mask].tolist(), | |
| mean_act=mean_act[live_mask].tolist(), | |
| phi_c_val=_phi_c_list, | |
| color_val=color_vals, | |
| point_alpha=_make_point_alphas(len(feat_ids)), | |
| ) | |
| umap_fig.title.text = "UMAP of SAE Features (by activation pattern)" | |
| else: | |
| feat_ids = umap_backup['dict_feat'] | |
| dict_freq = freq[dict_live_mask] | |
| dict_log_freq = log_freq[dict_live_mask] | |
| dict_mean_act = mean_act[dict_live_mask] | |
| color_vals = _make_color_vals(feat_ids) | |
| _phi_c_list = _phi_c_vals(feat_ids) | |
| umap_source.data = dict( | |
| x=umap_backup['dict_x'], | |
| y=umap_backup['dict_y'], | |
| feature_idx=feat_ids, | |
| frequency=dict_freq.tolist(), | |
| log_freq=dict_log_freq.tolist(), | |
| mean_act=dict_mean_act.tolist(), | |
| phi_c_val=_phi_c_list, | |
| color_val=color_vals, | |
| point_alpha=_make_point_alphas(len(feat_ids)), | |
| ) | |
| umap_fig.title.text = "UMAP of SAE Features (by dictionary geometry)" | |
| _set_color_mapper_range(color_vals) | |
| umap_type_select.on_change('value', on_umap_type_change) | |
| # Direct feature input | |
| feature_input = TextInput(title="Feature Index:", value="", width=120) | |
| go_button = Button(label="Go", width=60) | |
| random_btn = Button(label="Random", width=70) | |
| def _select_and_display(feat): | |
| """Show the detail panel for feat and sync the UMAP highlight.""" | |
| update_feature_display(feat) | |
| feat_list = umap_source.data['feature_idx'] | |
| if feat in feat_list: | |
| umap_source.selected.indices = [feat_list.index(feat)] | |
| def on_go_click(): | |
| try: | |
| feat = int(feature_input.value) | |
| if 0 <= feat < d_model: | |
| _select_and_display(feat) | |
| else: | |
| stats_div.text = f"<h3>Feature {feat} out of range (0-{d_model-1})</h3>" | |
| except ValueError: | |
| stats_div.text = "<h3>Please enter a valid integer</h3>" | |
| go_button.on_click(on_go_click) | |
| def _on_random(): | |
| if not _active_feats: | |
| return | |
| feat = random.choice(_active_feats) | |
| feature_input.value = str(feat) | |
| _select_and_display(feat) | |
| random_btn.on_click(_on_random) | |
| # ---------- Sorted feature list ---------- | |
| _init_order = np.argsort(-freq) | |
| feature_list_source = ColumnDataSource(data=dict( | |
| feature_idx=_init_order.tolist(), | |
| frequency=freq[_init_order].tolist(), | |
| mean_act=mean_act[_init_order].tolist(), | |
| p75_val=p75_np[_init_order].tolist(), | |
| phi_c_val=_phi_c_vals(_init_order.tolist()), | |
| name=[_display_name(int(i)) for i in _init_order], | |
| )) | |
| def _phi_col(): | |
| """Return phi_c column definition list (single element) if phi data is loaded, else [].""" | |
| if not HAS_PHI: | |
| return [] | |
| return [TableColumn(field="phi_c_val", title="φ_c", width=65, | |
| formatter=NumberFormatter(format="0.0000"))] | |
| feature_table = DataTable( | |
| source=feature_list_source, | |
| columns=[ | |
| TableColumn(field="feature_idx", title="Feature", width=60), | |
| TableColumn(field="frequency", title="Freq", width=70, | |
| formatter=NumberFormatter(format="0,0")), | |
| TableColumn(field="mean_act", title="Mean Act", width=80, | |
| formatter=NumberFormatter(format="0.0000")), | |
| ] + _phi_col() + [ | |
| TableColumn(field="name", title="Name", width=200), | |
| ], | |
| width=500, height=500, sortable=True, index_position=None, | |
| ) | |
| # Search state: None = no filter, otherwise a set of matching feature indices | |
| def _get_sorted_order(): | |
| order = np.argsort(-freq) | |
| if _S.search_filter is not None: | |
| mask = np.isin(order, list(_S.search_filter)) | |
| order = order[mask] | |
| return order | |
| def _apply_order(order): | |
| feature_list_source.data = dict( | |
| feature_idx=order.tolist(), | |
| frequency=freq[order].tolist(), | |
| mean_act=mean_act[order].tolist(), | |
| p75_val=p75_np[order].tolist(), | |
| phi_c_val=_phi_c_vals(order.tolist()), | |
| name=[_display_name(int(i)) for i in order], | |
| ) | |
| def _update_table_names(): | |
| """Refresh the name column after saving or deleting a feature name.""" | |
| _apply_order(np.array(feature_list_source.data['feature_idx'])) | |
| def _on_table_select(attr, old, new): | |
| if new: | |
| feat = feature_list_source.data['feature_idx'][new[0]] | |
| feature_input.value = str(feat) | |
| _select_and_display(feat) | |
| feature_list_source.selected.on_change('indices', _on_table_select) | |
| # ---------- Auto-save name on typing ---------- | |
| def on_name_change(attr, old, new): | |
| try: | |
| feat = int(feature_input.value) | |
| except ValueError: | |
| return | |
| name = new.strip() | |
| if name: | |
| feature_names[feat] = name | |
| elif feat in feature_names: | |
| del feature_names[feat] | |
| _save_names() | |
| _update_table_names() | |
| name_input.on_change('value', on_name_change) | |
| # ---------- Gemini auto-interp button ---------- | |
| _N_GEMINI_IMAGES = 6 | |
| _GEMINI_MODEL = "gemini-2.5-flash" | |
| _GEMINI_HM_ALPHA = 0.25 # heatmap overlay opacity sent to Gemini | |
| def _gemini_label_thread(feat, mei_items, doc): | |
| """Run in a worker thread: call Gemini and push the result back to the doc. | |
| mei_items: list of (path_str, heatmap_np_or_None) where heatmap is (H, W) float32. | |
| """ | |
| try: | |
| from google import genai | |
| from google.genai import types | |
| SYSTEM_PROMPT = ( | |
| "You are labeling features of a Sparse Autoencoder (SAE) trained on a " | |
| "vision transformer. Each SAE feature is a sparse direction in activation " | |
| "space that fires strongly on certain visual patterns." | |
| ) | |
| USER_PROMPT = ( | |
| "The images below are the top maximally-activating images for one SAE feature. " | |
| "In 2–5 words, give a precise label for the visual concept this feature detects. " | |
| "Be specific — prefer 'dog snout close-up' over 'dog', or 'brick wall texture' " | |
| "over 'texture'. " | |
| "Reply with ONLY the label, no explanation, no punctuation at the end." | |
| ) | |
| client = genai.Client(api_key=_gemini_api_key) | |
| parts = [] | |
| for path, _heatmap in mei_items[:_N_GEMINI_IMAGES]: | |
| resolved = _resolve_img_path(path) | |
| if resolved is None: | |
| continue | |
| try: | |
| img = Image.open(resolved).convert("RGB").resize((224, 224), Image.BILINEAR) | |
| buf = io.BytesIO() | |
| img.save(buf, format="JPEG", quality=85) | |
| parts.append(types.Part.from_bytes(data=buf.getvalue(), mime_type="image/jpeg")) | |
| except Exception: | |
| continue | |
| if not parts: | |
| def _no_images(): | |
| gemini_btn.disabled = False | |
| gemini_status_div.text = "<span style='color:#c00'>No images could be loaded.</span>" | |
| doc.add_next_tick_callback(_no_images) | |
| return | |
| parts.append(types.Part.from_text(text=USER_PROMPT)) | |
| response = client.models.generate_content( | |
| model=_GEMINI_MODEL, | |
| contents=parts, | |
| config=types.GenerateContentConfig(system_instruction=SYSTEM_PROMPT), | |
| ) | |
| label = response.text.strip().strip(".,;:\"'") | |
| def _apply_label(feat=feat, label=label): | |
| auto_interp_names[feat] = label | |
| _save_auto_interp() | |
| _update_table_names() | |
| # Refresh the stats panel so the [auto] label appears immediately | |
| try: | |
| update_feature_display(feat) | |
| except Exception: | |
| pass | |
| gemini_btn.disabled = False | |
| gemini_status_div.text = ( | |
| f"<span style='color:#1a6faf'><b>Labeled:</b> {label}</span>" | |
| ) | |
| print(f" [Gemini] feat {feat}: {label}") | |
| doc.add_next_tick_callback(_apply_label) | |
| except Exception as e: | |
| err = str(e) | |
| def _show_err(err=err): | |
| gemini_btn.disabled = False | |
| gemini_status_div.text = f"<span style='color:#c00'>Error: {err[:120]}</span>" | |
| print(f" [Gemini] feat {feat} error: {err}") | |
| doc.add_next_tick_callback(_show_err) | |
| def _on_gemini_click(): | |
| try: | |
| feat = int(feature_input.value) | |
| except ValueError: | |
| gemini_status_div.text = "<span style='color:#c00'>Select a feature first.</span>" | |
| return | |
| if feature_frequency[feat].item() == 0: | |
| gemini_status_div.text = "<span style='color:#c00'>Dead feature — no images.</span>" | |
| return | |
| n_top_stored = top_img_idx.shape[1] | |
| mei_items = [] | |
| for j in range(n_top_stored): | |
| idx = top_img_idx[feat, j].item() | |
| if idx >= 0: | |
| hm = None | |
| if top_heatmaps is not None: | |
| hm = top_heatmaps[feat, j].float().numpy().reshape(heatmap_patch_grid, heatmap_patch_grid) | |
| mei_items.append((image_paths[idx], hm)) | |
| if not mei_items: | |
| gemini_status_div.text = "<span style='color:#c00'>No MEI paths found.</span>" | |
| return | |
| gemini_btn.disabled = True | |
| gemini_status_div.text = "<i style='color:#888'>Calling Gemini…</i>" | |
| doc = curdoc() | |
| t = threading.Thread( | |
| target=_gemini_label_thread, | |
| args=(feat, mei_items, doc), | |
| daemon=True, | |
| ) | |
| t.start() | |
| if _gemini_api_key: | |
| gemini_btn.on_click(_on_gemini_click) | |
| # ---------- Search by name ---------- | |
| search_input = TextInput( | |
| title="Search feature names:", | |
| placeholder="Type to search...", | |
| width=220, | |
| ) | |
| search_btn = Button(label="Search", width=70, button_type="primary") | |
| clear_search_btn = Button(label="Clear", width=60) | |
| search_result_div = Div(text="", width=360) | |
| def _do_search(): | |
| query = search_input.value.strip().lower() | |
| if not query: | |
| _S.search_filter = None | |
| search_result_div.text = "" | |
| _apply_order(_get_sorted_order()) | |
| return | |
| matches = { | |
| i for i, name in feature_names.items() if query in name.lower() | |
| } | { | |
| i for i, name in auto_interp_names.items() if query in name.lower() | |
| } | |
| _S.search_filter = matches | |
| _apply_order(_get_sorted_order()) | |
| if matches: | |
| search_result_div.text = ( | |
| f'<span style="color:#1a6faf"><b>{len(matches)}</b> feature(s) matching ' | |
| f'“{query}”</span>' | |
| ) | |
| else: | |
| search_result_div.text = ( | |
| f'<span style="color:#c00">No features named “{query}”</span>' | |
| ) | |
| def _do_clear_search(): | |
| search_input.value = "" | |
| _S.search_filter = None | |
| search_result_div.text = "" | |
| _apply_order(_get_sorted_order()) | |
| search_btn.on_click(_do_search) | |
| clear_search_btn.on_click(_do_clear_search) | |
| # Summary — regenerated on every dataset switch | |
| def _make_summary_html(): | |
| ds = _all_datasets[_S.active] | |
| n_umap_act = int(live_mask.sum()) | |
| n_live_dict = int(dict_live_mask.sum()) | |
| n_truly_active = int((freq > 0).sum()) | |
| n_dead = d_model - n_truly_active | |
| tok_label = ("CLS global" if ds.get('token_type') == 'cls' | |
| else f"{patch_grid}×{patch_grid} = {patch_grid**2} patches") | |
| backbone_label = ds.get('backbone', 'dinov3').upper() | |
| clip_label = "yes" if (ds['clip_scores'] is not None or ds.get('clip_embeds') is not None) else "no" | |
| hm_label = "yes" if ds.get('top_heatmaps') is not None else "no" | |
| pa = ds.get('patch_acts') | |
| pa_label = f"yes ({len(pa['img_to_row'])} images)" if pa is not None else "no — run --save-patch-acts" | |
| sae_url = ds.get('sae_url') | |
| dl_row = (f'<tr><td><b>SAE weights:</b></td>' | |
| f'<td><a href="{sae_url}" download style="color:#1a6faf">⬇ Download</a></td></tr>' | |
| if sae_url else '') | |
| return f""" | |
| <div style="background:#f0f4f8;padding:12px;border-radius:6px;margin-bottom:8px;"> | |
| <h2 style="margin:0 0 8px 0">SAE Feature Explorer</h2> | |
| <table style="font-size:13px;"> | |
| <tr><td><b>Active model:</b></td><td><b style="color:#1a6faf">{ds['label']}</b></td></tr> | |
| <tr><td><b>Backbone:</b></td><td>{backbone_label}</td></tr> | |
| <tr><td><b>Token type:</b></td><td>{ds.get('token_type','spatial')}</td></tr> | |
| <tr><td><b>Dictionary size:</b></td><td>{d_model:,}</td></tr> | |
| <tr><td><b>Active (fired ≥1):</b></td><td>{n_truly_active:,} ({100*n_truly_active/d_model:.1f}%)</td></tr> | |
| <tr><td><b>Dead:</b></td><td>{n_dead:,} ({100*n_dead/d_model:.1f}%)</td></tr> | |
| <tr><td><b>Images:</b></td><td>{n_images:,}</td></tr> | |
| <tr><td><b>Tokens/image:</b></td><td>{tok_label}</td></tr> | |
| {dl_row} | |
| </table> | |
| </div>""" | |
| summary_div = Div(text=_make_summary_html(), width=700) | |
| # ---------- Patch Explorer ---------- | |
| # Click patches of an image to find the top active SAE features for that region. | |
| # Activations are served from pre-computed sidecars (no GPU required at serve time). | |
| _PATCH_FIG_PX = 400 | |
| # Raster-order (row, col) pairs for every patch cell. | |
| # _pr[i] = row index, _pc[i] = col index for flat patch i. | |
| _pr = [r for r in range(patch_grid) for _ in range(patch_grid)] # 0,0,...,0, 1,1,...,N-1 | |
| _pc = list(range(patch_grid)) * patch_grid # 0,1,...,N-1, 0,1,... | |
| patch_grid_source = ColumnDataSource(data=dict( | |
| x=[c + 0.5 for c in _pc], | |
| y=[patch_grid - r - 0.5 for r in _pr], | |
| row=_pr, | |
| col=_pc, | |
| )) | |
| patch_bg_source = ColumnDataSource(data=dict( | |
| image=[], x=[0], y=[0], dw=[patch_grid], dh=[patch_grid], | |
| )) | |
| patch_fig = figure( | |
| width=_PATCH_FIG_PX, height=_PATCH_FIG_PX, | |
| x_range=(0, patch_grid), y_range=(0, patch_grid), | |
| tools=["tap", "reset"], | |
| title="Click or drag to paint patch selection", | |
| toolbar_location="above", | |
| visible=False, | |
| ) | |
| # Paint-on-drag selection: any patch the mouse passes over while the button | |
| # is held gets added to the selection. We track button state with a | |
| # document-level mousedown/mouseup listener (set up lazily on first move). | |
| _paint_js = CustomJS(args=dict(source=patch_grid_source, pg=patch_grid), code=""" | |
| if (!window._patch_paint_init) { | |
| window._patch_paint_init = true; | |
| window._patch_btn_held = false; | |
| document.addEventListener('mousedown', () => { window._patch_btn_held = true; }); | |
| document.addEventListener('mouseup', () => { window._patch_btn_held = false; }); | |
| } | |
| if (!window._patch_btn_held) return; | |
| const x = cb_obj.x, y = cb_obj.y; | |
| if (x === null || y === null || x < 0 || x >= pg || y < 0 || y >= pg) return; | |
| const col = Math.floor(x); | |
| const row = pg - 1 - Math.floor(y); | |
| const flat_idx = row * pg + col; | |
| const sel = source.selected.indices.slice(); | |
| if (sel.indexOf(flat_idx) === -1) { | |
| sel.push(flat_idx); | |
| source.selected.indices = sel; | |
| } | |
| """) | |
| patch_fig.js_on_event(MouseMove, _paint_js) | |
| patch_fig.image_rgba( | |
| source=patch_bg_source, | |
| image='image', x='x', y='y', dw='dw', dh='dh', | |
| ) | |
| patch_fig.rect( | |
| source=patch_grid_source, | |
| x='x', y='y', width=0.95, height=0.95, | |
| fill_color='yellow', fill_alpha=0.0, | |
| line_color='white', line_alpha=0.35, line_width=0.5, | |
| selection_fill_color='red', selection_fill_alpha=0.45, | |
| nonselection_fill_alpha=0.0, nonselection_line_alpha=0.35, | |
| ) | |
| patch_fig.axis.visible = False | |
| patch_fig.xgrid.visible = False | |
| patch_fig.ygrid.visible = False | |
| patch_img_input = TextInput(title="Image Index:", value="0", width=120) | |
| load_patch_btn = Button(label="Load Image", width=90, button_type="primary") | |
| clear_patch_btn = Button(label="Clear", width=60) | |
| patch_feat_source = ColumnDataSource(data=dict( | |
| feature_idx=[], patch_act=[], frequency=[], mean_act=[], phi_c_val=[], | |
| )) | |
| patch_feat_table = DataTable( | |
| source=patch_feat_source, | |
| columns=[ | |
| TableColumn(field="feature_idx", title="Feature", width=65), | |
| TableColumn(field="patch_act", title="Patch Act", width=85, | |
| formatter=NumberFormatter(format="0.0000")), | |
| TableColumn(field="frequency", title="Freq", width=65, | |
| formatter=NumberFormatter(format="0,0")), | |
| TableColumn(field="mean_act", title="Mean Act", width=80, | |
| formatter=NumberFormatter(format="0.0000")), | |
| ] + _phi_col(), | |
| width=310 + (65 if HAS_PHI else 0), height=350, index_position=None, sortable=False, visible=False, | |
| ) | |
| patch_info_div = Div( | |
| text="<i>Load an image, then click patches to find top features.</i>", | |
| width=310, | |
| ) | |
| def _pil_to_bokeh_rgba(pil_img, size): | |
| pil_img = pil_img.resize((size, size), Image.BILINEAR).convert("RGBA") | |
| arr = np.array(pil_img, dtype=np.uint8) | |
| out = np.empty((size, size), dtype=np.uint32) | |
| view = out.view(dtype=np.uint8).reshape((size, size, 4)) | |
| view[:, :, :] = arr | |
| return out[::-1].copy() | |
| def _do_load_patch_image(): | |
| try: | |
| img_idx = _parse_img_label(patch_img_input.value) | |
| except ValueError: | |
| patch_info_div.text = "<b style='color:red'>Invalid image index</b>" | |
| return | |
| if not (0 <= img_idx < n_images): | |
| patch_info_div.text = f"<b style='color:red'>Index out of range (0–{n_images - 1})</b>" | |
| return | |
| _S.patch_img = img_idx | |
| try: | |
| pil = load_image(img_idx) | |
| bokeh_arr = _pil_to_bokeh_rgba(pil, _PATCH_FIG_PX) | |
| patch_bg_source.data = dict( | |
| image=[bokeh_arr], x=[0], y=[0], dw=[patch_grid], dh=[patch_grid], | |
| ) | |
| except Exception as e: | |
| patch_info_div.text = f"<b style='color:red'>Error loading image: {e}</b>" | |
| return | |
| # Show spinner immediately, then compute (possibly slow GPU inference) in background. | |
| load_patch_btn.disabled = True | |
| patch_info_div.text = ( | |
| "<span style='color:#1a6faf'>⏳ Computing patch activations" | |
| + (" (running GPU inference — first image may take ~10 s)…" | |
| if _gpu_runner[0] is None and args.sae_path else "…") | |
| + "</span>" | |
| ) | |
| _doc = curdoc() | |
| def _bg(): | |
| try: | |
| z_np = compute_patch_activations(img_idx) | |
| except Exception as e: | |
| err = str(e) | |
| def _show_err(err=err): | |
| load_patch_btn.disabled = False | |
| patch_info_div.text = f"<b style='color:red'>Error: {err}</b>" | |
| _doc.add_next_tick_callback(_show_err) | |
| return | |
| def _apply(z_np=z_np): | |
| _S.patch_z = z_np | |
| load_patch_btn.disabled = False | |
| patch_fig.visible = True | |
| patch_grid_source.selected.indices = [] | |
| patch_feat_source.data = dict(feature_idx=[], patch_act=[], frequency=[], mean_act=[], phi_c_val=[]) | |
| if z_np is None: | |
| patch_feat_table.visible = False | |
| patch_info_div.text = ( | |
| f"<b style='color:#888'>Image {img_idx} has no pre-computed patch activations " | |
| f"and no GPU runner is available. Pass --sae-path to the explorer to enable " | |
| f"live GPU inference for any image.</b>" | |
| ) | |
| return | |
| patch_feat_table.visible = True | |
| _ds = _all_datasets[_S.active] | |
| _pa = _ds.get('patch_acts') | |
| source = "patch_acts" if (_pa is not None and img_idx in _pa['img_to_row']) else "GPU inference" | |
| patch_info_div.text = ( | |
| f"Image {img_idx} loaded ({source}). " | |
| f"Drag to select a region, or click individual patches." | |
| ) | |
| _doc.add_next_tick_callback(_apply) | |
| threading.Thread(target=_bg, daemon=True).start() | |
| load_patch_btn.on_click(_do_load_patch_image) | |
| def _do_clear_patches(): | |
| patch_grid_source.selected.indices = [] | |
| patch_feat_source.data = dict(feature_idx=[], patch_act=[], frequency=[], mean_act=[], phi_c_val=[]) | |
| patch_info_div.text = "<i>Selection cleared.</i>" | |
| clear_patch_btn.on_click(_do_clear_patches) | |
| def _get_top_features_for_patches(patch_indices, top_n=20): | |
| """Sum SAE activations over selected patches; return top features.""" | |
| z_np = _S.patch_z | |
| if z_np is None: | |
| return [], [], [], [] | |
| # z_np: (n_patches, d_model) — vectorized sum over selected patches | |
| z_selected = z_np[patch_indices] # (n_sel, d_model) | |
| feat_sums = z_selected.sum(axis=0) # (d_model,) | |
| top_feats = np.argsort(-feat_sums)[:top_n] | |
| top_feats = top_feats[feat_sums[top_feats] > 0] # keep only nonzero | |
| feats = top_feats.tolist() | |
| acts = feat_sums[top_feats].tolist() | |
| freqs = [int(feature_frequency[f].item()) for f in feats] | |
| means = [float(feature_mean_act[f].item()) for f in feats] | |
| return feats, acts, freqs, means | |
| def _on_patch_select(attr, old, new): | |
| if _S.patch_img is None: | |
| return | |
| if not new: | |
| patch_feat_source.data = dict(feature_idx=[], patch_act=[], frequency=[], mean_act=[], phi_c_val=[]) | |
| patch_info_div.text = "<i>Selection cleared.</i>" | |
| return | |
| # Convert selected rect indices to flat patch indices | |
| rows = [patch_grid_source.data['row'][i] for i in new] | |
| cols = [patch_grid_source.data['col'][i] for i in new] | |
| patch_indices = [r * patch_grid + c for r, c in zip(rows, cols)] | |
| feats, acts, freqs, means = _get_top_features_for_patches(patch_indices) | |
| patch_feat_source.data = dict( | |
| feature_idx=feats, patch_act=acts, frequency=freqs, mean_act=means, | |
| phi_c_val=_phi_c_vals(feats), | |
| ) | |
| patch_info_div.text = ( | |
| f"{len(new)} patch(es) selected → {len(feats)} feature(s) found. " | |
| f"Click a row below to explore the feature." | |
| ) | |
| patch_grid_source.selected.on_change('indices', _on_patch_select) | |
| def _on_patch_feat_table_select(attr, old, new): | |
| if not new: | |
| return | |
| feat = patch_feat_source.data['feature_idx'][new[0]] | |
| feature_input.value = str(feat) | |
| _select_and_display(feat) | |
| patch_feat_source.selected.on_change('indices', _on_patch_feat_table_select) | |
| # ---------- CLIP Text Search ---------- | |
| def _build_clip_panel(): | |
| """Build the CLIP text-search panel widgets and callbacks. | |
| Returns (panel, result_div, result_source). | |
| When HAS_CLIP is False, result_div and result_source are None and panel is a | |
| static placeholder Div. | |
| """ | |
| if not HAS_CLIP: | |
| panel = Div( | |
| text="<i style='color:#aaa'>CLIP text search unavailable — " | |
| "run <code>scripts/add_clip_embeddings.py</code> to enable.</i>", | |
| width=470, | |
| ) | |
| return panel, None, None | |
| clip_query_input = TextInput( | |
| title="Search features by text (CLIP):", | |
| placeholder="e.g. 'dog', 'red stripes', 'water'...", | |
| width=280, | |
| ) | |
| clip_search_btn = Button(label="Search", width=70, button_type="primary") | |
| result_div = Div(text="", width=360) | |
| clip_top_k_input = TextInput(title="Top-K results:", value="20", width=70) | |
| result_source = ColumnDataSource(data=dict( | |
| feature_idx=[], clip_score=[], frequency=[], mean_act=[], phi_c_val=[], name=[], | |
| )) | |
| clip_result_table = DataTable( | |
| source=result_source, | |
| columns=[ | |
| TableColumn(field="feature_idx", title="Feature", width=65), | |
| TableColumn(field="clip_score", title="CLIP score", width=85, | |
| formatter=NumberFormatter(format="0.0000")), | |
| TableColumn(field="frequency", title="Freq", width=65, | |
| formatter=NumberFormatter(format="0,0")), | |
| TableColumn(field="mean_act", title="Mean Act", width=80, | |
| formatter=NumberFormatter(format="0.0000")), | |
| ] + _phi_col() + [ | |
| TableColumn(field="name", title="Name", width=160), | |
| ], | |
| width=470 + (65 if HAS_PHI else 0), height=300, index_position=None, sortable=False, | |
| ) | |
| def _do_search(): | |
| query = clip_query_input.value.strip() | |
| if not query: | |
| result_div.text = "<i>Enter a text query above.</i>" | |
| return | |
| try: | |
| top_k = max(1, int(clip_top_k_input.value)) | |
| except ValueError: | |
| top_k = 20 | |
| # Check if query matches a vocab term exactly (case-insensitive) | |
| vocab_lower = [v.lower() for v in (_clip_vocab or [])] | |
| if _clip_vocab and query.lower() in vocab_lower: | |
| col = vocab_lower.index(query.lower()) | |
| scores_vec = _clip_scores_f32[:, col] | |
| elif _clip_embeds is not None or _nsd_clip_embeds is not None: | |
| # Free-text: encode on-the-fly with CLIP, dot with feature image embeds. | |
| # Use NSD-specific embeds when the subset toggle is active. | |
| _use_nsd_embeds = nsd_subset_toggle.active == 1 and _nsd_clip_embeds is not None | |
| _active_embeds = _nsd_clip_embeds if _use_nsd_embeds else _clip_embeds | |
| result_div.text = "<i>Encoding query with CLIP…</i>" | |
| try: | |
| clip_m, clip_p, clip_dev = _get_clip() | |
| q_embed = compute_text_embeddings([query], clip_m, clip_p, clip_dev) | |
| scores_vec = (_active_embeds.float() @ q_embed.T).squeeze(-1) | |
| except Exception as exc: | |
| result_div.text = f"<span style='color:#c00'>CLIP error: {exc}</span>" | |
| return | |
| else: | |
| result_div.text = ( | |
| f"<span style='color:#c00'>Query not in vocab and no feature embeddings " | |
| f"available. Try one of: {', '.join((_clip_vocab or [])[:8])}…</span>" | |
| ) | |
| return | |
| # When NSD subset toggle is active, restrict to features with at least one NSD image | |
| if nsd_subset_toggle.active == 1 and HAS_NSD_SUBSET: | |
| nsd_mask = nsd_top_img_idx[:, 0] >= 0 # (d_model,) bool | |
| scores_vec = scores_vec.clone() | |
| scores_vec[~nsd_mask] = float('-inf') | |
| top_indices = torch.topk(scores_vec, k=min(top_k, len(scores_vec))).indices.tolist() | |
| # Drop any -inf results (features with no NSD images when subset is active) | |
| top_indices = [i for i in top_indices if scores_vec[i] > float('-inf')] | |
| result_source.data = dict( | |
| feature_idx=top_indices, | |
| clip_score=[float(scores_vec[i]) for i in top_indices], | |
| frequency=[int(feature_frequency[i].item()) for i in top_indices], | |
| mean_act=[float(feature_mean_act[i].item()) for i in top_indices], | |
| phi_c_val=_phi_c_vals(top_indices), | |
| name=[_display_name(int(i)) for i in top_indices], | |
| ) | |
| _subset_note = " [NSD sub01]" if (nsd_subset_toggle.active == 1 and HAS_NSD_SUBSET) else "" | |
| result_div.text = ( | |
| f'<span style="color:#1a6faf"><b>{len(top_indices)}</b> features for ' | |
| f'“{query}”{_subset_note}</span>' | |
| ) | |
| clip_search_btn.on_click(_do_search) | |
| def _on_result_select(attr, old, new): | |
| if not new: | |
| return | |
| feat = result_source.data['feature_idx'][new[0]] | |
| feature_input.value = str(feat) | |
| _select_and_display(feat) | |
| result_source.selected.on_change('indices', _on_result_select) | |
| panel = column( | |
| row(clip_query_input, clip_top_k_input, clip_search_btn), | |
| result_div, | |
| clip_result_table, | |
| ) | |
| return panel, result_div, result_source | |
| clip_search_panel, clip_result_div, clip_result_source = _build_clip_panel() | |
| # ---------- Layout ---------- | |
| controls = row(umap_type_select, umap_color_select, feature_input, go_button, random_btn) | |
| name_panel = column( | |
| name_input, | |
| row(gemini_btn, gemini_status_div), | |
| ) | |
| search_panel = column( | |
| row(search_input, search_btn, clear_search_btn), | |
| search_result_div, | |
| ) | |
| feature_list_panel = column(search_panel, feature_table) | |
| def _make_collapsible(title, body, initially_open=False): | |
| """Wrap a widget in a toggle-able collapsible section.""" | |
| btn = Toggle( | |
| label=("▼ " if initially_open else "▶ ") + title, | |
| active=initially_open, | |
| button_type="light", | |
| width=500, | |
| height=30, | |
| ) | |
| body.visible = initially_open | |
| btn.js_on_click(CustomJS(args=dict(body=body, btn=btn, title=title), code=""" | |
| body.visible = btn.active; | |
| btn.label = (btn.active ? '▼ ' : '▶ ') + title; | |
| """)) | |
| return column(btn, body) | |
| patch_explorer_panel = column( | |
| row(patch_img_input, load_patch_btn, clear_patch_btn), | |
| patch_fig, | |
| patch_info_div, | |
| patch_feat_table, | |
| ) | |
| summary_section = _make_collapsible("SAE Summary", summary_div) | |
| patch_section = _make_collapsible("Patch Explorer", patch_explorer_panel) | |
| clip_section = _make_collapsible("CLIP Text Search", clip_search_panel) | |
| _ds_select_row = ([dataset_select] if len(_all_datasets) > 1 and args.compare_data else []) | |
| left_panel = column(*_ds_select_row, controls, umap_fig, feature_list_panel) | |
| middle_panel = column( | |
| status_div, | |
| stats_div, | |
| name_panel, | |
| row(view_select, | |
| column(Div(text="<b>Images:</b>", width=60, height=15, styles={"padding-top":"5px"}), | |
| nsd_subset_toggle), | |
| column(zoom_slider, heatmap_alpha_slider)), | |
| compare_agg_div, | |
| top_heatmap_div, | |
| mean_heatmap_div, | |
| brain_div, | |
| ) | |
| dd_section = ( | |
| _make_collapsible("DynaDiff Brain Steering", _dd_panel_body, initially_open=True) | |
| if HAS_DYNADIFF else Div(text="", width=1) | |
| ) | |
| right_panel = column(summary_section, patch_section, clip_section, dd_section) | |
| layout = row(left_panel, middle_panel, right_panel) | |
| curdoc().add_root(layout) | |
| curdoc().title = "SAE Feature Explorer" | |
| print("Explorer app ready!") | |
| # Warm up GPU runner in background so the first patch explore request is instant. | |
| if args.sae_path: | |
| def _warmup_gpu(): | |
| try: | |
| _get_gpu_runner() | |
| except Exception as _e: | |
| print(f"[GPU runner] Warmup failed: {_e}") | |
| threading.Thread(target=_warmup_gpu, daemon=True).start() | |