Spaces:
Running
Running
Marlin Lee commited on
Commit ·
41be4ef
1
Parent(s): 68497cf
Sync explorer_app.py and clip_utils.py from main repo
Browse files- scripts/explorer_app.py +25 -5
scripts/explorer_app.py
CHANGED
|
@@ -649,14 +649,14 @@ def make_compare_aggregations_html(top_infos, mean_infos, p75_infos, feat, n_eac
|
|
| 649 |
return html
|
| 650 |
|
| 651 |
|
| 652 |
-
def make_cross_sae_comparison_html(ds_a, feat_a, ds_b, feat_b, n=4, size=160):
|
| 653 |
"""
|
| 654 |
Two side-by-side 2×2 grids: left = SAE A / feat_a, right = SAE B / feat_b.
|
| 655 |
"""
|
| 656 |
def _collect(ds, feat):
|
| 657 |
items = []
|
| 658 |
for slot in range(min(n, ds['top_img_idx'].shape[1])):
|
| 659 |
-
result = _render_overlay_from_ds(ds, feat, slot, size=size)
|
| 660 |
if result:
|
| 661 |
items.append(result)
|
| 662 |
if len(items) == n:
|
|
@@ -1889,10 +1889,28 @@ cmp_feat_a = TextInput(title="Feature (SAE A):", value="0", width=100)
|
|
| 1889 |
cmp_ds_b = Select(title="SAE B:", value=str(min(1, len(_all_datasets)-1)),
|
| 1890 |
options=[(str(i), ds['label']) for i, ds in enumerate(_all_datasets)])
|
| 1891 |
cmp_feat_b = TextInput(title="Feature (SAE B):", value="0", width=100)
|
|
|
|
| 1892 |
cmp_btn = Button(label="Generate Comparison", button_type="primary", width=200)
|
| 1893 |
cmp_output_div = Div(text="", width=400)
|
| 1894 |
|
| 1895 |
def _on_cmp_generate():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1896 |
try:
|
| 1897 |
idx_a = int(cmp_ds_a.value)
|
| 1898 |
idx_b = int(cmp_ds_b.value)
|
|
@@ -1902,16 +1920,18 @@ def _on_cmp_generate():
|
|
| 1902 |
_ensure_loaded(idx_b)
|
| 1903 |
ds_a = _all_datasets[idx_a]
|
| 1904 |
ds_b = _all_datasets[idx_b]
|
| 1905 |
-
cmp_output_div.text = make_cross_sae_comparison_html(
|
|
|
|
| 1906 |
except Exception as e:
|
| 1907 |
cmp_output_div.text = f'<p style="color:red">Error: {e}</p>'
|
| 1908 |
|
| 1909 |
-
cmp_btn.on_click(lambda:
|
|
|
|
| 1910 |
|
| 1911 |
cmp_section = _make_collapsible("Cross-SAE Comparison", column(
|
| 1912 |
row(cmp_ds_a, cmp_feat_a),
|
| 1913 |
row(cmp_ds_b, cmp_feat_b),
|
| 1914 |
-
cmp_btn,
|
| 1915 |
cmp_output_div,
|
| 1916 |
))
|
| 1917 |
|
|
|
|
| 649 |
return html
|
| 650 |
|
| 651 |
|
| 652 |
+
def make_cross_sae_comparison_html(ds_a, feat_a, ds_b, feat_b, n=4, size=160, alpha=1.0):
|
| 653 |
"""
|
| 654 |
Two side-by-side 2×2 grids: left = SAE A / feat_a, right = SAE B / feat_b.
|
| 655 |
"""
|
| 656 |
def _collect(ds, feat):
|
| 657 |
items = []
|
| 658 |
for slot in range(min(n, ds['top_img_idx'].shape[1])):
|
| 659 |
+
result = _render_overlay_from_ds(ds, feat, slot, size=size, alpha=alpha)
|
| 660 |
if result:
|
| 661 |
items.append(result)
|
| 662 |
if len(items) == n:
|
|
|
|
| 1889 |
cmp_ds_b = Select(title="SAE B:", value=str(min(1, len(_all_datasets)-1)),
|
| 1890 |
options=[(str(i), ds['label']) for i, ds in enumerate(_all_datasets)])
|
| 1891 |
cmp_feat_b = TextInput(title="Feature (SAE B):", value="0", width=100)
|
| 1892 |
+
cmp_alpha_slider = Slider(title="Heatmap opacity", value=1.0, start=0.0, end=1.0, step=0.05, width=220)
|
| 1893 |
cmp_btn = Button(label="Generate Comparison", button_type="primary", width=200)
|
| 1894 |
cmp_output_div = Div(text="", width=400)
|
| 1895 |
|
| 1896 |
def _on_cmp_generate():
|
| 1897 |
+
if not cmp_output_div.text:
|
| 1898 |
+
return
|
| 1899 |
+
try:
|
| 1900 |
+
idx_a = int(cmp_ds_a.value)
|
| 1901 |
+
idx_b = int(cmp_ds_b.value)
|
| 1902 |
+
fa = int(cmp_feat_a.value)
|
| 1903 |
+
fb = int(cmp_feat_b.value)
|
| 1904 |
+
_ensure_loaded(idx_a)
|
| 1905 |
+
_ensure_loaded(idx_b)
|
| 1906 |
+
ds_a = _all_datasets[idx_a]
|
| 1907 |
+
ds_b = _all_datasets[idx_b]
|
| 1908 |
+
cmp_output_div.text = make_cross_sae_comparison_html(
|
| 1909 |
+
ds_a, fa, ds_b, fb, alpha=cmp_alpha_slider.value)
|
| 1910 |
+
except Exception as e:
|
| 1911 |
+
cmp_output_div.text = f'<p style="color:red">Error: {e}</p>'
|
| 1912 |
+
|
| 1913 |
+
def _on_cmp_btn():
|
| 1914 |
try:
|
| 1915 |
idx_a = int(cmp_ds_a.value)
|
| 1916 |
idx_b = int(cmp_ds_b.value)
|
|
|
|
| 1920 |
_ensure_loaded(idx_b)
|
| 1921 |
ds_a = _all_datasets[idx_a]
|
| 1922 |
ds_b = _all_datasets[idx_b]
|
| 1923 |
+
cmp_output_div.text = make_cross_sae_comparison_html(
|
| 1924 |
+
ds_a, fa, ds_b, fb, alpha=cmp_alpha_slider.value)
|
| 1925 |
except Exception as e:
|
| 1926 |
cmp_output_div.text = f'<p style="color:red">Error: {e}</p>'
|
| 1927 |
|
| 1928 |
+
cmp_btn.on_click(lambda: _on_cmp_btn())
|
| 1929 |
+
cmp_alpha_slider.on_change('value', lambda attr, old, new: _on_cmp_generate())
|
| 1930 |
|
| 1931 |
cmp_section = _make_collapsible("Cross-SAE Comparison", column(
|
| 1932 |
row(cmp_ds_a, cmp_feat_a),
|
| 1933 |
row(cmp_ds_b, cmp_feat_b),
|
| 1934 |
+
row(cmp_alpha_slider, cmp_btn),
|
| 1935 |
cmp_output_div,
|
| 1936 |
))
|
| 1937 |
|