Marlin Lee commited on
Commit
f169dfb
·
1 Parent(s): 1068a69

Sync space code

Browse files
Files changed (1) hide show
  1. scripts/explorer_app.py +39 -120
scripts/explorer_app.py CHANGED
@@ -4,29 +4,13 @@ Interactive SAE Feature Explorer - Bokeh Server App.
4
  Visualizes SAE features with:
5
  - UMAP scatter plot of features (activation-based and dictionary-based)
6
  - Click a feature to see its top-activating images with heatmap overlays
7
- - 75th percentile images for distribution understanding
8
- - Patch explorer: click patches of any image to find active features
9
  - Feature naming: assign names to features, saved to JSON, searchable
 
 
10
 
11
- All display is driven by pre-computed sidecars (_heatmaps.pt, _patch_acts.pt).
12
- No GPU or model weights are required at serve time.
13
-
14
- Launch:
15
- bokeh serve explorer_app.py --port 5006 --allow-websocket-origin="*" \
16
- --session-token-expiration 86400 \
17
- --args \
18
- --data ../../smart_init_stability_SAE/explorer_data_d32000_k160_val.pt \
19
- --image-dir /scratch.global/lee02328/val \
20
- --extra-image-dir /scratch.global/lee02328/coco/val2017 \
21
- --primary-label "DINOv3 L24 Spatial (d=32K)" \
22
- --compare-data ../../smart_init_stability_SAE/explorer_data_18.pt \
23
- --compare-labels "DINOv3 L18 Spatial (d=20K)" \
24
- --phi-dir /path/to/phis \
25
- --brain-data /path/to/brain_meis_dinov3.pt \
26
- --brain-thumbnails /path/to/nsd_thumbs
27
-
28
- Then SSH tunnel: ssh -L 5006:<node>:5006 <user>@<login-node>
29
- Open: http://localhost:5006/explorer_app
30
  """
31
 
32
  import argparse
@@ -37,7 +21,6 @@ import base64
37
  import random
38
  import threading
39
  from collections import OrderedDict
40
- from functools import partial
41
 
42
  import cv2
43
  import numpy as np
@@ -56,7 +39,7 @@ from bokeh.layouts import column, row
56
  from bokeh.events import MouseMove
57
  from bokeh.models import (
58
  ColumnDataSource, HoverTool, Div, Select, TextInput, Button,
59
- DataTable, TableColumn, NumberFormatter, IntEditor, NumberEditor,
60
  Slider, Toggle, RadioButtonGroup, CustomJS,
61
  )
62
  from bokeh.plotting import figure
@@ -77,11 +60,6 @@ parser.add_argument("--inference-cache-size", type=int, default=64,
77
  parser.add_argument("--names-file", type=str, default=None,
78
  help="Path to JSON file for saving feature names "
79
  "(default: <data>_feature_names.json)")
80
- parser.add_argument("--compare-data", type=str, nargs="*", default=[],
81
- help="Additional explorer_data.pt files to show in cross-dataset "
82
- "comparison panel (e.g. layer 18, CLS SAE)")
83
- parser.add_argument("--compare-labels", type=str, nargs="*", default=[],
84
- help="Display labels for each --compare-data file")
85
  parser.add_argument("--primary-label", type=str, default="Primary",
86
  help="Display label for the primary --data file")
87
  parser.add_argument("--clip-model", type=str, default="openai/clip-vit-large-patch14",
@@ -91,10 +69,7 @@ parser.add_argument("--google-api-key", type=str, default=None,
91
  help="Google API key for Gemini auto-interp button "
92
  "(default: GOOGLE_API_KEY env var)")
93
  parser.add_argument("--sae-url", type=str, default=None,
94
- help="Download URL for the primary dataset's SAE weights — "
95
- "shown as a link in the summary panel")
96
- parser.add_argument("--compare-sae-urls", type=str, nargs="*", default=[],
97
- help="Download URLs for each --compare-data dataset's SAE weights (in order)")
98
  parser.add_argument("--phi-dir", type=str, default=None,
99
  help="Directory containing Phi_cv_*.npy, phi_c_*.npy, voxel_coords.npy "
100
  "(brain-alignment data; enables cortical profile and brain leverage features)")
@@ -133,27 +108,29 @@ args = parser.parse_args()
133
 
134
 
135
  # ---------- Lazy CLIP model (loaded on first free-text query) ----------
136
- # _clip_handle[0] is None until the first out-of-vocab query is issued.
137
- _clip_handle = [None] # (model, processor, device)
138
 
139
  def _get_clip():
140
  """Load CLIP once and cache it."""
141
- if _clip_handle[0] is None:
 
142
  _dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
143
  print(f"[CLIP] Loading {args.clip_model} on {_dev} (first free-text query)...")
144
  _m, _p = load_clip(_dev, model_name=args.clip_model)
145
- _clip_handle[0] = (_m, _p, _dev)
146
  print("[CLIP] Ready.")
147
- return _clip_handle[0]
148
 
149
 
150
  # ---------- GPU backbone + SAE runner (optional, lazy-loaded) ----------
151
- _gpu_runner = [None] # (forward_fn, sae, transform_fn, n_reg, extract_tokens_fn, backbone_name, device) or None
 
152
 
153
  def _get_gpu_runner():
154
- """Load backbone + SAE on GPU once; return (forward_fn, sae, transform_fn, device) or None."""
155
- if _gpu_runner[0] is not None:
156
- return _gpu_runner[0]
 
157
  if not args.sae_path:
158
  return None
159
  if not torch.cuda.is_available():
@@ -167,9 +144,9 @@ def _get_gpu_runner():
167
  print(f"[GPU runner] Loading {args.backbone} layer {args.layer} + SAE on {_dev} ...")
168
  _fwd, _d_hidden, _n_reg, _tfm = load_batched_backbone(args.backbone, args.layer, _dev)
169
  _sae = load_sae(args.sae_path, _d_hidden, d_model, args.top_k, _dev)
170
- _gpu_runner[0] = (_fwd, _sae, _tfm, _n_reg, _et, args.backbone, _dev)
171
  print("[GPU runner] Ready.")
172
- return _gpu_runner[0]
173
 
174
 
175
  def _run_gpu_inference(pil_img):
@@ -372,27 +349,11 @@ def _load_dataset_dict(path, label, sae_url=None):
372
  entry['heatmap_patch_grid'] = d['patch_grid']
373
  has_hm = 'no'
374
 
375
- # Load pre-computed patch activations sidecar if present.
376
- # Enables complete GPU-free patch exploration for any image covered by the file.
377
- pa_sidecar = os.path.splitext(path)[0] + '_patch_acts.pt'
378
- if os.path.exists(pa_sidecar):
379
- print(f" Loading pre-computed patch acts from {os.path.basename(pa_sidecar)} ...")
380
- pa = torch.load(pa_sidecar, map_location='cpu', weights_only=True)
381
- img_to_row = {int(idx): row for row, idx in enumerate(pa['img_indices'].tolist())}
382
- entry['patch_acts'] = {
383
- 'feat_indices': pa['feat_indices'], # (n_unique, n_patches, top_k) int16
384
- 'feat_values': pa['feat_values'], # (n_unique, n_patches, top_k) float16
385
- 'img_to_row': img_to_row,
386
- }
387
- print(f" patch_acts: {len(img_to_row)} images covered (GPU-free patch explorer)")
388
- else:
389
- entry['patch_acts'] = None
390
-
391
  entry['sae_url'] = sae_url
392
 
393
  print(f" d={entry['d_model']}, n={entry['n_images']}, token={entry['token_type']}, "
394
  f"backbone={entry['backbone']}, clip={'yes' if (cs is not None or entry.get('clip_embeds') is not None) else 'no'}, "
395
- f"heatmaps={has_hm}, patch_acts={'yes' if entry['patch_acts'] else 'no'}")
396
  return entry
397
 
398
 
@@ -410,7 +371,7 @@ class _S:
410
  render_token: int = 0 # incremented on each feature selection; stale renders bail out
411
  search_filter = None # set of feature indices matching the current name search, or None
412
  color_by: str = "Log Frequency" # which field drives UMAP point colour
413
- hf_push: object = None # active Bokeh timeout handle for debounced HuggingFace upload
414
  patch_img = None # image index currently loaded in the patch explorer
415
  patch_z = None # cached (n_patches, d_model) float32 for the loaded image
416
 
@@ -423,16 +384,6 @@ def _ds():
423
  # Primary dataset — always loaded eagerly
424
  _all_datasets.append(_load_dataset_dict(args.data, args.primary_label, sae_url=args.sae_url))
425
 
426
- # Compare datasets — stored as lazy placeholders; loaded on first access
427
- for _ci, _cpath in enumerate(args.compare_data):
428
- _clabel = (args.compare_labels[_ci]
429
- if args.compare_labels and _ci < len(args.compare_labels)
430
- else os.path.basename(_cpath))
431
- _csae = (args.compare_sae_urls[_ci]
432
- if args.compare_sae_urls and _ci < len(args.compare_sae_urls)
433
- else None)
434
- _all_datasets.append({'label': _clabel, 'path': _cpath, '_lazy': True, 'sae_url': _csae})
435
-
436
  def _load_brain_dataset_dict(path, label, thumb_dir):
437
  """Load a brain_meis.pt file and return a dataset entry dict.
