Marlin Lee commited on
Commit
1361d88
·
1 Parent(s): 62166b9

Comment out new features (CLIP×φ sort, λ input, brain MEI sidecar) until ready

Browse files
scripts/explorer/datasets.py CHANGED
@@ -150,16 +150,16 @@ def _load_dataset(path: str, label: str, *,
150
  'umap_backup': umap_backup,
151
  }
152
 
153
- # Brain MEI sidecar (brain-response-sorted MEI indices)
154
- brain_sidecar = stem + '_brain_meis.pt'
155
- if os.path.exists(brain_sidecar):
156
- print(f" Loading brain MEI sidecar {os.path.basename(brain_sidecar)} ...")
157
- bm = torch.load(brain_sidecar, map_location='cpu', weights_only=False)
158
- entry['brain_top_img_idx'] = bm.get('brain_top_img_idx')
159
- entry['brain_top_img_act'] = bm.get('brain_top_img_act')
160
- else:
161
- entry['brain_top_img_idx'] = None
162
- entry['brain_top_img_act'] = None
163
 
164
  # Heatmaps sidecar
165
  sidecar = stem + '_heatmaps.pt'
 
150
  'umap_backup': umap_backup,
151
  }
152
 
153
+ # Brain MEI sidecar (disabled — re-enable after running precompute_brain_response_meis.py)
154
+ # brain_sidecar = stem + '_brain_meis.pt'
155
+ # if os.path.exists(brain_sidecar):
156
+ # print(f" Loading brain MEI sidecar {os.path.basename(brain_sidecar)} ...")
157
+ # bm = torch.load(brain_sidecar, map_location='cpu', weights_only=False)
158
+ # entry['brain_top_img_idx'] = bm.get('brain_top_img_idx')
159
+ # entry['brain_top_img_act'] = bm.get('brain_top_img_act')
160
+ # else:
161
+ # entry['brain_top_img_idx'] = None
162
+ # entry['brain_top_img_act'] = None
163
 
164
  # Heatmaps sidecar
165
  sidecar = stem + '_heatmaps.pt'
scripts/explorer/main.py CHANGED
@@ -276,7 +276,7 @@ lower_left = column(
276
  _search_header,
277
  _clip_row,
278
  _clip_results,
279
- flist_panel.sort_select,
280
  flist_panel.gallery_div,
281
  flist_panel.gallery_bridge_input,
282
  flist_panel.gallery_page_input,
 
276
  _search_header,
277
  _clip_row,
278
  _clip_results,
279
+ # flist_panel.sort_select, # re-enable for CLIP × φ sort
280
  flist_panel.gallery_div,
281
  flist_panel.gallery_bridge_input,
282
  flist_panel.gallery_page_input,
scripts/explorer/panels/clip_search.py CHANGED
@@ -55,7 +55,7 @@ else:
55
  clip_m, clip_p, clip_dev = get_clip()
56
  from clip_utils import compute_text_embeddings
57
  q_embed = compute_text_embeddings([query], clip_m, clip_p, clip_dev)
58
- _S.clip_query_embed = q_embed # store for combined CLIP × φ sort
59
  scores_vec = (embeds.float() @ q_embed.T).squeeze(-1)
60
  except Exception as exc:
61
  clip_results_div.text = f"<span style='color:#c00'>CLIP error: {exc}</span>"
 
55
  clip_m, clip_p, clip_dev = get_clip()
56
  from clip_utils import compute_text_embeddings
57
  q_embed = compute_text_embeddings([query], clip_m, clip_p, clip_dev)
58
+ # _S.clip_query_embed = q_embed # store for combined CLIP × φ sort
59
  scores_vec = (embeds.float() @ q_embed.T).squeeze(-1)
60
  except Exception as exc:
61
  clip_results_div.text = f"<span style='color:#c00'>CLIP error: {exc}</span>"
scripts/explorer/panels/dynadiff.py CHANGED
@@ -239,24 +239,25 @@ if HAS_DYNADIFF:
239
  for k, vals in dd_source.data.items()}
240
  dd_source.data = new_data
241
  dd_status.text = ''
242
- elif msg.startswith('set_lam:'):
243
- parts = msg.split(':', 2)
244
- if len(parts) != 3:
245
- return
246
- try:
247
- feat = int(parts[1])
248
- new_val = float(parts[2])
249
- except ValueError:
250
- return
251
- feats = list(dd_source.data['feat'])
252
- if feat not in feats:
253
- return
254
- idx = feats.index(feat)
255
- new_lams = list(dd_source.data['lam'])
256
- new_lams[idx] = new_val
257
- new_data = dict(dd_source.data)
258
- new_data['lam'] = new_lams
259
- dd_source.data = new_data
 
