Marlin Lee commited on
Commit
b424e01
·
1 Parent(s): 3a91ce0

Make feature tables sortable

Browse files
Files changed (1) hide show
  1. scripts/explorer_app.py +5 -12
scripts/explorer_app.py CHANGED
@@ -123,13 +123,6 @@ def _get_clip():
123
  return _clip_handle[0]
124
 
125
 
126
- def _center_umap(coords: np.ndarray) -> np.ndarray:
127
- """Shift UMAP coordinates so the mean of live (non-NaN) points is (0, 0)."""
128
- live = ~np.isnan(coords[:, 0])
129
- coords[live] -= coords[live].mean(axis=0)
130
- return coords
131
-
132
-
133
  # ---------- Load all datasets into a unified list ----------
134
 
135
  def _load_dataset_dict(path, label, sae_url=None):
@@ -169,8 +162,8 @@ def _load_dataset_dict(path, label, sae_url=None):
169
  'feature_frequency': d['feature_frequency'],
170
  'feature_mean_act': d['feature_mean_act'],
171
  'feature_p75_val': d['feature_p75_val'],
172
- 'umap_coords': _center_umap(d['umap_coords'].numpy()),
173
- 'dict_umap_coords': _center_umap(d['dict_umap_coords'].numpy()),
174
  'clip_scores': cs,
175
  'clip_vocab': d.get('clip_text_vocab', None),
176
  'clip_embeds': d.get('clip_feature_embeds', None),
@@ -1215,7 +1208,7 @@ feature_table = DataTable(
1215
  formatter=NumberFormatter(format="0.0000")),
1216
  TableColumn(field="name", title="Name", width=200),
1217
  ],
1218
- width=500, height=500, sortable=False, index_position=None,
1219
  )
1220
 
1221
  # Search state: None = no filter, otherwise a set of matching feature indices
@@ -1611,7 +1604,7 @@ patch_feat_table = DataTable(
1611
  TableColumn(field="mean_act", title="Mean Act", width=80,
1612
  formatter=NumberFormatter(format="0.0000")),
1613
  ],
1614
- width=310, height=350, index_position=None, sortable=False, visible=False,
1615
  )
1616
  patch_info_div = Div(
1617
  text="<i>Load an image, then click patches to find top features.</i>",
@@ -1780,7 +1773,7 @@ if HAS_CLIP:
1780
  formatter=NumberFormatter(format="0.0000")),
1781
  TableColumn(field="name", title="Name", width=160),
1782
  ],
1783
- width=470, height=300, index_position=None, sortable=False,
1784
  )
1785
 
1786
  def _do_clip_search():
 
123
  return _clip_handle[0]
124
 
125
 
 
 
 
 
 
 
 
126
  # ---------- Load all datasets into a unified list ----------
127
 
128
  def _load_dataset_dict(path, label, sae_url=None):
 
162
  'feature_frequency': d['feature_frequency'],
163
  'feature_mean_act': d['feature_mean_act'],
164
  'feature_p75_val': d['feature_p75_val'],
165
+ 'umap_coords': d['umap_coords'].numpy(),
166
+ 'dict_umap_coords': d['dict_umap_coords'].numpy(),
167
  'clip_scores': cs,
168
  'clip_vocab': d.get('clip_text_vocab', None),
169
  'clip_embeds': d.get('clip_feature_embeds', None),
 
1208
  formatter=NumberFormatter(format="0.0000")),
1209
  TableColumn(field="name", title="Name", width=200),
1210
  ],
1211
+ width=500, height=500, sortable=True, index_position=None,
1212
  )
1213
 
1214
  # Search state: None = no filter, otherwise a set of matching feature indices
 
1604
  TableColumn(field="mean_act", title="Mean Act", width=80,
1605
  formatter=NumberFormatter(format="0.0000")),
1606
  ],
1607
+ width=310, height=350, index_position=None, sortable=True, visible=False,
1608
  )
1609
  patch_info_div = Div(
1610
  text="<i>Load an image, then click patches to find top features.</i>",
 
1773
  formatter=NumberFormatter(format="0.0000")),
1774
  TableColumn(field="name", title="Name", width=160),
1775
  ],
1776
+ width=470, height=300, index_position=None, sortable=True,
1777
  )
1778
 
1779
  def _do_clip_search():