Spaces:
Running
Running
Marlin Lee commited on
Commit ·
bef3be1
1
Parent(s): fc53679
Sync explorer_app: add auto_interp, fix image loading, remove cross-dataset panel
Browse files- scripts/explorer_app.py +37 -10
scripts/explorer_app.py
CHANGED
|
@@ -111,6 +111,13 @@ def _load_dataset_dict(path, label):
|
|
| 111 |
if os.path.exists(names_file):
|
| 112 |
with open(names_file) as _nf:
|
| 113 |
feat_names = {int(k): v for k, v in json.load(_nf).items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
entry = {
|
| 115 |
'label': label,
|
| 116 |
'path': path,
|
|
@@ -139,6 +146,7 @@ def _load_dataset_dict(path, label):
|
|
| 139 |
'inference_cache': OrderedDict(),
|
| 140 |
'names_file': names_file,
|
| 141 |
'feature_names': feat_names,
|
|
|
|
| 142 |
}
|
| 143 |
# Load pre-computed heatmaps sidecar if present
|
| 144 |
sidecar = os.path.splitext(path)[0] + '_heatmaps.pt'
|
|
@@ -212,7 +220,7 @@ def _apply_dataset_globals(idx):
|
|
| 212 |
global umap_backup
|
| 213 |
global _clip_scores, _clip_vocab, _clip_embeds, _clip_scores_f32, HAS_CLIP
|
| 214 |
global _compare_datasets
|
| 215 |
-
global feature_names, _names_file
|
| 216 |
|
| 217 |
ds = _all_datasets[idx]
|
| 218 |
image_paths = ds['image_paths']
|
|
@@ -242,6 +250,7 @@ def _apply_dataset_globals(idx):
|
|
| 242 |
_compare_datasets = [d for i, d in enumerate(_all_datasets) if i != idx]
|
| 243 |
feature_names = ds['feature_names']
|
| 244 |
_names_file = ds['names_file']
|
|
|
|
| 245 |
|
| 246 |
# Derived arrays used by UMAP, feature list, and callbacks
|
| 247 |
freq = feature_frequency.numpy()
|
|
@@ -274,6 +283,15 @@ def _save_names():
|
|
| 274 |
print(f"Saved {len(feature_names)} feature names to {_names_file}")
|
| 275 |
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
# Live inference has been removed — all feature display and patch exploration
|
| 278 |
# is driven entirely by pre-computed sidecars (_heatmaps.pt, _patch_acts.pt).
|
| 279 |
HAS_CLIP_MODEL = False
|
|
@@ -637,11 +655,20 @@ def update_feature_display(feature_idx):
|
|
| 637 |
dead = "DEAD FEATURE" if freq_val == 0 else ""
|
| 638 |
|
| 639 |
feat_name = feature_names.get(feat, "")
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
|
| 646 |
stats_div.text = f"""
|
| 647 |
<h2 style="margin:4px 0">Feature {feat} <span style="color:red">{dead}</span></h2>
|
|
@@ -900,7 +927,7 @@ feature_list_source = ColumnDataSource(data=dict(
|
|
| 900 |
frequency=freq_np[_init_order].tolist(),
|
| 901 |
mean_act=mean_act_np[_init_order].tolist(),
|
| 902 |
p75_val=p75_np[_init_order].tolist(),
|
| 903 |
-
name=[
|
| 904 |
))
|
| 905 |
|
| 906 |
feature_table = DataTable(
|
|
@@ -954,7 +981,7 @@ def _apply_order(order):
|
|
| 954 |
frequency=freq_np[order].tolist(),
|
| 955 |
mean_act=mean_act_np[order].tolist(),
|
| 956 |
p75_val=p75_np[order].tolist(),
|
| 957 |
-
name=[
|
| 958 |
)
|
| 959 |
|
| 960 |
|
|
@@ -970,7 +997,7 @@ def _update_table_names():
|
|
| 970 |
frequency=freq_np[order].tolist(),
|
| 971 |
mean_act=mean_act_np[order].tolist(),
|
| 972 |
p75_val=p75_np[order].tolist(),
|
| 973 |
-
name=[
|
| 974 |
)
|
| 975 |
|
| 976 |
|
|
@@ -1361,7 +1388,7 @@ if HAS_CLIP:
|
|
| 1361 |
clip_score=[float(scores_vec[i]) for i in top_indices],
|
| 1362 |
frequency=[int(feature_frequency[i].item()) for i in top_indices],
|
| 1363 |
mean_act=[float(feature_mean_act[i].item()) for i in top_indices],
|
| 1364 |
-
name=[
|
| 1365 |
)
|
| 1366 |
clip_result_div.text = (
|
| 1367 |
f'<span style="color:#1a6faf"><b>{len(top_indices)}</b> features for '
|
|
|
|
| 111 |
if os.path.exists(names_file):
|
| 112 |
with open(names_file) as _nf:
|
| 113 |
feat_names = {int(k): v for k, v in json.load(_nf).items()}
|
| 114 |
+
auto_interp_file = os.path.splitext(path)[0] + '_auto_interp.json'
|
| 115 |
+
auto_interp = {}
|
| 116 |
+
if os.path.exists(auto_interp_file):
|
| 117 |
+
with open(auto_interp_file) as _af:
|
| 118 |
+
auto_interp = {int(k): v for k, v in json.load(_af).items()}
|
| 119 |
+
print(f" Loaded {len(auto_interp)} auto-interp labels from "
|
| 120 |
+
f"{os.path.basename(auto_interp_file)}")
|
| 121 |
entry = {
|
| 122 |
'label': label,
|
| 123 |
'path': path,
|
|
|
|
| 146 |
'inference_cache': OrderedDict(),
|
| 147 |
'names_file': names_file,
|
| 148 |
'feature_names': feat_names,
|
| 149 |
+
'auto_interp_names': auto_interp,
|
| 150 |
}
|
| 151 |
# Load pre-computed heatmaps sidecar if present
|
| 152 |
sidecar = os.path.splitext(path)[0] + '_heatmaps.pt'
|
|
|
|
| 220 |
global umap_backup
|
| 221 |
global _clip_scores, _clip_vocab, _clip_embeds, _clip_scores_f32, HAS_CLIP
|
| 222 |
global _compare_datasets
|
| 223 |
+
global feature_names, _names_file, auto_interp_names
|
| 224 |
|
| 225 |
ds = _all_datasets[idx]
|
| 226 |
image_paths = ds['image_paths']
|
|
|
|
| 250 |
_compare_datasets = [d for i, d in enumerate(_all_datasets) if i != idx]
|
| 251 |
feature_names = ds['feature_names']
|
| 252 |
_names_file = ds['names_file']
|
| 253 |
+
auto_interp_names = ds['auto_interp_names']
|
| 254 |
|
| 255 |
# Derived arrays used by UMAP, feature list, and callbacks
|
| 256 |
freq = feature_frequency.numpy()
|
|
|
|
| 283 |
print(f"Saved {len(feature_names)} feature names to {_names_file}")
|
| 284 |
|
| 285 |
|
| 286 |
+
def _display_name(feat: int) -> str:
|
| 287 |
+
"""Return the label to show in tables: manual label takes priority over auto-interp."""
|
| 288 |
+
m = feature_names.get(feat)
|
| 289 |
+
if m:
|
| 290 |
+
return m
|
| 291 |
+
a = auto_interp_names.get(feat)
|
| 292 |
+
return f"[auto] {a}" if a else ""
|
| 293 |
+
|
| 294 |
+
|
| 295 |
# Live inference has been removed — all feature display and patch exploration
|
| 296 |
# is driven entirely by pre-computed sidecars (_heatmaps.pt, _patch_acts.pt).
|
| 297 |
HAS_CLIP_MODEL = False
|
|
|
|
| 655 |
dead = "DEAD FEATURE" if freq_val == 0 else ""
|
| 656 |
|
| 657 |
feat_name = feature_names.get(feat, "")
|
| 658 |
+
auto_name = auto_interp_names.get(feat, "")
|
| 659 |
+
if feat_name:
|
| 660 |
+
name_display = (
|
| 661 |
+
f'<div style="color:#1a6faf;font-style:italic;margin:2px 0 6px 0">'
|
| 662 |
+
f'🏷︎ {feat_name}</div>'
|
| 663 |
+
)
|
| 664 |
+
elif auto_name:
|
| 665 |
+
name_display = (
|
| 666 |
+
f'<div style="color:#5a9a5a;font-style:italic;margin:2px 0 6px 0">'
|
| 667 |
+
f'🤖 {auto_name}'
|
| 668 |
+
f'<span style="font-size:10px;color:#999;margin-left:6px">(auto-interp)</span></div>'
|
| 669 |
+
)
|
| 670 |
+
else:
|
| 671 |
+
name_display = ""
|
| 672 |
|
| 673 |
stats_div.text = f"""
|
| 674 |
<h2 style="margin:4px 0">Feature {feat} <span style="color:red">{dead}</span></h2>
|
|
|
|
| 927 |
frequency=freq_np[_init_order].tolist(),
|
| 928 |
mean_act=mean_act_np[_init_order].tolist(),
|
| 929 |
p75_val=p75_np[_init_order].tolist(),
|
| 930 |
+
name=[_display_name(int(i)) for i in _init_order],
|
| 931 |
))
|
| 932 |
|
| 933 |
feature_table = DataTable(
|
|
|
|
| 981 |
frequency=freq_np[order].tolist(),
|
| 982 |
mean_act=mean_act_np[order].tolist(),
|
| 983 |
p75_val=p75_np[order].tolist(),
|
| 984 |
+
name=[_display_name(int(i)) for i in order],
|
| 985 |
)
|
| 986 |
|
| 987 |
|
|
|
|
| 997 |
frequency=freq_np[order].tolist(),
|
| 998 |
mean_act=mean_act_np[order].tolist(),
|
| 999 |
p75_val=p75_np[order].tolist(),
|
| 1000 |
+
name=[_display_name(int(i)) for i in order],
|
| 1001 |
)
|
| 1002 |
|
| 1003 |
|
|
|
|
| 1388 |
clip_score=[float(scores_vec[i]) for i in top_indices],
|
| 1389 |
frequency=[int(feature_frequency[i].item()) for i in top_indices],
|
| 1390 |
mean_act=[float(feature_mean_act[i].item()) for i in top_indices],
|
| 1391 |
+
name=[_display_name(int(i)) for i in top_indices],
|
| 1392 |
)
|
| 1393 |
clip_result_div.text = (
|
| 1394 |
f'<span style="color:#1a6faf"><b>{len(top_indices)}</b> features for '
|