260
 
261
  dd_feat_bridge.on_change('value', _on_feat_bridge)
262
 
 
239
  for k, vals in dd_source.data.items()}
240
  dd_source.data = new_data
241
  dd_status.text = ''
242
+ # set_lam disabled — re-enable with lambda input in rendering.py
243
+ # elif msg.startswith('set_lam:'):
244
+ # parts = msg.split(':', 2)
245
+ # if len(parts) != 3:
246
+ # return
247
+ # try:
248
+ # feat = int(parts[1])
249
+ # new_val = float(parts[2])
250
+ # except ValueError:
251
+ # return
252
+ # feats = list(dd_source.data['feat'])
253
+ # if feat not in feats:
254
+ # return
255
+ # idx = feats.index(feat)
256
+ # new_lams = list(dd_source.data['lam'])
257
+ # new_lams[idx] = new_val
258
+ # new_data = dict(dd_source.data)
259
+ # new_data['lam'] = new_lams
260
+ # dd_source.data = new_data
261
 
262
  dd_feat_bridge.on_change('value', _on_feat_bridge)
263
 
scripts/explorer/panels/feature_list.py CHANGED
@@ -21,7 +21,8 @@ from bokeh.io import curdoc
21
  from bokeh.layouts import column, row
22
  from bokeh.models import (
23
  Button, ColumnDataSource, DataTable, Div, NumberFormatter,
24
- Select, TableColumn, TextInput,
 
25
  )
26
 
27
  from ..args import args
@@ -83,28 +84,30 @@ def _schedule_hf_push(file_path: str):
83
 
84
  # ---------- Feature list table ----------
85
 
86
- sort_select = Select(
87
- title="Sort by", value="Frequency",
88
- options=["Frequency", "CLIP × Brain φ"],
89
- width=160,
90
- )
 
91
 
92
 
93
  def _get_sorted_order():
94
- ds = active_ds()
95
- if (sort_select.value == "CLIP × Brain φ"
96
- and _S.clip_query_embed is not None):
97
- embeds = ds.get('nsd_clip_embeds') if ds.get('nsd_clip_embeds') is not None else ds.get('clip_embeds')
98
- if embeds is not None:
99
- clip_sims = (embeds.float() @ _S.clip_query_embed.T).squeeze(-1).cpu().numpy()
100
- phi_vals = np.array(phi_c_vals(list(range(len(clip_sims)))))
101
- c = clip_sims - clip_sims.min(); c = c / (c.max() + 1e-8)
102
- p = phi_vals / (phi_vals.max() + 1e-8)
103
- order = np.argsort(-(c * p))
104
- else:
105
- order = np.argsort(-ds['freq'])
106
- else:
107
- order = np.argsort(-ds['freq'])
 
108
  if _S.search_filter is not None:
109
  mask = np.isin(order, list(_S.search_filter))
110
  order = order[mask]
@@ -123,11 +126,10 @@ def _build_list_data(order) -> dict:
123
  )
124
 
125
 
126
- def _on_sort_change(attr, old, new):
127
- apply_order(_get_sorted_order())
128
- rebuild_gallery()
129
-
130
- sort_select.on_change('value', _on_sort_change)
131
 
132
  feature_list_source = ColumnDataSource(data=_build_list_data(_get_sorted_order()))
133
 
 
21
  from bokeh.layouts import column, row
22
  from bokeh.models import (
23
  Button, ColumnDataSource, DataTable, Div, NumberFormatter,
24
+ # Select, # re-enable for CLIP × φ sort
25
+ TableColumn, TextInput,
26
  )
27
 
28
  from ..args import args
 
84
 
85
  # ---------- Feature list table ----------
86
 
87
+ # sort_select = Select( # re-enable for CLIP × φ sort
88
+ # title="Sort by", value="Frequency",
89
+ # options=["Frequency", "CLIP × Brain φ"],
90
+ # width=160,
91
+ # )
92
+ sort_select = None
93
 
94
 
95
  def _get_sorted_order():