438
 
@@ -497,7 +448,6 @@ def _load_brain_dataset_dict(path, label, thumb_dir):
497
  'feature_names': {},
498
  'auto_interp_names': {},
499
  'sae_url': None,
500
- 'patch_acts': None,
501
  }
502
 
503
  # Load pre-computed heatmaps sidecar if present.
@@ -700,43 +650,23 @@ def _display_name(feat: int) -> str:
700
 
701
 
702
  def compute_patch_activations(img_idx):
703
- """Return (n_patches, d_sae) float32 for the active dataset, or None.
704
 
705
- Priority order:
706
- 1. LRU cache
707
- 2. Pre-computed patch_acts lookup — complete activations for covered images
708
- 3. GPU live inference — full activations via backbone + SAE (requires --sae-path)
709
- Uses a per-dataset LRU cache.
710
  """
711
  ds = _all_datasets[_S.active]
712
  cache = ds['inference_cache']
713
 
714
- # 1. LRU cache
715
  if img_idx in cache:
716
  cache.move_to_end(img_idx)
717
  return cache[img_idx]
718
 
719
- z_np = None
720
-
721
- # 2. Try patch_acts lookup (complete activations for covered images)
722
- pa = ds.get('patch_acts')
723
- if pa is not None:
724
- row = pa['img_to_row'].get(img_idx)
725
- if row is not None:
726
- fi = pa['feat_indices'][row].numpy() # (n_patches, top_k) int16
727
- fv = pa['feat_values'][row].float().numpy() # (n_patches, top_k) float32
728
- n_p = fi.shape[0]
729
- z_np = np.zeros((n_p, ds['d_model']), dtype=np.float32)
730
- z_np[np.arange(n_p)[:, None], fi.astype(np.int32)] = fv
731
-
732
- # 3. GPU live inference
733
- if z_np is None:
734
- try:
735
- pil = load_image(img_idx)
736
- z_np = _run_gpu_inference(pil)
737
- except Exception as _e:
738
- print(f"[GPU runner] inference failed for img {img_idx}: {_e}")
739
- z_np = None
740
 
