Jac-Zac commited on
Commit
1b16c40
·
1 Parent(s): d39b2dd

Small ui cleanups

Browse files
Files changed (2) hide show
  1. tabs/compare.py +22 -7
  2. utils/theme.py +35 -2
tabs/compare.py CHANGED
@@ -3,6 +3,7 @@ from dataclasses import dataclass
3
  from itertools import combinations
4
  from pathlib import Path
5
 
 
6
  import streamlit as st
7
  from persona_data.environment import get_artifacts_dir
8
  from persona_data.synth_persona import BASELINE_PERSONA_ID
@@ -44,6 +45,7 @@ from utils.helpers import (
44
  slugify,
45
  widget_key,
46
  )
 
47
 
48
 
49
  def _filename(*parts: str) -> str:
@@ -295,12 +297,25 @@ def _render_save_buttons(
295
  """Render the Save HTML button for one or more figures."""
296
  if st.button("Save HTML", key=widget_key("load", "save_html", key_suffix)):
297
  try:
 
298
  paths = [save_plot_html(fig, fn) for fig, fn in zip(figs, filenames)]
299
  st.success(f"Saved {len(paths)} HTML file(s) to `artifacts/plots`.")
300
  except Exception as exc:
301
  st.error(f"Could not save HTML: {exc}")
302
 
303
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  def _render_mask_strategy_select(scope: str) -> MaskStrategy:
305
  return render_mask_strategy_select(
306
  key=widget_key("load", "mask_strategy", scope),
@@ -494,12 +509,12 @@ def _render_cosine_similarity(
494
 
495
  if cosine_fig_key in st.session_state:
496
  fig, pair_fig, n_traces, n_pair_traces = st.session_state[cosine_fig_key]
497
- st.plotly_chart(fig, width="stretch")
498
  figs = [fig]
499
  filenames = [filename]
500
  if pair_fig is not None:
501
  st.subheader("Variant pairs")
502
- st.plotly_chart(pair_fig, width="stretch")
503
  figs.append(pair_fig)
504
  filenames.append(pairs_filename)
505
  _render_save_buttons(figs, filenames, "cosine")
@@ -662,12 +677,12 @@ def _render_layered_figure_analysis(
662
 
663
  if fig_key in st.session_state:
664
  main_fig, extra_fig, n_samples = st.session_state[fig_key]
665
- st.plotly_chart(main_fig, width="stretch")
666
  figs = [main_fig]
667
  filenames = [filename]
668
  if extra_fig is not None:
669
  st.subheader("Pair trajectories")
670
- st.plotly_chart(extra_fig, width="stretch")
671
  figs.append(extra_fig)
672
  filenames.append(f"{filename}__pair_trajectories")
673
  _render_save_buttons(figs, filenames, scope)
@@ -816,12 +831,12 @@ def _render_dendrogram_analysis(
816
  col_a, col_b = st.columns(2)
817
  with col_a:
818
  st.subheader(prompt_variant_label(va))
819
- st.plotly_chart(fig_a, width="stretch")
820
  with col_b:
821
  st.subheader(prompt_variant_label(vb))
822
- st.plotly_chart(fig_b, width="stretch")
823
  else:
824
- st.plotly_chart(fig_a, width="stretch")
825
 
826
  figs = [fig_a] + ([fig_b] if fig_b else [])
827
  filenames = [
 
3
  from itertools import combinations
4
  from pathlib import Path
5
 
6
+ import plotly.graph_objects as go
7
  import streamlit as st
8
  from persona_data.environment import get_artifacts_dir
9
  from persona_data.synth_persona import BASELINE_PERSONA_ID
 
45
  slugify,
46
  widget_key,
47
  )
48
+ from utils.theme import style_plotly_layer_controls
49
 
50
 
51
  def _filename(*parts: str) -> str:
 
297
  """Render the Save HTML button for one or more figures."""
298
  if st.button("Save HTML", key=widget_key("load", "save_html", key_suffix)):
299
  try:
300
+ _style_plotly_figures(figs)
301
  paths = [save_plot_html(fig, fn) for fig, fn in zip(figs, filenames)]
302
  st.success(f"Saved {len(paths)} HTML file(s) to `artifacts/plots`.")
303
  except Exception as exc:
304
  st.error(f"Could not save HTML: {exc}")
305
 
306
 
307
+ def _style_plotly_figures(figs: list[object]) -> None:
308
+ base = st.get_option("theme.base")
309
+ for fig in figs:
310
+ if isinstance(fig, go.Figure):
311
+ style_plotly_layer_controls(fig, base)
312
+
313
+
314
+ def _plotly_chart(fig: object) -> None:
315
+ _style_plotly_figures([fig])
316
+ st.plotly_chart(fig, width="stretch")
317
+
318
+
319
  def _render_mask_strategy_select(scope: str) -> MaskStrategy:
320
  return render_mask_strategy_select(
321
  key=widget_key("load", "mask_strategy", scope),
 
509
 
510
  if cosine_fig_key in st.session_state:
511
  fig, pair_fig, n_traces, n_pair_traces = st.session_state[cosine_fig_key]
512
+ _plotly_chart(fig)
513
  figs = [fig]
514
  filenames = [filename]
515
  if pair_fig is not None:
516
  st.subheader("Variant pairs")
517
+ _plotly_chart(pair_fig)
518
  figs.append(pair_fig)
519
  filenames.append(pairs_filename)
520
  _render_save_buttons(figs, filenames, "cosine")
 
677
 
678
  if fig_key in st.session_state:
679
  main_fig, extra_fig, n_samples = st.session_state[fig_key]
680
+ _plotly_chart(main_fig)
681
  figs = [main_fig]
682
  filenames = [filename]
683
  if extra_fig is not None:
684
  st.subheader("Pair trajectories")
685
+ _plotly_chart(extra_fig)
686
  figs.append(extra_fig)
687
  filenames.append(f"{filename}__pair_trajectories")
688
  _render_save_buttons(figs, filenames, scope)
 
831
  col_a, col_b = st.columns(2)
832
  with col_a:
833
  st.subheader(prompt_variant_label(va))
834
+ _plotly_chart(fig_a)
835
  with col_b:
836
  st.subheader(prompt_variant_label(vb))
837
+ _plotly_chart(fig_b)
838
  else:
839
+ _plotly_chart(fig_a)
840
 
841
  figs = [fig_a] + ([fig_b] if fig_b else [])
842
  filenames = [
utils/theme.py CHANGED
@@ -5,6 +5,10 @@ import plotly.io as pio
5
  from catppuccin import PALETTE
6
 
7
 
 
 
 
 
8
  def install_catppuccin_theme(base: str | None = None) -> None:
9
  """Register a Catppuccin template and alias it as ``plotly_white``.
10
 
@@ -12,8 +16,7 @@ def install_catppuccin_theme(base: str | None = None) -> None:
12
  every figure, so replacing that entry themes all plots without any
13
  per-figure code.
14
  """
15
- flavor = PALETTE.latte if base == "light" else PALETTE.mocha
16
- c = flavor.colors
17
  bg, surface, line = c.base.hex, c.surface0.hex, c.surface1.hex
18
  text, subtext = c.text.hex, c.subtext1.hex
19
 
@@ -65,3 +68,33 @@ def install_catppuccin_theme(base: str | None = None) -> None:
65
  pio.templates["catppuccin"] = template
66
  pio.templates["plotly_white"] = template
67
  pio.templates.default = "catppuccin"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from catppuccin import PALETTE
6
 
7
 
8
+ def _flavor(base: str | None):
9
+ return PALETTE.latte if base == "light" else PALETTE.mocha
10
+
11
+
12
  def install_catppuccin_theme(base: str | None = None) -> None:
13
  """Register a Catppuccin template and alias it as ``plotly_white``.
14
 
 
16
  every figure, so replacing that entry themes all plots without any
17
  per-figure code.
18
  """
19
+ c = _flavor(base).colors
 
20
  bg, surface, line = c.base.hex, c.surface0.hex, c.surface1.hex
21
  text, subtext = c.text.hex, c.subtext1.hex
22
 
 
68
  pio.templates["catppuccin"] = template
69
  pio.templates["plotly_white"] = template
70
  pio.templates.default = "catppuccin"
71
+
72
+
73
+ def style_plotly_layer_controls(fig: go.Figure, base: str | None = None) -> go.Figure:
74
+ """Theme persona-vectors layer sliders/buttons for this Streamlit app."""
75
+ c = _flavor(base).colors
76
+ surface = c.surface0.hex
77
+ overlay = c.surface1.hex
78
+ text = c.text.hex
79
+ primary = c.blue.hex
80
+
81
+ for slider in fig.layout.sliders:
82
+ slider.bgcolor = surface
83
+ slider.activebgcolor = primary
84
+ slider.bordercolor = overlay
85
+ slider.borderwidth = 1
86
+ slider.font = dict(color=text, size=11)
87
+ slider.tickcolor = primary
88
+ slider.currentvalue = dict(
89
+ slider.currentvalue.to_plotly_json(),
90
+ font=dict(color=text, size=13),
91
+ )
92
+
93
+ for menu in fig.layout.updatemenus:
94
+ if menu.type != "buttons":
95
+ continue
96
+ menu.bgcolor = surface
97
+ menu.bordercolor = overlay
98
+ menu.font = dict(color=text, size=13)
99
+
100
+ return fig