Marlin Lee commited on
Commit
bef3be1
·
1 Parent(s): fc53679

Sync explorer_app: add auto_interp, fix image loading, remove cross-dataset panel

Browse files
Files changed (1) hide show
  1. scripts/explorer_app.py +37 -10
scripts/explorer_app.py CHANGED
@@ -111,6 +111,13 @@ def _load_dataset_dict(path, label):
111
  if os.path.exists(names_file):
112
  with open(names_file) as _nf:
113
  feat_names = {int(k): v for k, v in json.load(_nf).items()}
 
 
 
 
 
 
 
114
  entry = {
115
  'label': label,
116
  'path': path,
@@ -139,6 +146,7 @@ def _load_dataset_dict(path, label):
139
  'inference_cache': OrderedDict(),
140
  'names_file': names_file,
141
  'feature_names': feat_names,
 
142
  }
143
  # Load pre-computed heatmaps sidecar if present
144
  sidecar = os.path.splitext(path)[0] + '_heatmaps.pt'
@@ -212,7 +220,7 @@ def _apply_dataset_globals(idx):
212
  global umap_backup
213
  global _clip_scores, _clip_vocab, _clip_embeds, _clip_scores_f32, HAS_CLIP
214
  global _compare_datasets
215
- global feature_names, _names_file
216
 
217
  ds = _all_datasets[idx]
218
  image_paths = ds['image_paths']
@@ -242,6 +250,7 @@ def _apply_dataset_globals(idx):
242
  _compare_datasets = [d for i, d in enumerate(_all_datasets) if i != idx]
243
  feature_names = ds['feature_names']
244
  _names_file = ds['names_file']
 
245
 
246
  # Derived arrays used by UMAP, feature list, and callbacks
247
  freq = feature_frequency.numpy()
@@ -274,6 +283,15 @@ def _save_names():
274
  print(f"Saved {len(feature_names)} feature names to {_names_file}")
275
 
276
 
 
 
 
 
 
 
 
 
 
277
  # Live inference has been removed — all feature display and patch exploration
278
  # is driven entirely by pre-computed sidecars (_heatmaps.pt, _patch_acts.pt).
279
  HAS_CLIP_MODEL = False
@@ -637,11 +655,20 @@ def update_feature_display(feature_idx):
637
  dead = "DEAD FEATURE" if freq_val == 0 else ""
638
 
639
  feat_name = feature_names.get(feat, "")
640
- name_display = (
641
- f'<div style="color:#1a6faf;font-style:italic;margin:2px 0 6px 0">'
642
- f'&#x1F3F7;&#xFE0E; {feat_name}</div>'
643
- if feat_name else ""
644
- )
 
 
 
 
 
 
 
 
 
645
 
