Marlin Lee commited on
Commit ·
f169dfb
1
Parent(s): 1068a69
Sync space code
Browse files- scripts/explorer_app.py +39 -120
scripts/explorer_app.py
CHANGED
|
@@ -4,29 +4,13 @@ Interactive SAE Feature Explorer - Bokeh Server App.
|
|
| 4 |
Visualizes SAE features with:
|
| 5 |
- UMAP scatter plot of features (activation-based and dictionary-based)
|
| 6 |
- Click a feature to see its top-activating images with heatmap overlays
|
| 7 |
-
-
|
| 8 |
-
|
| 9 |
- Feature naming: assign names to features, saved to JSON, searchable
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
No GPU or model weights are required at serve time.
|
| 13 |
-
|
| 14 |
-
Launch:
|
| 15 |
-
bokeh serve explorer_app.py --port 5006 --allow-websocket-origin="*" \
|
| 16 |
-
--session-token-expiration 86400 \
|
| 17 |
-
--args \
|
| 18 |
-
--data ../../smart_init_stability_SAE/explorer_data_d32000_k160_val.pt \
|
| 19 |
-
--image-dir /scratch.global/lee02328/val \
|
| 20 |
-
--extra-image-dir /scratch.global/lee02328/coco/val2017 \
|
| 21 |
-
--primary-label "DINOv3 L24 Spatial (d=32K)" \
|
| 22 |
-
--compare-data ../../smart_init_stability_SAE/explorer_data_18.pt \
|
| 23 |
-
--compare-labels "DINOv3 L18 Spatial (d=20K)" \
|
| 24 |
-
--phi-dir /path/to/phis \
|
| 25 |
-
--brain-data /path/to/brain_meis_dinov3.pt \
|
| 26 |
-
--brain-thumbnails /path/to/nsd_thumbs
|
| 27 |
-
|
| 28 |
-
Then SSH tunnel: ssh -L 5006:<node>:5006 <user>@<login-node>
|
| 29 |
-
Open: http://localhost:5006/explorer_app
|
| 30 |
"""
|
| 31 |
|
| 32 |
import argparse
|
|
@@ -37,7 +21,6 @@ import base64
|
|
| 37 |
import random
|
| 38 |
import threading
|
| 39 |
from collections import OrderedDict
|
| 40 |
-
from functools import partial
|
| 41 |
|
| 42 |
import cv2
|
| 43 |
import numpy as np
|
|
@@ -56,7 +39,7 @@ from bokeh.layouts import column, row
|
|
| 56 |
from bokeh.events import MouseMove
|
| 57 |
from bokeh.models import (
|
| 58 |
ColumnDataSource, HoverTool, Div, Select, TextInput, Button,
|
| 59 |
-
DataTable, TableColumn, NumberFormatter,
|
| 60 |
Slider, Toggle, RadioButtonGroup, CustomJS,
|
| 61 |
)
|
| 62 |
from bokeh.plotting import figure
|
|
@@ -77,11 +60,6 @@ parser.add_argument("--inference-cache-size", type=int, default=64,
|
|
| 77 |
parser.add_argument("--names-file", type=str, default=None,
|
| 78 |
help="Path to JSON file for saving feature names "
|
| 79 |
"(default: <data>_feature_names.json)")
|
| 80 |
-
parser.add_argument("--compare-data", type=str, nargs="*", default=[],
|
| 81 |
-
help="Additional explorer_data.pt files to show in cross-dataset "
|
| 82 |
-
"comparison panel (e.g. layer 18, CLS SAE)")
|
| 83 |
-
parser.add_argument("--compare-labels", type=str, nargs="*", default=[],
|
| 84 |
-
help="Display labels for each --compare-data file")
|
| 85 |
parser.add_argument("--primary-label", type=str, default="Primary",
|
| 86 |
help="Display label for the primary --data file")
|
| 87 |
parser.add_argument("--clip-model", type=str, default="openai/clip-vit-large-patch14",
|
|
@@ -91,10 +69,7 @@ parser.add_argument("--google-api-key", type=str, default=None,
|
|
| 91 |
help="Google API key for Gemini auto-interp button "
|
| 92 |
"(default: GOOGLE_API_KEY env var)")
|
| 93 |
parser.add_argument("--sae-url", type=str, default=None,
|
| 94 |
-
help="Download URL for the
|
| 95 |
-
"shown as a link in the summary panel")
|
| 96 |
-
parser.add_argument("--compare-sae-urls", type=str, nargs="*", default=[],
|
| 97 |
-
help="Download URLs for each --compare-data dataset's SAE weights (in order)")
|
| 98 |
parser.add_argument("--phi-dir", type=str, default=None,
|
| 99 |
help="Directory containing Phi_cv_*.npy, phi_c_*.npy, voxel_coords.npy "
|
| 100 |
"(brain-alignment data; enables cortical profile and brain leverage features)")
|
|
@@ -133,27 +108,29 @@ args = parser.parse_args()
|
|
| 133 |
|
| 134 |
|
| 135 |
# ---------- Lazy CLIP model (loaded on first free-text query) ----------
|
| 136 |
-
|
| 137 |
-
_clip_handle = [None] # (model, processor, device)
|
| 138 |
|
| 139 |
def _get_clip():
|
| 140 |
"""Load CLIP once and cache it."""
|
| 141 |
-
|
|
|
|
| 142 |
_dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 143 |
print(f"[CLIP] Loading {args.clip_model} on {_dev} (first free-text query)...")
|
| 144 |
_m, _p = load_clip(_dev, model_name=args.clip_model)
|
| 145 |
-
_clip_handle
|
| 146 |
print("[CLIP] Ready.")
|
| 147 |
-
return _clip_handle
|
| 148 |
|
| 149 |
|
| 150 |
# ---------- GPU backbone + SAE runner (optional, lazy-loaded) ----------
|
| 151 |
-
|
|
|
|
| 152 |
|
| 153 |
def _get_gpu_runner():
|
| 154 |
-
"""Load backbone + SAE on GPU once; return
|
| 155 |
-
|
| 156 |
-
|
|
|
|
| 157 |
if not args.sae_path:
|
| 158 |
return None
|
| 159 |
if not torch.cuda.is_available():
|
|
@@ -167,9 +144,9 @@ def _get_gpu_runner():
|
|
| 167 |
print(f"[GPU runner] Loading {args.backbone} layer {args.layer} + SAE on {_dev} ...")
|
| 168 |
_fwd, _d_hidden, _n_reg, _tfm = load_batched_backbone(args.backbone, args.layer, _dev)
|
| 169 |
_sae = load_sae(args.sae_path, _d_hidden, d_model, args.top_k, _dev)
|
| 170 |
-
_gpu_runner
|
| 171 |
print("[GPU runner] Ready.")
|
| 172 |
-
return _gpu_runner
|
| 173 |
|
| 174 |
|
| 175 |
def _run_gpu_inference(pil_img):
|
|
@@ -372,27 +349,11 @@ def _load_dataset_dict(path, label, sae_url=None):
|
|
| 372 |
entry['heatmap_patch_grid'] = d['patch_grid']
|
| 373 |
has_hm = 'no'
|
| 374 |
|
| 375 |
-
# Load pre-computed patch activations sidecar if present.
|
| 376 |
-
# Enables complete GPU-free patch exploration for any image covered by the file.
|
| 377 |
-
pa_sidecar = os.path.splitext(path)[0] + '_patch_acts.pt'
|
| 378 |
-
if os.path.exists(pa_sidecar):
|
| 379 |
-
print(f" Loading pre-computed patch acts from {os.path.basename(pa_sidecar)} ...")
|
| 380 |
-
pa = torch.load(pa_sidecar, map_location='cpu', weights_only=True)
|
| 381 |
-
img_to_row = {int(idx): row for row, idx in enumerate(pa['img_indices'].tolist())}
|
| 382 |
-
entry['patch_acts'] = {
|
| 383 |
-
'feat_indices': pa['feat_indices'], # (n_unique, n_patches, top_k) int16
|
| 384 |
-
'feat_values': pa['feat_values'], # (n_unique, n_patches, top_k) float16
|
| 385 |
-
'img_to_row': img_to_row,
|
| 386 |
-
}
|
| 387 |
-
print(f" patch_acts: {len(img_to_row)} images covered (GPU-free patch explorer)")
|
| 388 |
-
else:
|
| 389 |
-
entry['patch_acts'] = None
|
| 390 |
-
|
| 391 |
entry['sae_url'] = sae_url
|
| 392 |
|
| 393 |
print(f" d={entry['d_model']}, n={entry['n_images']}, token={entry['token_type']}, "
|
| 394 |
f"backbone={entry['backbone']}, clip={'yes' if (cs is not None or entry.get('clip_embeds') is not None) else 'no'}, "
|
| 395 |
-
f"heatmaps={has_hm}
|
| 396 |
return entry
|
| 397 |
|
| 398 |
|
|
@@ -410,7 +371,7 @@ class _S:
|
|
| 410 |
render_token: int = 0 # incremented on each feature selection; stale renders bail out
|
| 411 |
search_filter = None # set of feature indices matching the current name search, or None
|
| 412 |
color_by: str = "Log Frequency" # which field drives UMAP point colour
|
| 413 |
-
hf_push
|
| 414 |
patch_img = None # image index currently loaded in the patch explorer
|
| 415 |
patch_z = None # cached (n_patches, d_model) float32 for the loaded image
|
| 416 |
|
|
@@ -423,16 +384,6 @@ def _ds():
|
|
| 423 |
# Primary dataset — always loaded eagerly
|
| 424 |
_all_datasets.append(_load_dataset_dict(args.data, args.primary_label, sae_url=args.sae_url))
|
| 425 |
|
| 426 |
-
# Compare datasets — stored as lazy placeholders; loaded on first access
|
| 427 |
-
for _ci, _cpath in enumerate(args.compare_data):
|
| 428 |
-
_clabel = (args.compare_labels[_ci]
|
| 429 |
-
if args.compare_labels and _ci < len(args.compare_labels)
|
| 430 |
-
else os.path.basename(_cpath))
|
| 431 |
-
_csae = (args.compare_sae_urls[_ci]
|
| 432 |
-
if args.compare_sae_urls and _ci < len(args.compare_sae_urls)
|
| 433 |
-
else None)
|
| 434 |
-
_all_datasets.append({'label': _clabel, 'path': _cpath, '_lazy': True, 'sae_url': _csae})
|
| 435 |
-
|
| 436 |
def _load_brain_dataset_dict(path, label, thumb_dir):
|
| 437 |
"""Load a brain_meis.pt file and return a dataset entry dict.
|
| 438 |
|
|
@@ -497,7 +448,6 @@ def _load_brain_dataset_dict(path, label, thumb_dir):
|
|
| 497 |
'feature_names': {},
|
| 498 |
'auto_interp_names': {},
|
| 499 |
'sae_url': None,
|
| 500 |
-
'patch_acts': None,
|
| 501 |
}
|
| 502 |
|
| 503 |
# Load pre-computed heatmaps sidecar if present.
|
|
@@ -700,43 +650,23 @@ def _display_name(feat: int) -> str:
|
|
| 700 |
|
| 701 |
|
| 702 |
def compute_patch_activations(img_idx):
|
| 703 |
-
"""Return (n_patches, d_sae) float32
|
| 704 |
|
| 705 |
-
|
| 706 |
-
1. LRU cache
|
| 707 |
-
2. Pre-computed patch_acts lookup — complete activations for covered images
|
| 708 |
-
3. GPU live inference — full activations via backbone + SAE (requires --sae-path)
|
| 709 |
-
Uses a per-dataset LRU cache.
|
| 710 |
"""
|
| 711 |
ds = _all_datasets[_S.active]
|
| 712 |
cache = ds['inference_cache']
|
| 713 |
|
| 714 |
-
# 1. LRU cache
|
| 715 |
if img_idx in cache:
|
| 716 |
cache.move_to_end(img_idx)
|
| 717 |
return cache[img_idx]
|
| 718 |
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
if row is not None:
|
| 726 |
-
fi = pa['feat_indices'][row].numpy() # (n_patches, top_k) int16
|
| 727 |
-
fv = pa['feat_values'][row].float().numpy() # (n_patches, top_k) float32
|
| 728 |
-
n_p = fi.shape[0]
|
| 729 |
-
z_np = np.zeros((n_p, ds['d_model']), dtype=np.float32)
|
| 730 |
-
z_np[np.arange(n_p)[:, None], fi.astype(np.int32)] = fv
|
| 731 |
-
|
| 732 |
-
# 3. GPU live inference
|
| 733 |
-
if z_np is None:
|
| 734 |
-
try:
|
| 735 |
-
pil = load_image(img_idx)
|
| 736 |
-
z_np = _run_gpu_inference(pil)
|
| 737 |
-
except Exception as _e:
|
| 738 |
-
print(f"[GPU runner] inference failed for img {img_idx}: {_e}")
|
| 739 |
-
z_np = None
|
| 740 |
|
| 741 |
if z_np is not None:
|
| 742 |
cache[img_idx] = z_np
|
|
@@ -1306,21 +1236,16 @@ def _on_dataset_switch(attr, old, new):
|
|
| 1306 |
# Update summary panel
|
| 1307 |
summary_div.text = _make_summary_html()
|
| 1308 |
|
| 1309 |
-
# Show/hide patch explorer depending on token type and
|
| 1310 |
ds = _all_datasets[idx]
|
| 1311 |
-
has_heatmaps = ds.get('top_heatmaps') is not None
|
| 1312 |
-
has_patch_acts = ds.get('patch_acts') is not None
|
| 1313 |
can_explore = (
|
| 1314 |
ds.get('token_type', 'spatial') == 'spatial'
|
| 1315 |
-
and (
|
| 1316 |
)
|
| 1317 |
patch_fig.visible = can_explore
|
| 1318 |
patch_info_div.visible = can_explore
|
| 1319 |
if not can_explore:
|
| 1320 |
-
if ds.get('token_type') == 'cls'
|
| 1321 |
-
reason = "CLS token — no patch grid"
|
| 1322 |
-
else:
|
| 1323 |
-
reason = "no pre-computed heatmaps or patch_acts for this model"
|
| 1324 |
patch_info_div.text = (
|
| 1325 |
f'<p style="color:#888;font-style:italic">Patch explorer unavailable: {reason}.</p>')
|
| 1326 |
patch_info_div.visible = True
|
|
@@ -2223,8 +2148,6 @@ def _make_summary_html():
|
|
| 2223 |
backbone_label = ds.get('backbone', 'dinov3').upper()
|
| 2224 |
clip_label = "yes" if (ds['clip_scores'] is not None or ds.get('clip_embeds') is not None) else "no"
|
| 2225 |
hm_label = "yes" if ds.get('top_heatmaps') is not None else "no"
|
| 2226 |
-
pa = ds.get('patch_acts')
|
| 2227 |
-
pa_label = f"yes ({len(pa['img_to_row'])} images)" if pa is not None else "no — run --save-patch-acts"
|
| 2228 |
sae_url = ds.get('sae_url')
|
| 2229 |
dl_row = (f'<tr><td><b>SAE weights:</b></td>'
|
| 2230 |
f'<td><a href="{sae_url}" download style="color:#1a6faf">⬇ Download</a></td></tr>'
|
|
@@ -2250,7 +2173,7 @@ summary_div = Div(text=_make_summary_html(), width=700)
|
|
| 2250 |
|
| 2251 |
# ---------- Patch Explorer ----------
|
| 2252 |
# Click patches of an image to find the top active SAE features for that region.
|
| 2253 |
-
# Activations are
|
| 2254 |
|
| 2255 |
_PATCH_FIG_PX = 400
|
| 2256 |
|
|
@@ -2383,7 +2306,7 @@ def _do_load_patch_image():
|
|
| 2383 |
patch_info_div.text = (
|
| 2384 |
"<span style='color:#1a6faf'>⏳ Computing patch activations"
|
| 2385 |
+ (" (running GPU inference — first image may take ~10 s)…"
|
| 2386 |
-
if _gpu_runner
|
| 2387 |
+ "</span>"
|
| 2388 |
)
|
| 2389 |
|
|
@@ -2410,18 +2333,14 @@ def _do_load_patch_image():
|
|
| 2410 |
if z_np is None:
|
| 2411 |
patch_feat_table.visible = False
|
| 2412 |
patch_info_div.text = (
|
| 2413 |
-
f"<b style='color:#888'>
|
| 2414 |
-
f"
|
| 2415 |
-
f"live GPU inference for any image.</b>"
|
| 2416 |
)
|
| 2417 |
return
|
| 2418 |
|
| 2419 |
patch_feat_table.visible = True
|
| 2420 |
-
_ds = _all_datasets[_S.active]
|
| 2421 |
-
_pa = _ds.get('patch_acts')
|
| 2422 |
-
source = "patch_acts" if (_pa is not None and img_idx in _pa['img_to_row']) else "GPU inference"
|
| 2423 |
patch_info_div.text = (
|
| 2424 |
-
f"Image {img_idx} loaded
|
| 2425 |
f"Drag to select a region, or click individual patches."
|
| 2426 |
)
|
| 2427 |
|
|
@@ -2667,7 +2586,7 @@ summary_section = _make_collapsible("SAE Summary", summary_div)
|
|
| 2667 |
patch_section = _make_collapsible("Patch Explorer", patch_explorer_panel)
|
| 2668 |
clip_section = _make_collapsible("CLIP Text Search", clip_search_panel)
|
| 2669 |
|
| 2670 |
-
_ds_select_row = ([dataset_select] if len(_all_datasets) > 1
|
| 2671 |
left_panel = column(*_ds_select_row, controls, umap_fig, feature_list_panel)
|
| 2672 |
|
| 2673 |
middle_panel = column(
|
|
|
|
| 4 |
Visualizes SAE features with:
|
| 5 |
- UMAP scatter plot of features (activation-based and dictionary-based)
|
| 6 |
- Click a feature to see its top-activating images with heatmap overlays
|
| 7 |
+
- Patch explorer: click patches of any image to find active SAE features
|
| 8 |
+
(uses live GPU inference via the backbone + SAE loaded from --sae-path)
|
| 9 |
- Feature naming: assign names to features, saved to JSON, searchable
|
| 10 |
+
- CLIP text search, Gemini auto-interp, DynaDiff brain steering panel
|
| 11 |
+
- Optional NSD brain MEI dataset (--brain-data) shown in the dataset dropdown
|
| 12 |
|
| 13 |
+
Launch: see run_explorer.sh
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
import argparse
|
|
|
|
| 21 |
import random
|
| 22 |
import threading
|
| 23 |
from collections import OrderedDict
|
|
|
|
| 24 |
|
| 25 |
import cv2
|
| 26 |
import numpy as np
|
|
|
|
| 39 |
from bokeh.events import MouseMove
|
| 40 |
from bokeh.models import (
|
| 41 |
ColumnDataSource, HoverTool, Div, Select, TextInput, Button,
|
| 42 |
+
DataTable, TableColumn, NumberFormatter, NumberEditor,
|
| 43 |
Slider, Toggle, RadioButtonGroup, CustomJS,
|
| 44 |
)
|
| 45 |
from bokeh.plotting import figure
|
|
|
|
| 60 |
parser.add_argument("--names-file", type=str, default=None,
|
| 61 |
help="Path to JSON file for saving feature names "
|
| 62 |
"(default: <data>_feature_names.json)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
parser.add_argument("--primary-label", type=str, default="Primary",
|
| 64 |
help="Display label for the primary --data file")
|
| 65 |
parser.add_argument("--clip-model", type=str, default="openai/clip-vit-large-patch14",
|
|
|
|
| 69 |
help="Google API key for Gemini auto-interp button "
|
| 70 |
"(default: GOOGLE_API_KEY env var)")
|
| 71 |
parser.add_argument("--sae-url", type=str, default=None,
|
| 72 |
+
help="Download URL for the SAE weights — shown as a link in the summary panel")
|
|
|
|
|
|
|
|
|
|
| 73 |
parser.add_argument("--phi-dir", type=str, default=None,
|
| 74 |
help="Directory containing Phi_cv_*.npy, phi_c_*.npy, voxel_coords.npy "
|
| 75 |
"(brain-alignment data; enables cortical profile and brain leverage features)")
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
# ---------- Lazy CLIP model (loaded on first free-text query) ----------
|
| 111 |
+
_clip_handle = None # (model, processor, device), set on first use
|
|
|
|
| 112 |
|
| 113 |
def _get_clip():
|
| 114 |
"""Load CLIP once and cache it."""
|
| 115 |
+
global _clip_handle
|
| 116 |
+
if _clip_handle is None:
|
| 117 |
_dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 118 |
print(f"[CLIP] Loading {args.clip_model} on {_dev} (first free-text query)...")
|
| 119 |
_m, _p = load_clip(_dev, model_name=args.clip_model)
|
| 120 |
+
_clip_handle = (_m, _p, _dev)
|
| 121 |
print("[CLIP] Ready.")
|
| 122 |
+
return _clip_handle
|
| 123 |
|
| 124 |
|
| 125 |
# ---------- GPU backbone + SAE runner (optional, lazy-loaded) ----------
|
| 126 |
+
# Tuple of (forward_fn, sae, transform_fn, n_reg, extract_tokens_fn, backbone_name, device)
|
| 127 |
+
_gpu_runner = None
|
| 128 |
|
| 129 |
def _get_gpu_runner():
|
| 130 |
+
"""Load backbone + SAE on GPU once; return the runner tuple or None."""
|
| 131 |
+
global _gpu_runner
|
| 132 |
+
if _gpu_runner is not None:
|
| 133 |
+
return _gpu_runner
|
| 134 |
if not args.sae_path:
|
| 135 |
return None
|
| 136 |
if not torch.cuda.is_available():
|
|
|
|
| 144 |
print(f"[GPU runner] Loading {args.backbone} layer {args.layer} + SAE on {_dev} ...")
|
| 145 |
_fwd, _d_hidden, _n_reg, _tfm = load_batched_backbone(args.backbone, args.layer, _dev)
|
| 146 |
_sae = load_sae(args.sae_path, _d_hidden, d_model, args.top_k, _dev)
|
| 147 |
+
_gpu_runner = (_fwd, _sae, _tfm, _n_reg, _et, args.backbone, _dev)
|
| 148 |
print("[GPU runner] Ready.")
|
| 149 |
+
return _gpu_runner
|
| 150 |
|
| 151 |
|
| 152 |
def _run_gpu_inference(pil_img):
|
|
|
|
| 349 |
entry['heatmap_patch_grid'] = d['patch_grid']
|
| 350 |
has_hm = 'no'
|
| 351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
entry['sae_url'] = sae_url
|
| 353 |
|
| 354 |
print(f" d={entry['d_model']}, n={entry['n_images']}, token={entry['token_type']}, "
|
| 355 |
f"backbone={entry['backbone']}, clip={'yes' if (cs is not None or entry.get('clip_embeds') is not None) else 'no'}, "
|
| 356 |
+
f"heatmaps={has_hm}")
|
| 357 |
return entry
|
| 358 |
|
| 359 |
|
|
|
|
| 371 |
render_token: int = 0 # incremented on each feature selection; stale renders bail out
|
| 372 |
search_filter = None # set of feature indices matching the current name search, or None
|
| 373 |
color_by: str = "Log Frequency" # which field drives UMAP point colour
|
| 374 |
+
hf_push = None # active Bokeh timeout handle for debounced HuggingFace upload
|
| 375 |
patch_img = None # image index currently loaded in the patch explorer
|
| 376 |
patch_z = None # cached (n_patches, d_model) float32 for the loaded image
|
| 377 |
|
|
|
|
| 384 |
# Primary dataset — always loaded eagerly
|
| 385 |
_all_datasets.append(_load_dataset_dict(args.data, args.primary_label, sae_url=args.sae_url))
|
| 386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
def _load_brain_dataset_dict(path, label, thumb_dir):
|
| 388 |
"""Load a brain_meis.pt file and return a dataset entry dict.
|
| 389 |
|
|
|
|
| 448 |
'feature_names': {},
|
| 449 |
'auto_interp_names': {},
|
| 450 |
'sae_url': None,
|
|
|
|
| 451 |
}
|
| 452 |
|
| 453 |
# Load pre-computed heatmaps sidecar if present.
|
|
|
|
| 650 |
|
| 651 |
|
| 652 |
def compute_patch_activations(img_idx):
|
| 653 |
+
"""Return (n_patches, d_sae) float32 via GPU inference, or None if unavailable.
|
| 654 |
|
| 655 |
+
Results are cached in a per-dataset LRU cache keyed by image index.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
"""
|
| 657 |
ds = _all_datasets[_S.active]
|
| 658 |
cache = ds['inference_cache']
|
| 659 |
|
|
|
|
| 660 |
if img_idx in cache:
|
| 661 |
cache.move_to_end(img_idx)
|
| 662 |
return cache[img_idx]
|
| 663 |
|
| 664 |
+
try:
|
| 665 |
+
pil = load_image(img_idx)
|
| 666 |
+
z_np = _run_gpu_inference(pil)
|
| 667 |
+
except Exception as _e:
|
| 668 |
+
print(f"[GPU runner] inference failed for img {img_idx}: {_e}")
|
| 669 |
+
z_np = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 670 |
|
| 671 |
if z_np is not None:
|
| 672 |
cache[img_idx] = z_np
|
|
|
|
| 1236 |
# Update summary panel
|
| 1237 |
summary_div.text = _make_summary_html()
|
| 1238 |
|
| 1239 |
+
# Show/hide patch explorer depending on token type (spatial required) and GPU availability.
|
| 1240 |
ds = _all_datasets[idx]
|
|
|
|
|
|
|
| 1241 |
can_explore = (
|
| 1242 |
ds.get('token_type', 'spatial') == 'spatial'
|
| 1243 |
+
and bool(args.sae_path)
|
| 1244 |
)
|
| 1245 |
patch_fig.visible = can_explore
|
| 1246 |
patch_info_div.visible = can_explore
|
| 1247 |
if not can_explore:
|
| 1248 |
+
reason = "CLS token — no patch grid" if ds.get('token_type') == 'cls' else "no --sae-path provided"
|
|
|
|
|
|
|
|
|
|
| 1249 |
patch_info_div.text = (
|
| 1250 |
f'<p style="color:#888;font-style:italic">Patch explorer unavailable: {reason}.</p>')
|
| 1251 |
patch_info_div.visible = True
|
|
|
|
| 2148 |
backbone_label = ds.get('backbone', 'dinov3').upper()
|
| 2149 |
clip_label = "yes" if (ds['clip_scores'] is not None or ds.get('clip_embeds') is not None) else "no"
|
| 2150 |
hm_label = "yes" if ds.get('top_heatmaps') is not None else "no"
|
|
|
|
|
|
|
| 2151 |
sae_url = ds.get('sae_url')
|
| 2152 |
dl_row = (f'<tr><td><b>SAE weights:</b></td>'
|
| 2153 |
f'<td><a href="{sae_url}" download style="color:#1a6faf">⬇ Download</a></td></tr>'
|
|
|
|
| 2173 |
|
| 2174 |
# ---------- Patch Explorer ----------
|
| 2175 |
# Click patches of an image to find the top active SAE features for that region.
|
| 2176 |
+
# Activations are computed on-the-fly via GPU inference (backbone + SAE from --sae-path).
|
| 2177 |
|
| 2178 |
_PATCH_FIG_PX = 400
|
| 2179 |
|
|
|
|
| 2306 |
patch_info_div.text = (
|
| 2307 |
"<span style='color:#1a6faf'>⏳ Computing patch activations"
|
| 2308 |
+ (" (running GPU inference — first image may take ~10 s)…"
|
| 2309 |
+
if _gpu_runner is None and args.sae_path else "…")
|
| 2310 |
+ "</span>"
|
| 2311 |
)
|
| 2312 |
|
|
|
|
| 2333 |
if z_np is None:
|
| 2334 |
patch_feat_table.visible = False
|
| 2335 |
patch_info_div.text = (
|
| 2336 |
+
f"<b style='color:#888'>GPU inference unavailable for image {img_idx}. "
|
| 2337 |
+
f"Ensure --sae-path is set and the GPU runner loaded successfully.</b>"
|
|
|
|
| 2338 |
)
|
| 2339 |
return
|
| 2340 |
|
| 2341 |
patch_feat_table.visible = True
|
|
|
|
|
|
|
|
|
|
| 2342 |
patch_info_div.text = (
|
| 2343 |
+
f"Image {img_idx} loaded. "
|
| 2344 |
f"Drag to select a region, or click individual patches."
|
| 2345 |
)
|
| 2346 |
|
|
|
|
| 2586 |
patch_section = _make_collapsible("Patch Explorer", patch_explorer_panel)
|
| 2587 |
clip_section = _make_collapsible("CLIP Text Search", clip_search_panel)
|
| 2588 |
|
| 2589 |
+
_ds_select_row = ([dataset_select] if len(_all_datasets) > 1 else [])
|
| 2590 |
left_panel = column(*_ds_select_row, controls, umap_fig, feature_list_panel)
|
| 2591 |
|
| 2592 |
middle_panel = column(
|