741
  if z_np is not None:
742
  cache[img_idx] = z_np
@@ -1306,21 +1236,16 @@ def _on_dataset_switch(attr, old, new):
1306
  # Update summary panel
1307
  summary_div.text = _make_summary_html()
1308
 
1309
- # Show/hide patch explorer depending on token type and data availability.
1310
  ds = _all_datasets[idx]
1311
- has_heatmaps = ds.get('top_heatmaps') is not None
1312
- has_patch_acts = ds.get('patch_acts') is not None
1313
  can_explore = (
1314
  ds.get('token_type', 'spatial') == 'spatial'
1315
- and (has_heatmaps or has_patch_acts)
1316
  )
1317
  patch_fig.visible = can_explore
1318
  patch_info_div.visible = can_explore
1319
  if not can_explore:
1320
- if ds.get('token_type') == 'cls':
1321
- reason = "CLS token — no patch grid"
1322
- else:
1323
- reason = "no pre-computed heatmaps or patch_acts for this model"
1324
  patch_info_div.text = (
1325
  f'<p style="color:#888;font-style:italic">Patch explorer unavailable: {reason}.</p>')
1326
  patch_info_div.visible = True
@@ -2223,8 +2148,6 @@ def _make_summary_html():
2223
  backbone_label = ds.get('backbone', 'dinov3').upper()
