Spaces:
Running
Running
Marlin Lee commited on
Commit Β·
7bcc77f
1
Parent(s): 8f6db74
Sync explorer_app.py and clip_utils.py from main repo
Browse files- scripts/explorer_app.py +19 -48
scripts/explorer_app.py
CHANGED
|
@@ -100,11 +100,11 @@ parser.add_argument("--clip-model", type=str, default="openai/clip-vit-large-pat
|
|
| 100 |
parser.add_argument("--google-api-key", type=str, default=None,
|
| 101 |
help="Google API key for Gemini auto-interp button "
|
| 102 |
"(default: GOOGLE_API_KEY env var)")
|
| 103 |
-
parser.add_argument("--sae-
|
| 104 |
-
help="
|
| 105 |
-
"
|
| 106 |
-
parser.add_argument("--compare-sae-
|
| 107 |
-
help="
|
| 108 |
args = parser.parse_args()
|
| 109 |
|
| 110 |
|
|
@@ -125,7 +125,7 @@ def _get_clip():
|
|
| 125 |
|
| 126 |
# ---------- Load all datasets into a unified list ----------
|
| 127 |
|
| 128 |
-
def _load_dataset_dict(path, label,
|
| 129 |
"""Load one explorer_data.pt file and return a unified dataset dict."""
|
| 130 |
print(f"Loading [{label}] from {path} ...")
|
| 131 |
d = torch.load(path, map_location='cpu', weights_only=False)
|
|
@@ -208,7 +208,7 @@ def _load_dataset_dict(path, label, sae_path=None):
|
|
| 208 |
else:
|
| 209 |
entry['patch_acts'] = None
|
| 210 |
|
| 211 |
-
entry['
|
| 212 |
|
| 213 |
print(f" d={entry['d_model']}, n={entry['n_images']}, token={entry['token_type']}, "
|
| 214 |
f"backbone={entry['backbone']}, clip={'yes' if cs is not None else 'no'}, "
|
|
@@ -220,17 +220,17 @@ _all_datasets = []
|
|
| 220 |
_active = [0] # index of the currently displayed dataset
|
| 221 |
|
| 222 |
# Primary dataset β always loaded eagerly
|
| 223 |
-
_all_datasets.append(_load_dataset_dict(args.data, args.primary_label,
|
| 224 |
|
| 225 |
# Compare datasets β stored as lazy placeholders; loaded on first access
|
| 226 |
for _ci, _cpath in enumerate(args.compare_data):
|
| 227 |
_clabel = (args.compare_labels[_ci]
|
| 228 |
if args.compare_labels and _ci < len(args.compare_labels)
|
| 229 |
else os.path.basename(_cpath))
|
| 230 |
-
_csae = (args.
|
| 231 |
-
if args.
|
| 232 |
else None)
|
| 233 |
-
_all_datasets.append({'label': _clabel, 'path': _cpath, '_lazy': True, '
|
| 234 |
|
| 235 |
|
| 236 |
def _ensure_loaded(idx):
|
|
@@ -238,7 +238,7 @@ def _ensure_loaded(idx):
|
|
| 238 |
ds = _all_datasets[idx]
|
| 239 |
if ds.get('_lazy', False):
|
| 240 |
print(f"[Lazy load] Loading '{ds['label']}' on first access ...")
|
| 241 |
-
_all_datasets[idx] = _load_dataset_dict(ds['path'], ds['label'],
|
| 242 |
|
| 243 |
|
| 244 |
def _apply_dataset_globals(idx):
|
|
@@ -808,9 +808,8 @@ def _on_dataset_switch(attr, old, new):
|
|
| 808 |
# Rebuild active-feature pool for random button
|
| 809 |
_active_feats = [int(i) for i in range(d_model) if feature_frequency[i].item() > 0]
|
| 810 |
|
| 811 |
-
# Update summary panel
|
| 812 |
summary_div.text = _make_summary_html()
|
| 813 |
-
_update_download_source()
|
| 814 |
|
| 815 |
# Show/hide patch explorer depending on token type and data availability.
|
| 816 |
ds = _all_datasets[idx]
|
|
@@ -1493,6 +1492,10 @@ def _make_summary_html():
|
|
| 1493 |
hm_label = "yes" if ds.get('top_heatmaps') is not None else "no"
|
| 1494 |
pa = ds.get('patch_acts')
|
| 1495 |
pa_label = f"yes ({len(pa['img_to_row'])} images)" if pa is not None else "no β run --save-patch-acts"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1496 |
return f"""
|
| 1497 |
<div style="background:#f0f4f8;padding:12px;border-radius:6px;margin-bottom:8px;">
|
| 1498 |
<h2 style="margin:0 0 8px 0">SAE Feature Explorer</h2>
|
|
@@ -1505,44 +1508,12 @@ def _make_summary_html():
|
|
| 1505 |
<tr><td><b>Dead:</b></td><td>{n_dead:,} ({100*n_dead/d_model:.1f}%)</td></tr>
|
| 1506 |
<tr><td><b>Images:</b></td><td>{n_images:,}</td></tr>
|
| 1507 |
<tr><td><b>Tokens/image:</b></td><td>{tok_label}</td></tr>
|
|
|
|
| 1508 |
</table>
|
| 1509 |
</div>"""
|
| 1510 |
|
| 1511 |
summary_div = Div(text=_make_summary_html(), width=700)
|
| 1512 |
|
| 1513 |
-
# --- SAE weights download button ---
|
| 1514 |
-
# Data is pre-loaded into the source so the JS download fires synchronously
|
| 1515 |
-
# with the user click (bypassing browser gesture-context restrictions).
|
| 1516 |
-
|
| 1517 |
-
def _sae_b64(sae_path):
|
| 1518 |
-
"""Return (b64_str, filename) for a SAE weights file, or ('', '') if unavailable."""
|
| 1519 |
-
if not sae_path or not os.path.exists(sae_path):
|
| 1520 |
-
return '', ''
|
| 1521 |
-
with open(sae_path, 'rb') as f:
|
| 1522 |
-
return base64.b64encode(f.read()).decode('ascii'), os.path.basename(sae_path)
|
| 1523 |
-
|
| 1524 |
-
_init_b64, _init_fname = _sae_b64(_all_datasets[0].get('sae_path'))
|
| 1525 |
-
_download_source = ColumnDataSource(data=dict(b64=[_init_b64], filename=[_init_fname]))
|
| 1526 |
-
|
| 1527 |
-
sae_download_btn = Button(label="\u2b07 Download SAE weights", button_type="default", width=220)
|
| 1528 |
-
sae_download_btn.js_on_click(CustomJS(args=dict(src=_download_source), code="""
|
| 1529 |
-
const b64 = src.data['b64'][0];
|
| 1530 |
-
const fname = src.data['filename'][0];
|
| 1531 |
-
if (!b64) { alert('No SAE weights available for this model.'); return; }
|
| 1532 |
-
const bytes = Uint8Array.from(atob(b64), c => c.charCodeAt(0));
|
| 1533 |
-
const blob = new Blob([bytes], {type: 'application/octet-stream'});
|
| 1534 |
-
const url = URL.createObjectURL(blob);
|
| 1535 |
-
const a = document.createElement('a');
|
| 1536 |
-
a.href = url; a.download = fname; a.click();
|
| 1537 |
-
URL.revokeObjectURL(url);
|
| 1538 |
-
"""))
|
| 1539 |
-
|
| 1540 |
-
def _update_download_source():
|
| 1541 |
-
"""Reload SAE weights for the active dataset into _download_source."""
|
| 1542 |
-
ds = _all_datasets[_active[0]]
|
| 1543 |
-
b64, fname = _sae_b64(ds.get('sae_path'))
|
| 1544 |
-
_download_source.data = dict(b64=[b64], filename=[fname])
|
| 1545 |
-
|
| 1546 |
|
| 1547 |
# ---------- Patch Explorer ----------
|
| 1548 |
# Click patches of an image to find the top active SAE features for that region.
|
|
@@ -1920,7 +1891,7 @@ patch_explorer_panel = column(
|
|
| 1920 |
patch_feat_table,
|
| 1921 |
)
|
| 1922 |
|
| 1923 |
-
summary_section = _make_collapsible("SAE Summary",
|
| 1924 |
patch_section = _make_collapsible("Patch Explorer", patch_explorer_panel)
|
| 1925 |
clip_section = _make_collapsible("CLIP Text Search", clip_search_panel)
|
| 1926 |
|
|
|
|
| 100 |
parser.add_argument("--google-api-key", type=str, default=None,
|
| 101 |
help="Google API key for Gemini auto-interp button "
|
| 102 |
"(default: GOOGLE_API_KEY env var)")
|
| 103 |
+
parser.add_argument("--sae-url", type=str, default=None,
|
| 104 |
+
help="Download URL for the primary dataset's SAE weights β "
|
| 105 |
+
"shown as a link in the summary panel")
|
| 106 |
+
parser.add_argument("--compare-sae-urls", type=str, nargs="*", default=[],
|
| 107 |
+
help="Download URLs for each --compare-data dataset's SAE weights (in order)")
|
| 108 |
args = parser.parse_args()
|
| 109 |
|
| 110 |
|
|
|
|
| 125 |
|
| 126 |
# ---------- Load all datasets into a unified list ----------
|
| 127 |
|
| 128 |
+
def _load_dataset_dict(path, label, sae_url=None):
|
| 129 |
"""Load one explorer_data.pt file and return a unified dataset dict."""
|
| 130 |
print(f"Loading [{label}] from {path} ...")
|
| 131 |
d = torch.load(path, map_location='cpu', weights_only=False)
|
|
|
|
| 208 |
else:
|
| 209 |
entry['patch_acts'] = None
|
| 210 |
|
| 211 |
+
entry['sae_url'] = sae_url
|
| 212 |
|
| 213 |
print(f" d={entry['d_model']}, n={entry['n_images']}, token={entry['token_type']}, "
|
| 214 |
f"backbone={entry['backbone']}, clip={'yes' if cs is not None else 'no'}, "
|
|
|
|
| 220 |
_active = [0] # index of the currently displayed dataset
|
| 221 |
|
| 222 |
# Primary dataset β always loaded eagerly
|
| 223 |
+
_all_datasets.append(_load_dataset_dict(args.data, args.primary_label, sae_url=args.sae_url))
|
| 224 |
|
| 225 |
# Compare datasets β stored as lazy placeholders; loaded on first access
|
| 226 |
for _ci, _cpath in enumerate(args.compare_data):
|
| 227 |
_clabel = (args.compare_labels[_ci]
|
| 228 |
if args.compare_labels and _ci < len(args.compare_labels)
|
| 229 |
else os.path.basename(_cpath))
|
| 230 |
+
_csae = (args.compare_sae_urls[_ci]
|
| 231 |
+
if args.compare_sae_urls and _ci < len(args.compare_sae_urls)
|
| 232 |
else None)
|
| 233 |
+
_all_datasets.append({'label': _clabel, 'path': _cpath, '_lazy': True, 'sae_url': _csae})
|
| 234 |
|
| 235 |
|
| 236 |
def _ensure_loaded(idx):
|
|
|
|
| 238 |
ds = _all_datasets[idx]
|
| 239 |
if ds.get('_lazy', False):
|
| 240 |
print(f"[Lazy load] Loading '{ds['label']}' on first access ...")
|
| 241 |
+
_all_datasets[idx] = _load_dataset_dict(ds['path'], ds['label'], sae_url=ds.get('sae_url'))
|
| 242 |
|
| 243 |
|
| 244 |
def _apply_dataset_globals(idx):
|
|
|
|
| 808 |
# Rebuild active-feature pool for random button
|
| 809 |
_active_feats = [int(i) for i in range(d_model) if feature_frequency[i].item() > 0]
|
| 810 |
|
| 811 |
+
# Update summary panel
|
| 812 |
summary_div.text = _make_summary_html()
|
|
|
|
| 813 |
|
| 814 |
# Show/hide patch explorer depending on token type and data availability.
|
| 815 |
ds = _all_datasets[idx]
|
|
|
|
| 1492 |
hm_label = "yes" if ds.get('top_heatmaps') is not None else "no"
|
| 1493 |
pa = ds.get('patch_acts')
|
| 1494 |
pa_label = f"yes ({len(pa['img_to_row'])} images)" if pa is not None else "no β run --save-patch-acts"
|
| 1495 |
+
sae_url = ds.get('sae_url')
|
| 1496 |
+
dl_row = (f'<tr><td><b>SAE weights:</b></td>'
|
| 1497 |
+
f'<td><a href="{sae_url}" download style="color:#1a6faf">β¬ Download</a></td></tr>'
|
| 1498 |
+
if sae_url else '')
|
| 1499 |
return f"""
|
| 1500 |
<div style="background:#f0f4f8;padding:12px;border-radius:6px;margin-bottom:8px;">
|
| 1501 |
<h2 style="margin:0 0 8px 0">SAE Feature Explorer</h2>
|
|
|
|
| 1508 |
<tr><td><b>Dead:</b></td><td>{n_dead:,} ({100*n_dead/d_model:.1f}%)</td></tr>
|
| 1509 |
<tr><td><b>Images:</b></td><td>{n_images:,}</td></tr>
|
| 1510 |
<tr><td><b>Tokens/image:</b></td><td>{tok_label}</td></tr>
|
| 1511 |
+
{dl_row}
|
| 1512 |
</table>
|
| 1513 |
</div>"""
|
| 1514 |
|
| 1515 |
summary_div = Div(text=_make_summary_html(), width=700)
|
| 1516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1517 |
|
| 1518 |
# ---------- Patch Explorer ----------
|
| 1519 |
# Click patches of an image to find the top active SAE features for that region.
|
|
|
|
| 1891 |
patch_feat_table,
|
| 1892 |
)
|
| 1893 |
|
| 1894 |
+
summary_section = _make_collapsible("SAE Summary", summary_div)
|
| 1895 |
patch_section = _make_collapsible("Patch Explorer", patch_explorer_panel)
|
| 1896 |
clip_section = _make_collapsible("CLIP Text Search", clip_search_panel)
|
| 1897 |
|