Marlin Lee commited on
Commit
7bcc77f
Β·
1 Parent(s): 8f6db74

Sync explorer_app.py and clip_utils.py from main repo

Browse files
Files changed (1) hide show
  1. scripts/explorer_app.py +19 -48
scripts/explorer_app.py CHANGED
@@ -100,11 +100,11 @@ parser.add_argument("--clip-model", type=str, default="openai/clip-vit-large-pat
100
  parser.add_argument("--google-api-key", type=str, default=None,
101
  help="Google API key for Gemini auto-interp button "
102
  "(default: GOOGLE_API_KEY env var)")
103
- parser.add_argument("--sae-path", type=str, default=None,
104
- help="Path to SAE weights (.pth) for the primary dataset β€” "
105
- "enables the Download SAE weights button in the summary panel")
106
- parser.add_argument("--compare-sae-paths", type=str, nargs="*", default=[],
107
- help="SAE weight paths for each --compare-data dataset (in order)")
108
  args = parser.parse_args()
109
 
110
 
@@ -125,7 +125,7 @@ def _get_clip():
125
 
126
  # ---------- Load all datasets into a unified list ----------
127
 
128
- def _load_dataset_dict(path, label, sae_path=None):
129
  """Load one explorer_data.pt file and return a unified dataset dict."""
130
  print(f"Loading [{label}] from {path} ...")
131
  d = torch.load(path, map_location='cpu', weights_only=False)
@@ -208,7 +208,7 @@ def _load_dataset_dict(path, label, sae_path=None):
208
  else:
209
  entry['patch_acts'] = None
210
 
211
- entry['sae_path'] = sae_path
212
 
213
  print(f" d={entry['d_model']}, n={entry['n_images']}, token={entry['token_type']}, "
214
  f"backbone={entry['backbone']}, clip={'yes' if cs is not None else 'no'}, "
@@ -220,17 +220,17 @@ _all_datasets = []
220
  _active = [0] # index of the currently displayed dataset
221
 
222
  # Primary dataset β€” always loaded eagerly
223
- _all_datasets.append(_load_dataset_dict(args.data, args.primary_label, sae_path=args.sae_path))
224
 
225
  # Compare datasets β€” stored as lazy placeholders; loaded on first access
226
  for _ci, _cpath in enumerate(args.compare_data):
227
  _clabel = (args.compare_labels[_ci]
228
  if args.compare_labels and _ci < len(args.compare_labels)
229
  else os.path.basename(_cpath))
230
- _csae = (args.compare_sae_paths[_ci]
231
- if args.compare_sae_paths and _ci < len(args.compare_sae_paths)
232
  else None)
233
- _all_datasets.append({'label': _clabel, 'path': _cpath, '_lazy': True, 'sae_path': _csae})
234
 
235
 
236
  def _ensure_loaded(idx):
@@ -238,7 +238,7 @@ def _ensure_loaded(idx):
238
  ds = _all_datasets[idx]
239
  if ds.get('_lazy', False):
240
  print(f"[Lazy load] Loading '{ds['label']}' on first access ...")
241
- _all_datasets[idx] = _load_dataset_dict(ds['path'], ds['label'], sae_path=ds.get('sae_path'))
242
 
243
 
244
  def _apply_dataset_globals(idx):
@@ -808,9 +808,8 @@ def _on_dataset_switch(attr, old, new):
808
  # Rebuild active-feature pool for random button
809
  _active_feats = [int(i) for i in range(d_model) if feature_frequency[i].item() > 0]
810
 
811
- # Update summary panel and SAE download source
812
  summary_div.text = _make_summary_html()
813
- _update_download_source()
814
 
815
  # Show/hide patch explorer depending on token type and data availability.
816
  ds = _all_datasets[idx]
@@ -1493,6 +1492,10 @@ def _make_summary_html():
1493
  hm_label = "yes" if ds.get('top_heatmaps') is not None else "no"
1494
  pa = ds.get('patch_acts')
1495
  pa_label = f"yes ({len(pa['img_to_row'])} images)" if pa is not None else "no β€” run --save-patch-acts"
 
 
 
 
1496
  return f"""
1497
  <div style="background:#f0f4f8;padding:12px;border-radius:6px;margin-bottom:8px;">
1498
  <h2 style="margin:0 0 8px 0">SAE Feature Explorer</h2>
@@ -1505,44 +1508,12 @@ def _make_summary_html():
1505
  <tr><td><b>Dead:</b></td><td>{n_dead:,} ({100*n_dead/d_model:.1f}%)</td></tr>
1506
  <tr><td><b>Images:</b></td><td>{n_images:,}</td></tr>
1507
  <tr><td><b>Tokens/image:</b></td><td>{tok_label}</td></tr>
 
1508
  </table>
1509
  </div>"""
1510
 
1511
  summary_div = Div(text=_make_summary_html(), width=700)
1512
 
1513
- # --- SAE weights download button ---
1514
- # Data is pre-loaded into the source so the JS download fires synchronously
1515
- # with the user click (bypassing browser gesture-context restrictions).
1516
-
1517
- def _sae_b64(sae_path):
1518
- """Return (b64_str, filename) for a SAE weights file, or ('', '') if unavailable."""
1519
- if not sae_path or not os.path.exists(sae_path):
1520
- return '', ''
1521
- with open(sae_path, 'rb') as f:
1522
- return base64.b64encode(f.read()).decode('ascii'), os.path.basename(sae_path)
1523
-
1524
- _init_b64, _init_fname = _sae_b64(_all_datasets[0].get('sae_path'))
1525
- _download_source = ColumnDataSource(data=dict(b64=[_init_b64], filename=[_init_fname]))
1526
-
1527
- sae_download_btn = Button(label="\u2b07 Download SAE weights", button_type="default", width=220)
1528
- sae_download_btn.js_on_click(CustomJS(args=dict(src=_download_source), code="""
1529
- const b64 = src.data['b64'][0];
1530
- const fname = src.data['filename'][0];
1531
- if (!b64) { alert('No SAE weights available for this model.'); return; }
1532
- const bytes = Uint8Array.from(atob(b64), c => c.charCodeAt(0));
1533
- const blob = new Blob([bytes], {type: 'application/octet-stream'});
1534
- const url = URL.createObjectURL(blob);
1535
- const a = document.createElement('a');
1536
- a.href = url; a.download = fname; a.click();
1537
- URL.revokeObjectURL(url);
1538
- """))
1539
-
1540
- def _update_download_source():
1541
- """Reload SAE weights for the active dataset into _download_source."""
1542
- ds = _all_datasets[_active[0]]
1543
- b64, fname = _sae_b64(ds.get('sae_path'))
1544
- _download_source.data = dict(b64=[b64], filename=[fname])
1545
-
1546
 
1547
  # ---------- Patch Explorer ----------
1548
  # Click patches of an image to find the top active SAE features for that region.
@@ -1920,7 +1891,7 @@ patch_explorer_panel = column(
1920
  patch_feat_table,
1921
  )
1922
 
1923
- summary_section = _make_collapsible("SAE Summary", column(summary_div, sae_download_btn))
1924
  patch_section = _make_collapsible("Patch Explorer", patch_explorer_panel)
1925
  clip_section = _make_collapsible("CLIP Text Search", clip_search_panel)
1926
 
 
100
  parser.add_argument("--google-api-key", type=str, default=None,
101
  help="Google API key for Gemini auto-interp button "
102
  "(default: GOOGLE_API_KEY env var)")
103
+ parser.add_argument("--sae-url", type=str, default=None,
104
+ help="Download URL for the primary dataset's SAE weights β€” "
105
+ "shown as a link in the summary panel")
106
+ parser.add_argument("--compare-sae-urls", type=str, nargs="*", default=[],
107
+ help="Download URLs for each --compare-data dataset's SAE weights (in order)")
108
  args = parser.parse_args()
109
 
110
 
 
125
 
126
  # ---------- Load all datasets into a unified list ----------
127
 
128
+ def _load_dataset_dict(path, label, sae_url=None):
129
  """Load one explorer_data.pt file and return a unified dataset dict."""
130
  print(f"Loading [{label}] from {path} ...")
131
  d = torch.load(path, map_location='cpu', weights_only=False)
 
208
  else:
209
  entry['patch_acts'] = None
210
 
211
+ entry['sae_url'] = sae_url
212
 
213
  print(f" d={entry['d_model']}, n={entry['n_images']}, token={entry['token_type']}, "
214
  f"backbone={entry['backbone']}, clip={'yes' if cs is not None else 'no'}, "
 
220
  _active = [0] # index of the currently displayed dataset
221
 
222
  # Primary dataset β€” always loaded eagerly
223
+ _all_datasets.append(_load_dataset_dict(args.data, args.primary_label, sae_url=args.sae_url))
224
 
225
  # Compare datasets β€” stored as lazy placeholders; loaded on first access
226
  for _ci, _cpath in enumerate(args.compare_data):
227
  _clabel = (args.compare_labels[_ci]
228
  if args.compare_labels and _ci < len(args.compare_labels)
229
  else os.path.basename(_cpath))
230
+ _csae = (args.compare_sae_urls[_ci]
231
+ if args.compare_sae_urls and _ci < len(args.compare_sae_urls)
232
  else None)
233
+ _all_datasets.append({'label': _clabel, 'path': _cpath, '_lazy': True, 'sae_url': _csae})
234
 
235
 
236
  def _ensure_loaded(idx):
 
238
  ds = _all_datasets[idx]
239
  if ds.get('_lazy', False):
240
  print(f"[Lazy load] Loading '{ds['label']}' on first access ...")
241
+ _all_datasets[idx] = _load_dataset_dict(ds['path'], ds['label'], sae_url=ds.get('sae_url'))
242
 
243
 
244
  def _apply_dataset_globals(idx):
 
808
  # Rebuild active-feature pool for random button
809
  _active_feats = [int(i) for i in range(d_model) if feature_frequency[i].item() > 0]
810
 
811
+ # Update summary panel
812
  summary_div.text = _make_summary_html()
 
813
 
814
  # Show/hide patch explorer depending on token type and data availability.
815
  ds = _all_datasets[idx]
 
1492
  hm_label = "yes" if ds.get('top_heatmaps') is not None else "no"
1493
  pa = ds.get('patch_acts')
1494
  pa_label = f"yes ({len(pa['img_to_row'])} images)" if pa is not None else "no β€” run --save-patch-acts"
1495
+ sae_url = ds.get('sae_url')
1496
+ dl_row = (f'<tr><td><b>SAE weights:</b></td>'
1497
+ f'<td><a href="{sae_url}" download style="color:#1a6faf">⬇ Download</a></td></tr>'
1498
+ if sae_url else '')
1499
  return f"""
1500
  <div style="background:#f0f4f8;padding:12px;border-radius:6px;margin-bottom:8px;">
1501
  <h2 style="margin:0 0 8px 0">SAE Feature Explorer</h2>
 
1508
  <tr><td><b>Dead:</b></td><td>{n_dead:,} ({100*n_dead/d_model:.1f}%)</td></tr>
1509
  <tr><td><b>Images:</b></td><td>{n_images:,}</td></tr>
1510
  <tr><td><b>Tokens/image:</b></td><td>{tok_label}</td></tr>
1511
+ {dl_row}
1512
  </table>
1513
  </div>"""
1514
 
1515
  summary_div = Div(text=_make_summary_html(), width=700)
1516
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1517
 
1518
  # ---------- Patch Explorer ----------
1519
  # Click patches of an image to find the top active SAE features for that region.
 
1891
  patch_feat_table,
1892
  )
1893
 
1894
+ summary_section = _make_collapsible("SAE Summary", summary_div)
1895
  patch_section = _make_collapsible("Patch Explorer", patch_explorer_panel)
1896
  clip_section = _make_collapsible("CLIP Text Search", clip_search_panel)
1897