646
  stats_div.text = f"""
647
  <h2 style="margin:4px 0">Feature {feat} <span style="color:red">{dead}</span></h2>
@@ -900,7 +927,7 @@ feature_list_source = ColumnDataSource(data=dict(
900
  frequency=freq_np[_init_order].tolist(),
901
  mean_act=mean_act_np[_init_order].tolist(),
902
  p75_val=p75_np[_init_order].tolist(),
903
- name=[feature_names.get(int(i), "") for i in _init_order],
904
  ))
905
 
906
  feature_table = DataTable(
@@ -954,7 +981,7 @@ def _apply_order(order):
954
  frequency=freq_np[order].tolist(),
955
  mean_act=mean_act_np[order].tolist(),
956
  p75_val=p75_np[order].tolist(),
957
- name=[feature_names.get(int(i), "") for i in order],
958
  )
959
 
960
 
@@ -970,7 +997,7 @@ def _update_table_names():
970
  frequency=freq_np[order].tolist(),
971
  mean_act=mean_act_np[order].tolist(),
972
  p75_val=p75_np[order].tolist(),
973
- name=[feature_names.get(int(i), "") for i in order],
974
  )
975
 
976
 
@@ -1361,7 +1388,7 @@ if HAS_CLIP:
1361
  clip_score=[float(scores_vec[i]) for i in top_indices],
1362
  frequency=[int(feature_frequency[i].item()) for i in top_indices],
1363
  mean_act=[float(feature_mean_act[i].item()) for i in top_indices],
1364
- name=[feature_names.get(int(i), "") for i in top_indices],
1365
  )
1366
  clip_result_div.text = (
1367
  f'<span style="color:#1a6faf"><b>{len(top_indices)}</b> features for '
 
111
  if os.path.exists(names_file):
112
  with open(names_file) as _nf:
113
  feat_names = {int(k): v for k, v in json.load(_nf).items()}
114
+ auto_interp_file = os.path.splitext(path)[0] + '_auto_interp.json'
115
+ auto_interp = {}
116
+ if os.path.exists(auto_interp_file):
117
+ with open(auto_interp_file) as _af:
118
+ auto_interp = {int(k): v for k, v in json.load(_af).items()}
119
+ print(f" Loaded {len(auto_interp)} auto-interp labels from "
120
+ f"{os.path.basename(auto_interp_file)}")
121
  entry = {
122
  'label': label,
123
  'path': path,
 
146
  'inference_cache': OrderedDict(),
147
  'names_file': names_file,
148
  'feature_names': feat_names,
149
+ 'auto_interp_names': auto_interp,
150
  }
151
  # Load pre-computed heatmaps sidecar if present
152
  sidecar = os.path.splitext(path)[0] + '_heatmaps.pt'
 
220
  global umap_backup
221
  global _clip_scores, _clip_vocab, _clip_embeds, _clip_scores_f32, HAS_CLIP
222
  global _compare_datasets
223
+ global feature_names, _names_file, auto_interp_names
224
 
225
  ds = _all_datasets[idx]
226
  image_paths = ds['image_paths']
 
250
  _compare_datasets = [d for i, d in enumerate(_all_datasets) if i != idx]
251
  feature_names = ds['feature_names']
252
  _names_file = ds['names_file']
253
+ auto_interp_names = ds['auto_interp_names']
254
 
255
  # Derived arrays used by UMAP, feature list, and callbacks
256
  freq = feature_frequency.numpy()
 
283
  print(f"Saved {len(feature_names)} feature names to {_names_file}")
284
 
285
 
286
+ def _display_name(feat: int) -> str:
287
+ """Return the label to show in tables: manual label takes priority over auto-interp."""
288
+ m = feature_names.get(feat)
289
+ if m:
290
+ return m
291
+ a = auto_interp_names.get(feat)
292
+ return f"[auto] {a}" if a else ""
293
+
294
+
295
  # Live inference has been removed — all feature display and patch exploration
296
  # is driven entirely by pre-computed sidecars (_heatmaps.pt, _patch_acts.pt).
297
  HAS_CLIP_MODEL = False
 
655
  dead = "DEAD FEATURE" if freq_val == 0 else ""
656
 
657
  feat_name = feature_names.get(feat, "")
658
+ auto_name = auto_interp_names.get(feat, "")
659
+ if feat_name:
660
+ name_display = (
661
+ f'<div style="color:#1a6faf;font-style:italic;margin:2px 0 6px 0">'
662
+ f'&#x1F3F7;&#xFE0E; {feat_name}</div>'
663
+ )
664
+ elif auto_name:
665
+ name_display = (
666
+ f'<div style="color:#5a9a5a;font-style:italic;margin:2px 0 6px 0">'
667
+ f'&#x1F916; {auto_name}'
668
+ f'<span style="font-size:10px;color:#999;margin-left:6px">(auto-interp)</span></div>'
669
+ )
670
+ else:
671
+ name_display = ""
672
 
673
  stats_div.text = f"""
674
  <h2 style="margin:4px 0">Feature {feat} <span style="color:red">{dead}</span></h2>
 
927
  frequency=freq_np[_init_order].tolist(),
928
  mean_act=mean_act_np[_init_order].tolist(),
929
  p75_val=p75_np[_init_order].tolist(),
930
+ name=[_display_name(int(i)) for i in _init_order],
931
  ))
932
 
933
  feature_table = DataTable(
 
981
  frequency=freq_np[order].tolist(),
982
  mean_act=mean_act_np[order].tolist(),
983
  p75_val=p75_np[order].tolist(),
984
+ name=[_display_name(int(i)) for i in order],
985
  )
986
 
987
 
 
997
  frequency=freq_np[order].tolist(),
998
  mean_act=mean_act_np[order].tolist(),
999
  p75_val=p75_np[order].tolist(),
1000
+ name=[_display_name(int(i)) for i in order],
1001
  )
1002
 
1003
 
 
1388
  clip_score=[float(scores_vec[i]) for i in top_indices],
1389
  frequency=[int(feature_frequency[i].item()) for i in top_indices],
1390
  mean_act=[float(feature_mean_act[i].item()) for i in top_indices],
1391
+ name=[_display_name(int(i)) for i in top_indices],
1392
  )
1393
  clip_result_div.text = (
1394
  f'<span style="color:#1a6faf"><b>{len(top_indices)}</b> features for '