Spaces:
Sleeping
Sleeping
Marlin Lee commited on
Commit ·
1361d88
1
Parent(s): 62166b9
Comment out new features (CLIP×φ sort, λ input, brain MEI sidecar) until ready
Browse files- scripts/explorer/datasets.py +10 -10
- scripts/explorer/main.py +1 -1
- scripts/explorer/panels/clip_search.py +1 -1
- scripts/explorer/panels/dynadiff.py +19 -18
- scripts/explorer/panels/feature_list.py +27 -25
- scripts/explorer/rendering.py +14 -13
- scripts/explorer/state.py +1 -1
scripts/explorer/datasets.py
CHANGED
|
@@ -150,16 +150,16 @@ def _load_dataset(path: str, label: str, *,
|
|
| 150 |
'umap_backup': umap_backup,
|
| 151 |
}
|
| 152 |
|
| 153 |
-
# Brain MEI sidecar (
|
| 154 |
-
brain_sidecar = stem + '_brain_meis.pt'
|
| 155 |
-
if os.path.exists(brain_sidecar):
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
else:
|
| 161 |
-
|
| 162 |
-
|
| 163 |
|
| 164 |
# Heatmaps sidecar
|
| 165 |
sidecar = stem + '_heatmaps.pt'
|
|
|
|
| 150 |
'umap_backup': umap_backup,
|
| 151 |
}
|
| 152 |
|
| 153 |
+
# Brain MEI sidecar (disabled — re-enable after running precompute_brain_response_meis.py)
|
| 154 |
+
# brain_sidecar = stem + '_brain_meis.pt'
|
| 155 |
+
# if os.path.exists(brain_sidecar):
|
| 156 |
+
# print(f" Loading brain MEI sidecar {os.path.basename(brain_sidecar)} ...")
|
| 157 |
+
# bm = torch.load(brain_sidecar, map_location='cpu', weights_only=False)
|
| 158 |
+
# entry['brain_top_img_idx'] = bm.get('brain_top_img_idx')
|
| 159 |
+
# entry['brain_top_img_act'] = bm.get('brain_top_img_act')
|
| 160 |
+
# else:
|
| 161 |
+
# entry['brain_top_img_idx'] = None
|
| 162 |
+
# entry['brain_top_img_act'] = None
|
| 163 |
|
| 164 |
# Heatmaps sidecar
|
| 165 |
sidecar = stem + '_heatmaps.pt'
|
scripts/explorer/main.py
CHANGED
|
@@ -276,7 +276,7 @@ lower_left = column(
|
|
| 276 |
_search_header,
|
| 277 |
_clip_row,
|
| 278 |
_clip_results,
|
| 279 |
-
flist_panel.sort_select,
|
| 280 |
flist_panel.gallery_div,
|
| 281 |
flist_panel.gallery_bridge_input,
|
| 282 |
flist_panel.gallery_page_input,
|
|
|
|
| 276 |
_search_header,
|
| 277 |
_clip_row,
|
| 278 |
_clip_results,
|
| 279 |
+
# flist_panel.sort_select, # re-enable for CLIP × φ sort
|
| 280 |
flist_panel.gallery_div,
|
| 281 |
flist_panel.gallery_bridge_input,
|
| 282 |
flist_panel.gallery_page_input,
|
scripts/explorer/panels/clip_search.py
CHANGED
|
@@ -55,7 +55,7 @@ else:
|
|
| 55 |
clip_m, clip_p, clip_dev = get_clip()
|
| 56 |
from clip_utils import compute_text_embeddings
|
| 57 |
q_embed = compute_text_embeddings([query], clip_m, clip_p, clip_dev)
|
| 58 |
-
_S.clip_query_embed = q_embed # store for combined CLIP × φ sort
|
| 59 |
scores_vec = (embeds.float() @ q_embed.T).squeeze(-1)
|
| 60 |
except Exception as exc:
|
| 61 |
clip_results_div.text = f"<span style='color:#c00'>CLIP error: {exc}</span>"
|
|
|
|
| 55 |
clip_m, clip_p, clip_dev = get_clip()
|
| 56 |
from clip_utils import compute_text_embeddings
|
| 57 |
q_embed = compute_text_embeddings([query], clip_m, clip_p, clip_dev)
|
| 58 |
+
# _S.clip_query_embed = q_embed # store for combined CLIP × φ sort
|
| 59 |
scores_vec = (embeds.float() @ q_embed.T).squeeze(-1)
|
| 60 |
except Exception as exc:
|
| 61 |
clip_results_div.text = f"<span style='color:#c00'>CLIP error: {exc}</span>"
|
scripts/explorer/panels/dynadiff.py
CHANGED
|
@@ -239,24 +239,25 @@ if HAS_DYNADIFF:
|
|
| 239 |
for k, vals in dd_source.data.items()}
|
| 240 |
dd_source.data = new_data
|
| 241 |
dd_status.text = ''
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
|
|
|
| 260 |
|
| 261 |
dd_feat_bridge.on_change('value', _on_feat_bridge)
|
| 262 |
|
|
|
|
| 239 |
for k, vals in dd_source.data.items()}
|
| 240 |
dd_source.data = new_data
|
| 241 |
dd_status.text = ''
|
| 242 |
+
# set_lam disabled — re-enable with lambda input in rendering.py
|
| 243 |
+
# elif msg.startswith('set_lam:'):
|
| 244 |
+
# parts = msg.split(':', 2)
|
| 245 |
+
# if len(parts) != 3:
|
| 246 |
+
# return
|
| 247 |
+
# try:
|
| 248 |
+
# feat = int(parts[1])
|
| 249 |
+
# new_val = float(parts[2])
|
| 250 |
+
# except ValueError:
|
| 251 |
+
# return
|
| 252 |
+
# feats = list(dd_source.data['feat'])
|
| 253 |
+
# if feat not in feats:
|
| 254 |
+
# return
|
| 255 |
+
# idx = feats.index(feat)
|
| 256 |
+
# new_lams = list(dd_source.data['lam'])
|
| 257 |
+
# new_lams[idx] = new_val
|
| 258 |
+
# new_data = dict(dd_source.data)
|
| 259 |
+
# new_data['lam'] = new_lams
|
| 260 |
+
# dd_source.data = new_data
|
| 261 |
|
| 262 |
dd_feat_bridge.on_change('value', _on_feat_bridge)
|
| 263 |
|
scripts/explorer/panels/feature_list.py
CHANGED
|
@@ -21,7 +21,8 @@ from bokeh.io import curdoc
|
|
| 21 |
from bokeh.layouts import column, row
|
| 22 |
from bokeh.models import (
|
| 23 |
Button, ColumnDataSource, DataTable, Div, NumberFormatter,
|
| 24 |
-
Select,
|
|
|
|
| 25 |
)
|
| 26 |
|
| 27 |
from ..args import args
|
|
@@ -83,28 +84,30 @@ def _schedule_hf_push(file_path: str):
|
|
| 83 |
|
| 84 |
# ---------- Feature list table ----------
|
| 85 |
|
| 86 |
-
sort_select = Select(
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
)
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
def _get_sorted_order():
|
| 94 |
-
ds
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
| 108 |
if _S.search_filter is not None:
|
| 109 |
mask = np.isin(order, list(_S.search_filter))
|
| 110 |
order = order[mask]
|
|
@@ -123,11 +126,10 @@ def _build_list_data(order) -> dict:
|
|
| 123 |
)
|
| 124 |
|
| 125 |
|
| 126 |
-
def _on_sort_change(attr, old, new):
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
sort_select.on_change('value', _on_sort_change)
|
| 131 |
|
| 132 |
feature_list_source = ColumnDataSource(data=_build_list_data(_get_sorted_order()))
|
| 133 |
|
|
|
|
| 21 |
from bokeh.layouts import column, row
|
| 22 |
from bokeh.models import (
|
| 23 |
Button, ColumnDataSource, DataTable, Div, NumberFormatter,
|
| 24 |
+
# Select, # re-enable for CLIP × φ sort
|
| 25 |
+
TableColumn, TextInput,
|
| 26 |
)
|
| 27 |
|
| 28 |
from ..args import args
|
|
|
|
| 84 |
|
| 85 |
# ---------- Feature list table ----------
|
| 86 |
|
| 87 |
+
# sort_select = Select( # re-enable for CLIP × φ sort
|
| 88 |
+
# title="Sort by", value="Frequency",
|
| 89 |
+
# options=["Frequency", "CLIP × Brain φ"],
|
| 90 |
+
# width=160,
|
| 91 |
+
# )
|
| 92 |
+
sort_select = None
|
| 93 |
|
| 94 |
|
| 95 |
def _get_sorted_order():
|
| 96 |
+
ds = active_ds()
|
| 97 |
+
# CLIP × φ sort (disabled):
|
| 98 |
+
# if (sort_select is not None and sort_select.value == "CLIP × Brain φ"
|
| 99 |
+
# and _S.clip_query_embed is not None):
|
| 100 |
+
# embeds = ds.get('nsd_clip_embeds') if ds.get('nsd_clip_embeds') is not None else ds.get('clip_embeds')
|
| 101 |
+
# if embeds is not None:
|
| 102 |
+
# clip_sims = (embeds.float() @ _S.clip_query_embed.T).squeeze(-1).cpu().numpy()
|
| 103 |
+
# phi_vals = np.array(phi_c_vals(list(range(len(clip_sims)))))
|
| 104 |
+
# c = clip_sims - clip_sims.min(); c = c / (c.max() + 1e-8)
|
| 105 |
+
# p = phi_vals / (phi_vals.max() + 1e-8)
|
| 106 |
+
# order = np.argsort(-(c * p))
|
| 107 |
+
# else:
|
| 108 |
+
# order = np.argsort(-ds['freq'])
|
| 109 |
+
# else:
|
| 110 |
+
order = np.argsort(-ds['freq'])
|
| 111 |
if _S.search_filter is not None:
|
| 112 |
mask = np.isin(order, list(_S.search_filter))
|
| 113 |
order = order[mask]
|
|
|
|
| 126 |
)
|
| 127 |
|
| 128 |
|
| 129 |
+
# def _on_sort_change(attr, old, new): # re-enable for CLIP × φ sort
|
| 130 |
+
# apply_order(_get_sorted_order())
|
| 131 |
+
# rebuild_gallery()
|
| 132 |
+
# sort_select.on_change('value', _on_sort_change)
|
|
|
|
| 133 |
|
| 134 |
feature_list_source = ColumnDataSource(data=_build_list_data(_get_sorted_order()))
|
| 135 |
|
scripts/explorer/rendering.py
CHANGED
|
@@ -480,19 +480,20 @@ def make_active_features_tile_html(feats: list, ds: dict, mei_size: int = 72,
|
|
| 480 |
f'margin-top:2px;text-align:center;overflow:hidden;'
|
| 481 |
f'text-overflow:ellipsis;white-space:nowrap">{label}</div>') if label else ''
|
| 482 |
|
| 483 |
-
# Lambda number input (
|
| 484 |
-
if removable:
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
else:
|
| 495 |
-
|
|
|
|
| 496 |
|
| 497 |
cards.append(
|
| 498 |
f'<div onclick="window._sae_select_feature({feat})" '
|
|
|
|
| 480 |
f'margin-top:2px;text-align:center;overflow:hidden;'
|
| 481 |
f'text-overflow:ellipsis;white-space:nowrap">{label}</div>') if label else ''
|
| 482 |
|
| 483 |
+
# Lambda number input (disabled — re-enable when set_lam handler is active)
|
| 484 |
+
# if removable:
|
| 485 |
+
# lam_val = lam_map.get(feat, 3.0)
|
| 486 |
+
# lam_html = (
|
| 487 |
+
# f'<input type="number" value="{lam_val:.2g}" step="0.5" '
|
| 488 |
+
# f'style="width:54px;font-size:10px;text-align:center;'
|
| 489 |
+
# f'border:1px solid #ccc;border-radius:3px;margin-top:2px" '
|
| 490 |
+
# f'onclick="event.stopPropagation()" '
|
| 491 |
+
# f'onchange="window._dd_feat_action(\'set_lam:{feat}:\'+this.value)" '
|
| 492 |
+
# f'title="Steering strength λ"/>'
|
| 493 |
+
# )
|
| 494 |
+
# else:
|
| 495 |
+
# lam_html = ''
|
| 496 |
+
lam_html = ''
|
| 497 |
|
| 498 |
cards.append(
|
| 499 |
f'<div onclick="window._sae_select_feature({feat})" '
|
scripts/explorer/state.py
CHANGED
|
@@ -23,7 +23,7 @@ class _S:
|
|
| 23 |
patch_img = None # image index currently loaded in patch explorer
|
| 24 |
patch_z = None # (n_patches, d_sae) float32 for the loaded patch image
|
| 25 |
gallery_page: int = 0 # current page in the MEI thumbnail gallery
|
| 26 |
-
clip_query_embed = None # last CLIP text query embedding (1, d_clip) tensor for combined sort
|
| 27 |
|
| 28 |
|
| 29 |
_all_datasets: list[dict] = []
|
|
|
|
| 23 |
patch_img = None # image index currently loaded in patch explorer
|
| 24 |
patch_z = None # (n_patches, d_sae) float32 for the loaded patch image
|
| 25 |
gallery_page: int = 0 # current page in the MEI thumbnail gallery
|
| 26 |
+
# clip_query_embed = None # last CLIP text query embedding (1, d_clip) tensor for combined sort
|
| 27 |
|
| 28 |
|
| 29 |
_all_datasets: list[dict] = []
|