Jac-Zac commited on
Commit ·
1b16c40
1
Parent(s): d39b2dd
Small ui cleanups
Browse files- tabs/compare.py +22 -7
- utils/theme.py +35 -2
tabs/compare.py
CHANGED
|
@@ -3,6 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
from itertools import combinations
|
| 4 |
from pathlib import Path
|
| 5 |
|
|
|
|
| 6 |
import streamlit as st
|
| 7 |
from persona_data.environment import get_artifacts_dir
|
| 8 |
from persona_data.synth_persona import BASELINE_PERSONA_ID
|
|
@@ -44,6 +45,7 @@ from utils.helpers import (
|
|
| 44 |
slugify,
|
| 45 |
widget_key,
|
| 46 |
)
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
def _filename(*parts: str) -> str:
|
|
@@ -295,12 +297,25 @@ def _render_save_buttons(
|
|
| 295 |
"""Render the Save HTML button for one or more figures."""
|
| 296 |
if st.button("Save HTML", key=widget_key("load", "save_html", key_suffix)):
|
| 297 |
try:
|
|
|
|
| 298 |
paths = [save_plot_html(fig, fn) for fig, fn in zip(figs, filenames)]
|
| 299 |
st.success(f"Saved {len(paths)} HTML file(s) to `artifacts/plots`.")
|
| 300 |
except Exception as exc:
|
| 301 |
st.error(f"Could not save HTML: {exc}")
|
| 302 |
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
def _render_mask_strategy_select(scope: str) -> MaskStrategy:
|
| 305 |
return render_mask_strategy_select(
|
| 306 |
key=widget_key("load", "mask_strategy", scope),
|
|
@@ -494,12 +509,12 @@ def _render_cosine_similarity(
|
|
| 494 |
|
| 495 |
if cosine_fig_key in st.session_state:
|
| 496 |
fig, pair_fig, n_traces, n_pair_traces = st.session_state[cosine_fig_key]
|
| 497 |
-
|
| 498 |
figs = [fig]
|
| 499 |
filenames = [filename]
|
| 500 |
if pair_fig is not None:
|
| 501 |
st.subheader("Variant pairs")
|
| 502 |
-
|
| 503 |
figs.append(pair_fig)
|
| 504 |
filenames.append(pairs_filename)
|
| 505 |
_render_save_buttons(figs, filenames, "cosine")
|
|
@@ -662,12 +677,12 @@ def _render_layered_figure_analysis(
|
|
| 662 |
|
| 663 |
if fig_key in st.session_state:
|
| 664 |
main_fig, extra_fig, n_samples = st.session_state[fig_key]
|
| 665 |
-
|
| 666 |
figs = [main_fig]
|
| 667 |
filenames = [filename]
|
| 668 |
if extra_fig is not None:
|
| 669 |
st.subheader("Pair trajectories")
|
| 670 |
-
|
| 671 |
figs.append(extra_fig)
|
| 672 |
filenames.append(f"{filename}__pair_trajectories")
|
| 673 |
_render_save_buttons(figs, filenames, scope)
|
|
@@ -816,12 +831,12 @@ def _render_dendrogram_analysis(
|
|
| 816 |
col_a, col_b = st.columns(2)
|
| 817 |
with col_a:
|
| 818 |
st.subheader(prompt_variant_label(va))
|
| 819 |
-
|
| 820 |
with col_b:
|
| 821 |
st.subheader(prompt_variant_label(vb))
|
| 822 |
-
|
| 823 |
else:
|
| 824 |
-
|
| 825 |
|
| 826 |
figs = [fig_a] + ([fig_b] if fig_b else [])
|
| 827 |
filenames = [
|
|
|
|
| 3 |
from itertools import combinations
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
import streamlit as st
|
| 8 |
from persona_data.environment import get_artifacts_dir
|
| 9 |
from persona_data.synth_persona import BASELINE_PERSONA_ID
|
|
|
|
| 45 |
slugify,
|
| 46 |
widget_key,
|
| 47 |
)
|
| 48 |
+
from utils.theme import style_plotly_layer_controls
|
| 49 |
|
| 50 |
|
| 51 |
def _filename(*parts: str) -> str:
|
|
|
|
| 297 |
"""Render the Save HTML button for one or more figures."""
|
| 298 |
if st.button("Save HTML", key=widget_key("load", "save_html", key_suffix)):
|
| 299 |
try:
|
| 300 |
+
_style_plotly_figures(figs)
|
| 301 |
paths = [save_plot_html(fig, fn) for fig, fn in zip(figs, filenames)]
|
| 302 |
st.success(f"Saved {len(paths)} HTML file(s) to `artifacts/plots`.")
|
| 303 |
except Exception as exc:
|
| 304 |
st.error(f"Could not save HTML: {exc}")
|
| 305 |
|
| 306 |
|
| 307 |
+
def _style_plotly_figures(figs: list[object]) -> None:
|
| 308 |
+
base = st.get_option("theme.base")
|
| 309 |
+
for fig in figs:
|
| 310 |
+
if isinstance(fig, go.Figure):
|
| 311 |
+
style_plotly_layer_controls(fig, base)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _plotly_chart(fig: object) -> None:
|
| 315 |
+
_style_plotly_figures([fig])
|
| 316 |
+
st.plotly_chart(fig, width="stretch")
|
| 317 |
+
|
| 318 |
+
|
| 319 |
def _render_mask_strategy_select(scope: str) -> MaskStrategy:
|
| 320 |
return render_mask_strategy_select(
|
| 321 |
key=widget_key("load", "mask_strategy", scope),
|
|
|
|
| 509 |
|
| 510 |
if cosine_fig_key in st.session_state:
|
| 511 |
fig, pair_fig, n_traces, n_pair_traces = st.session_state[cosine_fig_key]
|
| 512 |
+
_plotly_chart(fig)
|
| 513 |
figs = [fig]
|
| 514 |
filenames = [filename]
|
| 515 |
if pair_fig is not None:
|
| 516 |
st.subheader("Variant pairs")
|
| 517 |
+
_plotly_chart(pair_fig)
|
| 518 |
figs.append(pair_fig)
|
| 519 |
filenames.append(pairs_filename)
|
| 520 |
_render_save_buttons(figs, filenames, "cosine")
|
|
|
|
| 677 |
|
| 678 |
if fig_key in st.session_state:
|
| 679 |
main_fig, extra_fig, n_samples = st.session_state[fig_key]
|
| 680 |
+
_plotly_chart(main_fig)
|
| 681 |
figs = [main_fig]
|
| 682 |
filenames = [filename]
|
| 683 |
if extra_fig is not None:
|
| 684 |
st.subheader("Pair trajectories")
|
| 685 |
+
_plotly_chart(extra_fig)
|
| 686 |
figs.append(extra_fig)
|
| 687 |
filenames.append(f"{filename}__pair_trajectories")
|
| 688 |
_render_save_buttons(figs, filenames, scope)
|
|
|
|
| 831 |
col_a, col_b = st.columns(2)
|
| 832 |
with col_a:
|
| 833 |
st.subheader(prompt_variant_label(va))
|
| 834 |
+
_plotly_chart(fig_a)
|
| 835 |
with col_b:
|
| 836 |
st.subheader(prompt_variant_label(vb))
|
| 837 |
+
_plotly_chart(fig_b)
|
| 838 |
else:
|
| 839 |
+
_plotly_chart(fig_a)
|
| 840 |
|
| 841 |
figs = [fig_a] + ([fig_b] if fig_b else [])
|
| 842 |
filenames = [
|
utils/theme.py
CHANGED
|
@@ -5,6 +5,10 @@ import plotly.io as pio
|
|
| 5 |
from catppuccin import PALETTE
|
| 6 |
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
def install_catppuccin_theme(base: str | None = None) -> None:
|
| 9 |
"""Register a Catppuccin template and alias it as ``plotly_white``.
|
| 10 |
|
|
@@ -12,8 +16,7 @@ def install_catppuccin_theme(base: str | None = None) -> None:
|
|
| 12 |
every figure, so replacing that entry themes all plots without any
|
| 13 |
per-figure code.
|
| 14 |
"""
|
| 15 |
-
|
| 16 |
-
c = flavor.colors
|
| 17 |
bg, surface, line = c.base.hex, c.surface0.hex, c.surface1.hex
|
| 18 |
text, subtext = c.text.hex, c.subtext1.hex
|
| 19 |
|
|
@@ -65,3 +68,33 @@ def install_catppuccin_theme(base: str | None = None) -> None:
|
|
| 65 |
pio.templates["catppuccin"] = template
|
| 66 |
pio.templates["plotly_white"] = template
|
| 67 |
pio.templates.default = "catppuccin"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from catppuccin import PALETTE
|
| 6 |
|
| 7 |
|
| 8 |
+
def _flavor(base: str | None):
|
| 9 |
+
return PALETTE.latte if base == "light" else PALETTE.mocha
|
| 10 |
+
|
| 11 |
+
|
| 12 |
def install_catppuccin_theme(base: str | None = None) -> None:
|
| 13 |
"""Register a Catppuccin template and alias it as ``plotly_white``.
|
| 14 |
|
|
|
|
| 16 |
every figure, so replacing that entry themes all plots without any
|
| 17 |
per-figure code.
|
| 18 |
"""
|
| 19 |
+
c = _flavor(base).colors
|
|
|
|
| 20 |
bg, surface, line = c.base.hex, c.surface0.hex, c.surface1.hex
|
| 21 |
text, subtext = c.text.hex, c.subtext1.hex
|
| 22 |
|
|
|
|
| 68 |
pio.templates["catppuccin"] = template
|
| 69 |
pio.templates["plotly_white"] = template
|
| 70 |
pio.templates.default = "catppuccin"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def style_plotly_layer_controls(fig: go.Figure, base: str | None = None) -> go.Figure:
|
| 74 |
+
"""Theme persona-vectors layer sliders/buttons for this Streamlit app."""
|
| 75 |
+
c = _flavor(base).colors
|
| 76 |
+
surface = c.surface0.hex
|
| 77 |
+
overlay = c.surface1.hex
|
| 78 |
+
text = c.text.hex
|
| 79 |
+
primary = c.blue.hex
|
| 80 |
+
|
| 81 |
+
for slider in fig.layout.sliders:
|
| 82 |
+
slider.bgcolor = surface
|
| 83 |
+
slider.activebgcolor = primary
|
| 84 |
+
slider.bordercolor = overlay
|
| 85 |
+
slider.borderwidth = 1
|
| 86 |
+
slider.font = dict(color=text, size=11)
|
| 87 |
+
slider.tickcolor = primary
|
| 88 |
+
slider.currentvalue = dict(
|
| 89 |
+
slider.currentvalue.to_plotly_json(),
|
| 90 |
+
font=dict(color=text, size=13),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
for menu in fig.layout.updatemenus:
|
| 94 |
+
if menu.type != "buttons":
|
| 95 |
+
continue
|
| 96 |
+
menu.bgcolor = surface
|
| 97 |
+
menu.bordercolor = overlay
|
| 98 |
+
menu.font = dict(color=text, size=13)
|
| 99 |
+
|
| 100 |
+
return fig
|