Marlin Lee commited on
Commit
5891413
·
1 Parent(s): 5b87894

Sync explorer_app.py and clip_utils.py from main repo

Browse files
Files changed (1) hide show
  1. scripts/explorer_app.py +113 -1
scripts/explorer_app.py CHANGED
@@ -516,6 +516,43 @@ def render_zoomed_patch(img_idx, heatmap_16x16, size=THUMB, pg=None):
516
  return img.crop((x0, y0, x1, y1)).resize((size, size), Image.BILINEAR)
517
 
518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  def pil_to_data_url(img):
520
  buf = io.BytesIO()
521
  img.save(buf, format="JPEG", quality=85)
@@ -612,6 +649,48 @@ def make_compare_aggregations_html(top_infos, mean_infos, p75_infos, feat, n_eac
612
  return html
613
 
614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
  # ---------- UMAP data source ----------
616
  # live_mask / live_indices / freq / mean_act / log_freq / umap_backup are all
617
  # already set by _apply_dataset_globals(0) above — just build the source from them.
@@ -1795,7 +1874,40 @@ middle_panel = column(
1795
  p75_heatmap_div, p75_zoom_div,
1796
  )
1797
 
1798
- right_panel = column(summary_section, patch_section, clip_section)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1799
 
1800
  layout = row(left_panel, middle_panel, right_panel)
1801
  curdoc().add_root(layout)
 
516
  return img.crop((x0, y0, x1, y1)).resize((size, size), Image.BILINEAR)
517
 
518
 
519
+ def _load_image_from_ds(ds, img_i):
520
+ """Like load_image() but uses the given dataset's image_paths."""
521
+ path = ds['image_paths'][img_i]
522
+ fname = os.path.basename(path)
523
+ for base in [args.image_dir] + ([args.extra_image_dir] if args.extra_image_dir else []):
524
+ candidate = os.path.join(base, fname)
525
+ if os.path.exists(candidate):
526
+ return Image.open(candidate).convert("RGB")
527
+ return Image.open(path).convert("RGB")
528
+
529
+
530
+ def _render_overlay_from_ds(ds, feat, slot, size=THUMB, alpha=None):
531
+ """Return (PIL overlay image, caption) for ds/feat/slot, or None on failure."""
532
+ if alpha is None:
533
+ alpha = heatmap_alpha_slider.value
534
+ try:
535
+ img_i = int(ds['top_img_idx'][feat, slot].item())
536
+ if img_i < 0:
537
+ return None
538
+ plain = _load_image_from_ds(ds, img_i).resize((size, size), Image.BILINEAR)
539
+ hm_tensor = ds.get('top_heatmaps')
540
+ if hm_tensor is not None:
541
+ pg = ds.get('heatmap_patch_grid', 16)
542
+ hmap = hm_tensor[feat, slot].float().numpy().reshape(pg, pg)
543
+ img_arr = np.array(plain).astype(np.float32) / 255.0
544
+ hmap_up = cv2.resize(hmap, (size, size), interpolation=cv2.INTER_CUBIC)
545
+ hmax = hmap_up.max()
546
+ hmap_norm = hmap_up / hmax if hmax > 0 else hmap_up
547
+ overlay = ALPHA_JET(hmap_norm)
548
+ ov_a = overlay[:, :, 3:4] * alpha
549
+ blended = np.clip((img_arr * (1 - ov_a) + overlay[:, :, :3] * ov_a) * 255, 0, 255).astype(np.uint8)
550
+ return Image.fromarray(blended), f"img {img_i}"
551
+ return plain, f"img {img_i}"
552
+ except Exception:
553
+ return None
554
+
555
+
556
  def pil_to_data_url(img):
557
  buf = io.BytesIO()
558
  img.save(buf, format="JPEG", quality=85)
 
649
  return html
650
 
651
 
652
+ def make_cross_sae_comparison_html(ds_a, feat_a, ds_b, feat_b, n=4, size=160):
653
+ """
654
+ Two side-by-side 2×2 grids: left = SAE A / feat_a, right = SAE B / feat_b.
655
+ """
656
+ def _collect(ds, feat):
657
+ items = []
658
+ for slot in range(min(n, ds['top_img_idx'].shape[1])):
659
+ result = _render_overlay_from_ds(ds, feat, slot, size=size)
660
+ if result:
661
+ items.append(result)
662
+ if len(items) == n:
663
+ break
664
+ return items
665
+
666
+ items_a = _collect(ds_a, feat_a)
667
+ items_b = _collect(ds_b, feat_b)
668
+
669
+ def _grid_html(items, label, color):
670
+ header = (f'<div style="background:{color};color:#fff;font-size:11px;font-weight:bold;'
671
+ f'text-align:center;padding:4px;border-radius:4px;margin-bottom:6px">{label}</div>')
672
+ grid = '<div style="display:grid;grid-template-columns:repeat(2,{s}px);gap:4px">'.format(s=size)
673
+ for img, cap in items:
674
+ url = pil_to_data_url(img)
675
+ grid += (f'<div style="text-align:center">'
676
+ f'<img src="{url}" width="{size}" height="{size}"'
677
+ f' style="border:1px solid #ccc;border-radius:3px;display:block"/>'
678
+ f'<div style="font-size:9px;color:#555;margin-top:2px">{cap}</div></div>')
679
+ grid += '</div>'
680
+ return header + grid
681
+
682
+ label_a = f"{ds_a['label']} — feat {feat_a}"
683
+ label_b = f"{ds_b['label']} — feat {feat_b}"
684
+ col_a = _grid_html(items_a, label_a, "#2563a8")
685
+ col_b = _grid_html(items_b, label_b, "#b85c00")
686
+
687
+ return (
688
+ '<div style="display:flex;gap:16px;padding:8px;background:#fafafa;'
689
+ 'border:1px solid #ddd;border-radius:6px">'
690
+ + col_a + col_b + '</div>'
691
+ )
692
+
693
+
694
  # ---------- UMAP data source ----------
695
  # live_mask / live_indices / freq / mean_act / log_freq / umap_backup are all
696
  # already set by _apply_dataset_globals(0) above — just build the source from them.
 
1874
  p75_heatmap_div, p75_zoom_div,
1875
  )
