Marlin Lee commited on
Commit
41be4ef
·
1 Parent(s): 68497cf

Sync explorer_app.py and clip_utils.py from main repo

Browse files
Files changed (1) hide show
  1. scripts/explorer_app.py +25 -5
scripts/explorer_app.py CHANGED
@@ -649,14 +649,14 @@ def make_compare_aggregations_html(top_infos, mean_infos, p75_infos, feat, n_eac
649
  return html
650
 
651
 
652
- def make_cross_sae_comparison_html(ds_a, feat_a, ds_b, feat_b, n=4, size=160):
653
  """
654
  Two side-by-side 2×2 grids: left = SAE A / feat_a, right = SAE B / feat_b.
655
  """
656
  def _collect(ds, feat):
657
  items = []
658
  for slot in range(min(n, ds['top_img_idx'].shape[1])):
659
- result = _render_overlay_from_ds(ds, feat, slot, size=size)
660
  if result:
661
  items.append(result)
662
  if len(items) == n:
@@ -1889,10 +1889,28 @@ cmp_feat_a = TextInput(title="Feature (SAE A):", value="0", width=100)
1889
  cmp_ds_b = Select(title="SAE B:", value=str(min(1, len(_all_datasets)-1)),
1890
  options=[(str(i), ds['label']) for i, ds in enumerate(_all_datasets)])
1891
  cmp_feat_b = TextInput(title="Feature (SAE B):", value="0", width=100)
 
1892
  cmp_btn = Button(label="Generate Comparison", button_type="primary", width=200)
1893
  cmp_output_div = Div(text="", width=400)
1894
 
1895
  def _on_cmp_generate():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1896
  try:
1897
  idx_a = int(cmp_ds_a.value)
1898
  idx_b = int(cmp_ds_b.value)
@@ -1902,16 +1920,18 @@ def _on_cmp_generate():
1902
  _ensure_loaded(idx_b)
1903
  ds_a = _all_datasets[idx_a]
1904
  ds_b = _all_datasets[idx_b]
1905
- cmp_output_div.text = make_cross_sae_comparison_html(ds_a, fa, ds_b, fb)
 
1906
  except Exception as e:
1907
  cmp_output_div.text = f'<p style="color:red">Error: {e}</p>'
1908
 
1909
- cmp_btn.on_click(lambda: _on_cmp_generate())
 
1910
 
1911
  cmp_section = _make_collapsible("Cross-SAE Comparison", column(
1912
  row(cmp_ds_a, cmp_feat_a),
1913
  row(cmp_ds_b, cmp_feat_b),
1914
- cmp_btn,
1915
  cmp_output_div,
1916
  ))
1917
 
 
649
  return html
650
 
651
 
652
+ def make_cross_sae_comparison_html(ds_a, feat_a, ds_b, feat_b, n=4, size=160, alpha=1.0):
653
  """
654
  Two side-by-side 2×2 grids: left = SAE A / feat_a, right = SAE B / feat_b.
655
  """
656
  def _collect(ds, feat):
657
  items = []
658
  for slot in range(min(n, ds['top_img_idx'].shape[1])):
659
+ result = _render_overlay_from_ds(ds, feat, slot, size=size, alpha=alpha)
660
  if result:
661
  items.append(result)
662
  if len(items) == n:
 
1889
  cmp_ds_b = Select(title="SAE B:", value=str(min(1, len(_all_datasets)-1)),
1890
  options=[(str(i), ds['label']) for i, ds in enumerate(_all_datasets)])
1891
  cmp_feat_b = TextInput(title="Feature (SAE B):", value="0", width=100)
1892
+ cmp_alpha_slider = Slider(title="Heatmap opacity", value=1.0, start=0.0, end=1.0, step=0.05, width=220)
1893
  cmp_btn = Button(label="Generate Comparison", button_type="primary", width=200)
1894
  cmp_output_div = Div(text="", width=400)
1895
 
1896
  def _on_cmp_generate():
1897
+ if not cmp_output_div.text:
1898
+ return
1899
+ try:
1900
+ idx_a = int(cmp_ds_a.value)
1901
+ idx_b = int(cmp_ds_b.value)
1902
+ fa = int(cmp_feat_a.value)
1903
+ fb = int(cmp_feat_b.value)
1904
+ _ensure_loaded(idx_a)
1905
+ _ensure_loaded(idx_b)
1906
+ ds_a = _all_datasets[idx_a]
1907
+ ds_b = _all_datasets[idx_b]
1908
+ cmp_output_div.text = make_cross_sae_comparison_html(
1909
+ ds_a, fa, ds_b, fb, alpha=cmp_alpha_slider.value)
1910
+ except Exception as e:
1911
+ cmp_output_div.text = f'<p style="color:red">Error: {e}</p>'
1912
+
1913
+ def _on_cmp_btn():
1914
  try:
1915
  idx_a = int(cmp_ds_a.value)
1916
  idx_b = int(cmp_ds_b.value)
 
1920
  _ensure_loaded(idx_b)
1921
  ds_a = _all_datasets[idx_a]
1922
  ds_b = _all_datasets[idx_b]
1923
+ cmp_output_div.text = make_cross_sae_comparison_html(
1924
+ ds_a, fa, ds_b, fb, alpha=cmp_alpha_slider.value)
1925
  except Exception as e:
1926
  cmp_output_div.text = f'<p style="color:red">Error: {e}</p>'
1927
 
1928
+ cmp_btn.on_click(lambda: _on_cmp_btn())
1929
+ cmp_alpha_slider.on_change('value', lambda attr, old, new: _on_cmp_generate())
1930
 
1931
  cmp_section = _make_collapsible("Cross-SAE Comparison", column(
1932
  row(cmp_ds_a, cmp_feat_a),
1933
  row(cmp_ds_b, cmp_feat_b),
1934
+ row(cmp_alpha_slider, cmp_btn),
1935
  cmp_output_div,
1936
  ))
1937