2224
  clip_label = "yes" if (ds['clip_scores'] is not None or ds.get('clip_embeds') is not None) else "no"
2225
  hm_label = "yes" if ds.get('top_heatmaps') is not None else "no"
2226
- pa = ds.get('patch_acts')
2227
- pa_label = f"yes ({len(pa['img_to_row'])} images)" if pa is not None else "no — run --save-patch-acts"
2228
  sae_url = ds.get('sae_url')
2229
  dl_row = (f'<tr><td><b>SAE weights:</b></td>'
2230
  f'<td><a href="{sae_url}" download style="color:#1a6faf">⬇ Download</a></td></tr>'
@@ -2250,7 +2173,7 @@ summary_div = Div(text=_make_summary_html(), width=700)
2250
 
2251
  # ---------- Patch Explorer ----------
2252
  # Click patches of an image to find the top active SAE features for that region.
2253
- # Activations are served from pre-computed sidecars (no GPU required at serve time).
2254
 
2255
  _PATCH_FIG_PX = 400
2256
 
@@ -2383,7 +2306,7 @@ def _do_load_patch_image():
2383
  patch_info_div.text = (
2384
  "<span style='color:#1a6faf'>&#x23F3; Computing patch activations"
2385
  + (" (running GPU inference — first image may take ~10 s)…"
2386
- if _gpu_runner[0] is None and args.sae_path else "…")
2387
  + "</span>"
2388
  )
2389
 
@@ -2410,18 +2333,14 @@ def _do_load_patch_image():
2410
  if z_np is None:
2411
  patch_feat_table.visible = False
2412
  patch_info_div.text = (
2413
- f"<b style='color:#888'>Image {img_idx} has no pre-computed patch activations "
2414
- f"and no GPU runner is available. Pass --sae-path to the explorer to enable "
2415
- f"live GPU inference for any image.</b>"
2416
  )
2417
  return
2418
 
2419
  patch_feat_table.visible = True
2420
- _ds = _all_datasets[_S.active]
2421
- _pa = _ds.get('patch_acts')
2422
- source = "patch_acts" if (_pa is not None and img_idx in _pa['img_to_row']) else "GPU inference"
2423
  patch_info_div.text = (
2424
- f"Image {img_idx} loaded ({source}). "
2425
  f"Drag to select a region, or click individual patches."
2426
  )
2427
 
@@ -2667,7 +2586,7 @@ summary_section = _make_collapsible("SAE Summary", summary_div)
2667
  patch_section = _make_collapsible("Patch Explorer", patch_explorer_panel)
2668
  clip_section = _make_collapsible("CLIP Text Search", clip_search_panel)
2669
 
2670
- _ds_select_row = ([dataset_select] if len(_all_datasets) > 1 and args.compare_data else [])
2671
  left_panel = column(*_ds_select_row, controls, umap_fig, feature_list_panel)
2672
 
2673
  middle_panel = column(
 
4
  Visualizes SAE features with:
5
  - UMAP scatter plot of features (activation-based and dictionary-based)
6
  - Click a feature to see its top-activating images with heatmap overlays
7
+ - Patch explorer: click patches of any image to find active SAE features
8
+ (uses live GPU inference via the backbone + SAE loaded from --sae-path)
9
  - Feature naming: assign names to features, saved to JSON, searchable
10
+ - CLIP text search, Gemini auto-interp, DynaDiff brain steering panel
11
+ - Optional NSD brain MEI dataset (--brain-data) shown in the dataset dropdown
12
 
13
+ Launch: see run_explorer.sh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
 
16
  import argparse
 
21
  import random
22
  import threading
23
  from collections import OrderedDict
 
24
 
25
  import cv2
26
  import numpy as np
 
39
  from bokeh.events import MouseMove
40
  from bokeh.models import (
41
  ColumnDataSource, HoverTool, Div, Select, TextInput, Button,
42
+ DataTable, TableColumn, NumberFormatter, NumberEditor,
43
  Slider, Toggle, RadioButtonGroup, CustomJS,
44
  )
45
  from bokeh.plotting import figure
 
60
  parser.add_argument("--names-file", type=str, default=None,
61
  help="Path to JSON file for saving feature names "
62
  "(default: <data>_feature_names.json)")
 
 
 
 
 
63
  parser.add_argument("--primary-label", type=str, default="Primary",
64
  help="Display label for the primary --data file")
65
  parser.add_argument("--clip-model", type=str, default="openai/clip-vit-large-patch14",
 
69
  help="Google API key for Gemini auto-interp button "
70
  "(default: GOOGLE_API_KEY env var)")
71
  parser.add_argument("--sae-url", type=str, default=None,
72
+ help="Download URL for the SAE weights — shown as a link in the summary panel")
 
 
 
73
  parser.add_argument("--phi-dir", type=str, default=None,
74
  help="Directory containing Phi_cv_*.npy, phi_c_*.npy, voxel_coords.npy "
75
  "(brain-alignment data; enables cortical profile and brain leverage features)")
 
108
 
109
 
110
  # ---------- Lazy CLIP model (loaded on first free-text query) ----------
111
+ _clip_handle = None # (model, processor, device), set on first use
 
112
 
113
  def _get_clip():
114
  """Load CLIP once and cache it."""
115
+ global _clip_handle
116
+ if _clip_handle is None:
117
  _dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
118
  print(f"[CLIP] Loading {args.clip_model} on {_dev} (first free-text query)...")
119
  _m, _p = load_clip(_dev, model_name=args.clip_model)
120
+ _clip_handle = (_m, _p, _dev)
121
  print("[CLIP] Ready.")
122
+ return _clip_handle
123
 
124
 
125
  # ---------- GPU backbone + SAE runner (optional, lazy-loaded) ----------
126
+ # Tuple of (forward_fn, sae, transform_fn, n_reg, extract_tokens_fn, backbone_name, device)
127
+ _gpu_runner = None
128
 
129
  def _get_gpu_runner():
130
+ """Load backbone + SAE on GPU once; return the runner tuple or None."""
131
+ global _gpu_runner
132
+ if _gpu_runner is not None:
133
+ return _gpu_runner
134
  if not args.sae_path:
135
  return None
136
  if not torch.cuda.is_available():
 
144
  print(f"[GPU runner] Loading {args.backbone} layer {args.layer} + SAE on {_dev} ...")
145
  _fwd, _d_hidden, _n_reg, _tfm = load_batched_backbone(args.backbone, args.layer, _dev)
146
  _sae = load_sae(args.sae_path, _d_hidden, d_model, args.top_k, _dev)
147
+ _gpu_runner = (_fwd, _sae, _tfm, _n_reg, _et, args.backbone, _dev)
148
  print("[GPU runner] Ready.")
149
+ return _gpu_runner
150
 
151
 
152
  def _run_gpu_inference(pil_img):
 
349
  entry['heatmap_patch_grid'] = d['patch_grid']
350
  has_hm = 'no'
351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  entry['sae_url'] = sae_url
353
 
354
  print(f" d={entry['d_model']}, n={entry['n_images']}, token={entry['token_type']}, "
355
  f"backbone={entry['backbone']}, clip={'yes' if (cs is not None or entry.get('clip_embeds') is not None) else 'no'}, "
356
+ f"heatmaps={has_hm}")
357
  return entry
358
 
359
 
 
371
  render_token: int = 0 # incremented on each feature selection; stale renders bail out
372
  search_filter = None # set of feature indices matching the current name search, or None
373
  color_by: str = "Log Frequency" # which field drives UMAP point colour
374
+ hf_push = None # active Bokeh timeout handle for debounced HuggingFace upload
375
  patch_img = None # image index currently loaded in the patch explorer
376
  patch_z = None # cached (n_patches, d_model) float32 for the loaded image
377
 
 
384
  # Primary dataset — always loaded eagerly
385
  _all_datasets.append(_load_dataset_dict(args.data, args.primary_label, sae_url=args.sae_url))
386
 
 
 
 
 
 
 
 
 
 
 
387
  def _load_brain_dataset_dict(path, label, thumb_dir):
388
  """Load a brain_meis.pt file and return a dataset entry dict.
389
 
 
448
  'feature_names': {},
449
  'auto_interp_names': {},
450
  'sae_url': None,
 
451
  }
452
 
453
  # Load pre-computed heatmaps sidecar if present.
 
650
 
651
 
652
  def compute_patch_activations(img_idx):
653
+ """Return (n_patches, d_sae) float32 via GPU inference, or None if unavailable.
654
 
655
+ Results are cached in a per-dataset LRU cache keyed by image index.
 
 
 
 
656
  """
657
  ds = _all_datasets[_S.active]
658
  cache = ds['inference_cache']
659
 
 
660
  if img_idx in cache:
661
  cache.move_to_end(img_idx)
662
  return cache[img_idx]
663
 
664
+ try:
665
+ pil = load_image(img_idx)
666
+ z_np = _run_gpu_inference(pil)
667
+ except Exception as _e:
668
+ print(f"[GPU runner] inference failed for img {img_idx}: {_e}")
669
+ z_np = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
670
 
671
  if z_np is not None:
672
  cache[img_idx] = z_np
 
1236
  # Update summary panel
1237
  summary_div.text = _make_summary_html()
1238
 
1239
+ # Show/hide patch explorer depending on token type (spatial required) and GPU availability.
1240
  ds = _all_datasets[idx]
 
 
1241
  can_explore = (
1242
  ds.get('token_type', 'spatial') == 'spatial'
1243
+ and bool(args.sae_path)
1244
  )
1245
  patch_fig.visible = can_explore
1246
  patch_info_div.visible = can_explore
1247
  if not can_explore:
1248
+ reason = "CLS token — no patch grid" if ds.get('token_type') == 'cls' else "no --sae-path provided"
 
 
 
1249
  patch_info_div.text = (
1250
  f'<p style="color:#888;font-style:italic">Patch explorer unavailable: {reason}.</p>')
1251
  patch_info_div.visible = True
 
2148
  backbone_label = ds.get('backbone', 'dinov3').upper()
2149
  clip_label = "yes" if (ds['clip_scores'] is not None or ds.get('clip_embeds') is not None) else "no"
2150
  hm_label = "yes" if ds.get('top_heatmaps') is not None else "no"
 
 
2151
  sae_url = ds.get('sae_url')
2152
  dl_row = (f'<tr><td><b>SAE weights:</b></td>'
2153
  f'<td><a href="{sae_url}" download style="color:#1a6faf">⬇ Download</a></td></tr>'
 
2173
 
2174
  # ---------- Patch Explorer ----------
2175
  # Click patches of an image to find the top active SAE features for that region.
2176
+ # Activations are computed on-the-fly via GPU inference (backbone + SAE from --sae-path).
2177
 
2178
  _PATCH_FIG_PX = 400
2179
 
 
2306
  patch_info_div.text = (
2307
  "<span style='color:#1a6faf'>&#x23F3; Computing patch activations"
2308
  + (" (running GPU inference — first image may take ~10 s)…"
2309
+ if _gpu_runner is None and args.sae_path else "…")
2310
  + "</span>"
2311
  )
2312
 
 
2333
  if z_np is None:
2334
  patch_feat_table.visible = False
2335
  patch_info_div.text = (
2336
+ f"<b style='color:#888'>GPU inference unavailable for image {img_idx}. "
2337
+ f"Ensure --sae-path is set and the GPU runner loaded successfully.</b>"
 
2338
  )
2339
  return
2340
 
2341
  patch_feat_table.visible = True
 
 
 
2342
  patch_info_div.text = (
2343
+ f"Image {img_idx} loaded. "
2344
  f"Drag to select a region, or click individual patches."
2345
  )
2346
 
 
2586
  patch_section = _make_collapsible("Patch Explorer", patch_explorer_panel)
2587
  clip_section = _make_collapsible("CLIP Text Search", clip_search_panel)
2588
 
2589
+ _ds_select_row = ([dataset_select] if len(_all_datasets) > 1 else [])
2590
  left_panel = column(*_ds_select_row, controls, umap_fig, feature_list_panel)
2591
 
2592
  middle_panel = column(