96
+ ds = active_ds()
97
+ # CLIP × φ sort (disabled):
98
+ # if (sort_select is not None and sort_select.value == "CLIP × Brain φ"
99
+ # and _S.clip_query_embed is not None):
100
+ # embeds = ds.get('nsd_clip_embeds') if ds.get('nsd_clip_embeds') is not None else ds.get('clip_embeds')
101
+ # if embeds is not None:
102
+ # clip_sims = (embeds.float() @ _S.clip_query_embed.T).squeeze(-1).cpu().numpy()
103
+ # phi_vals = np.array(phi_c_vals(list(range(len(clip_sims)))))
104
+ # c = clip_sims - clip_sims.min(); c = c / (c.max() + 1e-8)
105
+ # p = phi_vals / (phi_vals.max() + 1e-8)
106
+ # order = np.argsort(-(c * p))
107
+ # else:
108
+ # order = np.argsort(-ds['freq'])
109
+ # else:
110
+ order = np.argsort(-ds['freq'])
111
  if _S.search_filter is not None:
112
  mask = np.isin(order, list(_S.search_filter))
113
  order = order[mask]
 
126
  )
127
 
128
 
129
+ # def _on_sort_change(attr, old, new): # re-enable for CLIP × φ sort
130
+ # apply_order(_get_sorted_order())
131
+ # rebuild_gallery()
132
+ # sort_select.on_change('value', _on_sort_change)
 
133
 
134
  feature_list_source = ColumnDataSource(data=_build_list_data(_get_sorted_order()))
135
 
scripts/explorer/rendering.py CHANGED
@@ -480,19 +480,20 @@ def make_active_features_tile_html(feats: list, ds: dict, mei_size: int = 72,
480
  f'margin-top:2px;text-align:center;overflow:hidden;'
481
  f'text-overflow:ellipsis;white-space:nowrap">{label}</div>') if label else ''
482
 
483
- # Lambda number input (only when removable/editable)
484
- if removable:
485
- lam_val = lam_map.get(feat, 3.0)
486
- lam_html = (
487
- f'<input type="number" value="{lam_val:.2g}" step="0.5" '
488
- f'style="width:54px;font-size:10px;text-align:center;'
489
- f'border:1px solid #ccc;border-radius:3px;margin-top:2px" '
490
- f'onclick="event.stopPropagation()" '
491
- f'onchange="window._dd_feat_action(\'set_lam:{feat}:\'+this.value)" '
492
- f'title="Steering strength λ"/>'
493
- )
494
- else:
495
- lam_html = ''
 
496
 
497
  cards.append(
498
  f'<div onclick="window._sae_select_feature({feat})" '
 
480
  f'margin-top:2px;text-align:center;overflow:hidden;'
481
  f'text-overflow:ellipsis;white-space:nowrap">{label}</div>') if label else ''
482
 
483
+ # Lambda number input (disabled — re-enable when set_lam handler is active)
484
+ # if removable:
485
+ # lam_val = lam_map.get(feat, 3.0)
486
+ # lam_html = (
487
+ # f'<input type="number" value="{lam_val:.2g}" step="0.5" '
488
+ # f'style="width:54px;font-size:10px;text-align:center;'
489
+ # f'border:1px solid #ccc;border-radius:3px;margin-top:2px" '
490
+ # f'onclick="event.stopPropagation()" '
491
+ # f'onchange="window._dd_feat_action(\'set_lam:{feat}:\'+this.value)" '
492
+ # f'title="Steering strength λ"/>'
493
+ # )
494
+ # else:
495
+ # lam_html = ''
496
+ lam_html = ''
497
 
498
  cards.append(
499
  f'<div onclick="window._sae_select_feature({feat})" '
scripts/explorer/state.py CHANGED
@@ -23,7 +23,7 @@ class _S:
23
  patch_img = None # image index currently loaded in patch explorer
24
  patch_z = None # (n_patches, d_sae) float32 for the loaded patch image
25
  gallery_page: int = 0 # current page in the MEI thumbnail gallery
26
- clip_query_embed = None # last CLIP text query embedding (1, d_clip) tensor for combined sort
27
 
28
 
29
  _all_datasets: list[dict] = []
 
23
  patch_img = None # image index currently loaded in patch explorer
24
  patch_z = None # (n_patches, d_sae) float32 for the loaded patch image
25
  gallery_page: int = 0 # current page in the MEI thumbnail gallery
26
+ # clip_query_embed = None # last CLIP text query embedding (1, d_clip) tensor for combined sort
27
 
28
 
29
  _all_datasets: list[dict] = []