1876
 
1877
+ # --- Cross-SAE comparison section ---
1878
+ cmp_ds_a = Select(title="SAE A:", value="0",
1879
+ options=[(str(i), ds['label']) for i, ds in enumerate(_all_datasets)])
1880
+ cmp_feat_a = TextInput(title="Feature (SAE A):", value="0", width=100)
1881
+ cmp_ds_b = Select(title="SAE B:", value=str(min(1, len(_all_datasets)-1)),
1882
+ options=[(str(i), ds['label']) for i, ds in enumerate(_all_datasets)])
1883
+ cmp_feat_b = TextInput(title="Feature (SAE B):", value="0", width=100)
1884
+ cmp_btn = Button(label="Generate Comparison", button_type="primary", width=200)
1885
+ cmp_output_div = Div(text="", width=400)
1886
+
1887
+ def _on_cmp_generate():
1888
+ try:
1889
+ idx_a = int(cmp_ds_a.value)
1890
+ idx_b = int(cmp_ds_b.value)
1891
+ fa = int(cmp_feat_a.value)
1892
+ fb = int(cmp_feat_b.value)
1893
+ _ensure_loaded(idx_a)
1894
+ _ensure_loaded(idx_b)
1895
+ ds_a = _all_datasets[idx_a]
1896
+ ds_b = _all_datasets[idx_b]
1897
+ cmp_output_div.text = make_cross_sae_comparison_html(ds_a, fa, ds_b, fb)
1898
+ except Exception as e:
1899
+ cmp_output_div.text = f'<p style="color:red">Error: {e}</p>'
1900
+
1901
+ cmp_btn.on_click(lambda: _on_cmp_generate())
1902
+
1903
+ cmp_section = _make_collapsible("Cross-SAE Comparison", column(
1904
+ row(cmp_ds_a, cmp_feat_a),
1905
+ row(cmp_ds_b, cmp_feat_b),
1906
+ cmp_btn,
1907
+ cmp_output_div,
1908
+ ))
1909
+
1910
+ right_panel = column(summary_section, patch_section, clip_section, cmp_section)
1911
 
1912
  layout = row(left_panel, middle_panel, right_panel)
1913
  curdoc().add_root(layout)