Spaces:
Running
Running
Marlin Lee commited on
Commit ·
cf75c2d
1
Parent(s): cdaf9dc
Sync explorer_app.py and clip_utils.py from main repo
Browse files- scripts/explorer_app.py +46 -5
scripts/explorer_app.py
CHANGED
|
@@ -100,6 +100,11 @@ parser.add_argument("--clip-model", type=str, default="openai/clip-vit-large-pat
|
|
| 100 |
parser.add_argument("--google-api-key", type=str, default=None,
|
| 101 |
help="Google API key for Gemini auto-interp button "
|
| 102 |
"(default: GOOGLE_API_KEY env var)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
args = parser.parse_args()
|
| 104 |
|
| 105 |
|
|
@@ -120,7 +125,7 @@ def _get_clip():
|
|
| 120 |
|
| 121 |
# ---------- Load all datasets into a unified list ----------
|
| 122 |
|
| 123 |
-
def _load_dataset_dict(path, label):
|
| 124 |
"""Load one explorer_data.pt file and return a unified dataset dict."""
|
| 125 |
print(f"Loading [{label}] from {path} ...")
|
| 126 |
d = torch.load(path, map_location='cpu', weights_only=False)
|
|
@@ -203,6 +208,8 @@ def _load_dataset_dict(path, label):
|
|
| 203 |
else:
|
| 204 |
entry['patch_acts'] = None
|
| 205 |
|
|
|
|
|
|
|
| 206 |
print(f" d={entry['d_model']}, n={entry['n_images']}, token={entry['token_type']}, "
|
| 207 |
f"backbone={entry['backbone']}, clip={'yes' if cs is not None else 'no'}, "
|
| 208 |
f"heatmaps={has_hm}, patch_acts={'yes' if entry['patch_acts'] else 'no'}")
|
|
@@ -213,14 +220,17 @@ _all_datasets = []
|
|
| 213 |
_active = [0] # index of the currently displayed dataset
|
| 214 |
|
| 215 |
# Primary dataset — always loaded eagerly
|
| 216 |
-
_all_datasets.append(_load_dataset_dict(args.data, args.primary_label))
|
| 217 |
|
| 218 |
# Compare datasets — stored as lazy placeholders; loaded on first access
|
| 219 |
for _ci, _cpath in enumerate(args.compare_data):
|
| 220 |
_clabel = (args.compare_labels[_ci]
|
| 221 |
if args.compare_labels and _ci < len(args.compare_labels)
|
| 222 |
else os.path.basename(_cpath))
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
|
| 226 |
def _ensure_loaded(idx):
|
|
@@ -228,7 +238,7 @@ def _ensure_loaded(idx):
|
|
| 228 |
ds = _all_datasets[idx]
|
| 229 |
if ds.get('_lazy', False):
|
| 230 |
print(f"[Lazy load] Loading '{ds['label']}' on first access ...")
|
| 231 |
-
_all_datasets[idx] = _load_dataset_dict(ds['path'], ds['label'])
|
| 232 |
|
| 233 |
|
| 234 |
def _apply_dataset_globals(idx):
|
|
@@ -1499,6 +1509,37 @@ def _make_summary_html():
|
|
| 1499 |
|
| 1500 |
summary_div = Div(text=_make_summary_html(), width=700)
|
| 1501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1502 |
|
| 1503 |
# ---------- Patch Explorer ----------
|
| 1504 |
# Click patches of an image to find the top active SAE features for that region.
|
|
@@ -1876,7 +1917,7 @@ patch_explorer_panel = column(
|
|
| 1876 |
patch_feat_table,
|
| 1877 |
)
|
| 1878 |
|
| 1879 |
-
summary_section = _make_collapsible("SAE Summary", summary_div)
|
| 1880 |
patch_section = _make_collapsible("Patch Explorer", patch_explorer_panel)
|
| 1881 |
clip_section = _make_collapsible("CLIP Text Search", clip_search_panel)
|
| 1882 |
|
|
|
|
| 100 |
parser.add_argument("--google-api-key", type=str, default=None,
|
| 101 |
help="Google API key for Gemini auto-interp button "
|
| 102 |
"(default: GOOGLE_API_KEY env var)")
|
| 103 |
+
parser.add_argument("--sae-path", type=str, default=None,
|
| 104 |
+
help="Path to SAE weights (.pth) for the primary dataset — "
|
| 105 |
+
"enables the Download SAE weights button in the summary panel")
|
| 106 |
+
parser.add_argument("--compare-sae-paths", type=str, nargs="*", default=[],
|
| 107 |
+
help="SAE weight paths for each --compare-data dataset (in order)")
|
| 108 |
args = parser.parse_args()
|
| 109 |
|
| 110 |
|
|
|
|
| 125 |
|
| 126 |
# ---------- Load all datasets into a unified list ----------
|
| 127 |
|
| 128 |
+
def _load_dataset_dict(path, label, sae_path=None):
|
| 129 |
"""Load one explorer_data.pt file and return a unified dataset dict."""
|
| 130 |
print(f"Loading [{label}] from {path} ...")
|
| 131 |
d = torch.load(path, map_location='cpu', weights_only=False)
|
|
|
|
| 208 |
else:
|
| 209 |
entry['patch_acts'] = None
|
| 210 |
|
| 211 |
+
entry['sae_path'] = sae_path
|
| 212 |
+
|
| 213 |
print(f" d={entry['d_model']}, n={entry['n_images']}, token={entry['token_type']}, "
|
| 214 |
f"backbone={entry['backbone']}, clip={'yes' if cs is not None else 'no'}, "
|
| 215 |
f"heatmaps={has_hm}, patch_acts={'yes' if entry['patch_acts'] else 'no'}")
|
|
|
|
| 220 |
_active = [0] # index of the currently displayed dataset
|
| 221 |
|
| 222 |
# Primary dataset — always loaded eagerly
|
| 223 |
+
_all_datasets.append(_load_dataset_dict(args.data, args.primary_label, sae_path=args.sae_path))
|
| 224 |
|
| 225 |
# Compare datasets — stored as lazy placeholders; loaded on first access
|
| 226 |
for _ci, _cpath in enumerate(args.compare_data):
|
| 227 |
_clabel = (args.compare_labels[_ci]
|
| 228 |
if args.compare_labels and _ci < len(args.compare_labels)
|
| 229 |
else os.path.basename(_cpath))
|
| 230 |
+
_csae = (args.compare_sae_paths[_ci]
|
| 231 |
+
if args.compare_sae_paths and _ci < len(args.compare_sae_paths)
|
| 232 |
+
else None)
|
| 233 |
+
_all_datasets.append({'label': _clabel, 'path': _cpath, '_lazy': True, 'sae_path': _csae})
|
| 234 |
|
| 235 |
|
| 236 |
def _ensure_loaded(idx):
|
|
|
|
| 238 |
ds = _all_datasets[idx]
|
| 239 |
if ds.get('_lazy', False):
|
| 240 |
print(f"[Lazy load] Loading '{ds['label']}' on first access ...")
|
| 241 |
+
_all_datasets[idx] = _load_dataset_dict(ds['path'], ds['label'], sae_path=ds.get('sae_path'))
|
| 242 |
|
| 243 |
|
| 244 |
def _apply_dataset_globals(idx):
|
|
|
|
| 1509 |
|
| 1510 |
summary_div = Div(text=_make_summary_html(), width=700)
|
| 1511 |
|
| 1512 |
+
# --- SAE weights download button ---
|
| 1513 |
+
_download_source = ColumnDataSource(data=dict(b64=[''], filename=['']))
|
| 1514 |
+
_download_source.js_on_change('data', CustomJS(args=dict(src=_download_source), code="""
|
| 1515 |
+
const b64 = src.data['b64'][0];
|
| 1516 |
+
const fname = src.data['filename'][0];
|
| 1517 |
+
if (!b64) return;
|
| 1518 |
+
const bytes = Uint8Array.from(atob(b64), c => c.charCodeAt(0));
|
| 1519 |
+
const blob = new Blob([bytes], {type: 'application/octet-stream'});
|
| 1520 |
+
const url = URL.createObjectURL(blob);
|
| 1521 |
+
const a = document.createElement('a');
|
| 1522 |
+
a.href = url; a.download = fname; a.click();
|
| 1523 |
+
URL.revokeObjectURL(url);
|
| 1524 |
+
src.data = {b64: [''], filename: ['']};
|
| 1525 |
+
"""))
|
| 1526 |
+
|
| 1527 |
+
sae_download_btn = Button(label="\u2b07 Download SAE weights", button_type="default", width=220)
|
| 1528 |
+
|
| 1529 |
+
def _on_sae_download():
|
| 1530 |
+
ds = _all_datasets[_active[0]]
|
| 1531 |
+
sae_path = ds.get('sae_path')
|
| 1532 |
+
if not sae_path or not os.path.exists(sae_path):
|
| 1533 |
+
status_div.text = "<b style='color:red'>No SAE path set for this model. Pass --sae-path.</b>"
|
| 1534 |
+
return
|
| 1535 |
+
status_div.text = f"<b>Reading {os.path.basename(sae_path)}…</b>"
|
| 1536 |
+
with open(sae_path, 'rb') as f:
|
| 1537 |
+
b64 = base64.b64encode(f.read()).decode('ascii')
|
| 1538 |
+
_download_source.data = dict(b64=[b64], filename=[os.path.basename(sae_path)])
|
| 1539 |
+
status_div.text = ""
|
| 1540 |
+
|
| 1541 |
+
sae_download_btn.on_click(lambda: _on_sae_download())
|
| 1542 |
+
|
| 1543 |
|
| 1544 |
# ---------- Patch Explorer ----------
|
| 1545 |
# Click patches of an image to find the top active SAE features for that region.
|
|
|
|
| 1917 |
patch_feat_table,
|
| 1918 |
)
|
| 1919 |
|
| 1920 |
+
summary_section = _make_collapsible("SAE Summary", column(summary_div, sae_download_btn))
|
| 1921 |
patch_section = _make_collapsible("Patch Explorer", patch_explorer_panel)
|
| 1922 |
clip_section = _make_collapsible("CLIP Text Search", clip_search_panel)
|
| 1923 |
|