Spaces:
Running
Running
Marlin Lee commited on
Commit ·
5891413
1
Parent(s): 5b87894
Sync explorer_app.py and clip_utils.py from main repo
Browse files- scripts/explorer_app.py +113 -1
scripts/explorer_app.py
CHANGED
|
@@ -516,6 +516,43 @@ def render_zoomed_patch(img_idx, heatmap_16x16, size=THUMB, pg=None):
|
|
| 516 |
return img.crop((x0, y0, x1, y1)).resize((size, size), Image.BILINEAR)
|
| 517 |
|
| 518 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
def pil_to_data_url(img):
|
| 520 |
buf = io.BytesIO()
|
| 521 |
img.save(buf, format="JPEG", quality=85)
|
|
@@ -612,6 +649,48 @@ def make_compare_aggregations_html(top_infos, mean_infos, p75_infos, feat, n_eac
|
|
| 612 |
return html
|
| 613 |
|
| 614 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
# ---------- UMAP data source ----------
|
| 616 |
# live_mask / live_indices / freq / mean_act / log_freq / umap_backup are all
|
| 617 |
# already set by _apply_dataset_globals(0) above — just build the source from them.
|
|
@@ -1795,7 +1874,40 @@ middle_panel = column(
|
|
| 1795 |
p75_heatmap_div, p75_zoom_div,
|
| 1796 |
)
|
| 1797 |
|
| 1798 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1799 |
|
| 1800 |
layout = row(left_panel, middle_panel, right_panel)
|
| 1801 |
curdoc().add_root(layout)
|
|
|
|
| 516 |
return img.crop((x0, y0, x1, y1)).resize((size, size), Image.BILINEAR)
|
| 517 |
|
| 518 |
|
| 519 |
+
def _load_image_from_ds(ds, img_i):
|
| 520 |
+
"""Like load_image() but uses the given dataset's image_paths."""
|
| 521 |
+
path = ds['image_paths'][img_i]
|
| 522 |
+
fname = os.path.basename(path)
|
| 523 |
+
for base in [args.image_dir] + ([args.extra_image_dir] if args.extra_image_dir else []):
|
| 524 |
+
candidate = os.path.join(base, fname)
|
| 525 |
+
if os.path.exists(candidate):
|
| 526 |
+
return Image.open(candidate).convert("RGB")
|
| 527 |
+
return Image.open(path).convert("RGB")
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def _render_overlay_from_ds(ds, feat, slot, size=THUMB, alpha=None):
|
| 531 |
+
"""Return (PIL overlay image, caption) for ds/feat/slot, or None on failure."""
|
| 532 |
+
if alpha is None:
|
| 533 |
+
alpha = heatmap_alpha_slider.value
|
| 534 |
+
try:
|
| 535 |
+
img_i = int(ds['top_img_idx'][feat, slot].item())
|
| 536 |
+
if img_i < 0:
|
| 537 |
+
return None
|
| 538 |
+
plain = _load_image_from_ds(ds, img_i).resize((size, size), Image.BILINEAR)
|
| 539 |
+
hm_tensor = ds.get('top_heatmaps')
|
| 540 |
+
if hm_tensor is not None:
|
| 541 |
+
pg = ds.get('heatmap_patch_grid', 16)
|
| 542 |
+
hmap = hm_tensor[feat, slot].float().numpy().reshape(pg, pg)
|
| 543 |
+
img_arr = np.array(plain).astype(np.float32) / 255.0
|
| 544 |
+
hmap_up = cv2.resize(hmap, (size, size), interpolation=cv2.INTER_CUBIC)
|
| 545 |
+
hmax = hmap_up.max()
|
| 546 |
+
hmap_norm = hmap_up / hmax if hmax > 0 else hmap_up
|
| 547 |
+
overlay = ALPHA_JET(hmap_norm)
|
| 548 |
+
ov_a = overlay[:, :, 3:4] * alpha
|
| 549 |
+
blended = np.clip((img_arr * (1 - ov_a) + overlay[:, :, :3] * ov_a) * 255, 0, 255).astype(np.uint8)
|
| 550 |
+
return Image.fromarray(blended), f"img {img_i}"
|
| 551 |
+
return plain, f"img {img_i}"
|
| 552 |
+
except Exception:
|
| 553 |
+
return None
|
| 554 |
+
|
| 555 |
+
|
| 556 |
def pil_to_data_url(img):
|
| 557 |
buf = io.BytesIO()
|
| 558 |
img.save(buf, format="JPEG", quality=85)
|
|
|
|
| 649 |
return html
|
| 650 |
|
| 651 |
|
| 652 |
+
def make_cross_sae_comparison_html(ds_a, feat_a, ds_b, feat_b, n=4, size=160):
|
| 653 |
+
"""
|
| 654 |
+
Two side-by-side 2×2 grids: left = SAE A / feat_a, right = SAE B / feat_b.
|
| 655 |
+
"""
|
| 656 |
+
def _collect(ds, feat):
|
| 657 |
+
items = []
|
| 658 |
+
for slot in range(min(n, ds['top_img_idx'].shape[1])):
|
| 659 |
+
result = _render_overlay_from_ds(ds, feat, slot, size=size)
|
| 660 |
+
if result:
|
| 661 |
+
items.append(result)
|
| 662 |
+
if len(items) == n:
|
| 663 |
+
break
|
| 664 |
+
return items
|
| 665 |
+
|
| 666 |
+
items_a = _collect(ds_a, feat_a)
|
| 667 |
+
items_b = _collect(ds_b, feat_b)
|
| 668 |
+
|
| 669 |
+
def _grid_html(items, label, color):
|
| 670 |
+
header = (f'<div style="background:{color};color:#fff;font-size:11px;font-weight:bold;'
|
| 671 |
+
f'text-align:center;padding:4px;border-radius:4px;margin-bottom:6px">{label}</div>')
|
| 672 |
+
grid = '<div style="display:grid;grid-template-columns:repeat(2,{s}px);gap:4px">'.format(s=size)
|
| 673 |
+
for img, cap in items:
|
| 674 |
+
url = pil_to_data_url(img)
|
| 675 |
+
grid += (f'<div style="text-align:center">'
|
| 676 |
+
f'<img src="{url}" width="{size}" height="{size}"'
|
| 677 |
+
f' style="border:1px solid #ccc;border-radius:3px;display:block"/>'
|
| 678 |
+
f'<div style="font-size:9px;color:#555;margin-top:2px">{cap}</div></div>')
|
| 679 |
+
grid += '</div>'
|
| 680 |
+
return header + grid
|
| 681 |
+
|
| 682 |
+
label_a = f"{ds_a['label']} — feat {feat_a}"
|
| 683 |
+
label_b = f"{ds_b['label']} — feat {feat_b}"
|
| 684 |
+
col_a = _grid_html(items_a, label_a, "#2563a8")
|
| 685 |
+
col_b = _grid_html(items_b, label_b, "#b85c00")
|
| 686 |
+
|
| 687 |
+
return (
|
| 688 |
+
'<div style="display:flex;gap:16px;padding:8px;background:#fafafa;'
|
| 689 |
+
'border:1px solid #ddd;border-radius:6px">'
|
| 690 |
+
+ col_a + col_b + '</div>'
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
|
| 694 |
# ---------- UMAP data source ----------
|
| 695 |
# live_mask / live_indices / freq / mean_act / log_freq / umap_backup are all
|
| 696 |
# already set by _apply_dataset_globals(0) above — just build the source from them.
|
|
|
|
| 1874 |
p75_heatmap_div, p75_zoom_div,
|
| 1875 |
)
|
| 1876 |
|
| 1877 |
+
# --- Cross-SAE comparison section ---
|
| 1878 |
+
cmp_ds_a = Select(title="SAE A:", value="0",
|
| 1879 |
+
options=[(str(i), ds['label']) for i, ds in enumerate(_all_datasets)])
|
| 1880 |
+
cmp_feat_a = TextInput(title="Feature (SAE A):", value="0", width=100)
|
| 1881 |
+
cmp_ds_b = Select(title="SAE B:", value=str(min(1, len(_all_datasets)-1)),
|
| 1882 |
+
options=[(str(i), ds['label']) for i, ds in enumerate(_all_datasets)])
|
| 1883 |
+
cmp_feat_b = TextInput(title="Feature (SAE B):", value="0", width=100)
|
| 1884 |
+
cmp_btn = Button(label="Generate Comparison", button_type="primary", width=200)
|
| 1885 |
+
cmp_output_div = Div(text="", width=400)
|
| 1886 |
+
|
| 1887 |
+
def _on_cmp_generate():
|
| 1888 |
+
try:
|
| 1889 |
+
idx_a = int(cmp_ds_a.value)
|
| 1890 |
+
idx_b = int(cmp_ds_b.value)
|
| 1891 |
+
fa = int(cmp_feat_a.value)
|
| 1892 |
+
fb = int(cmp_feat_b.value)
|
| 1893 |
+
_ensure_loaded(idx_a)
|
| 1894 |
+
_ensure_loaded(idx_b)
|
| 1895 |
+
ds_a = _all_datasets[idx_a]
|
| 1896 |
+
ds_b = _all_datasets[idx_b]
|
| 1897 |
+
cmp_output_div.text = make_cross_sae_comparison_html(ds_a, fa, ds_b, fb)
|
| 1898 |
+
except Exception as e:
|
| 1899 |
+
cmp_output_div.text = f'<p style="color:red">Error: {e}</p>'
|
| 1900 |
+
|
| 1901 |
+
cmp_btn.on_click(lambda: _on_cmp_generate())
|
| 1902 |
+
|
| 1903 |
+
cmp_section = _make_collapsible("Cross-SAE Comparison", column(
|
| 1904 |
+
row(cmp_ds_a, cmp_feat_a),
|
| 1905 |
+
row(cmp_ds_b, cmp_feat_b),
|
| 1906 |
+
cmp_btn,
|
| 1907 |
+
cmp_output_div,
|
| 1908 |
+
))
|
| 1909 |
+
|
| 1910 |
+
right_panel = column(summary_section, patch_section, clip_section, cmp_section)
|
| 1911 |
|
| 1912 |
layout = row(left_panel, middle_panel, right_panel)
|
| 1913 |
curdoc().add_root(layout)
|