Jac-Zac commited on
Commit
9ac8f1c
·
1 Parent(s): d8ae160

Adding color theme option

Browse files

Support light and dark color themes

.streamlit/config.toml CHANGED
@@ -1,9 +1,20 @@
1
- # Catppuccin Mocha theme. Switch base to "light" and swap the four colors
2
- # below to the Latte equivalents (see utils/theme.py) for the light flavor.
 
3
  [theme]
4
  base = "dark"
5
- primaryColor = "#89b4fa" # Mocha blue
6
- backgroundColor = "#1e1e2e" # base
7
- secondaryBackgroundColor = "#313244" # surface0
8
- textColor = "#cdd6f4" # text
9
  font = "sans serif"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Catppuccin theme. `base` is the default flavor; the built-in Streamlit
2
+ # Settings menu (Light / Dark / System) switches between the two variants
3
+ # below. Latte = light, Mocha = dark. Keep in sync with utils/theme.py.
4
  [theme]
5
  base = "dark"
 
 
 
 
6
  font = "sans serif"
7
+
8
+ [theme.light] # Catppuccin Latte (style-guide mapping)
9
+ primaryColor = "#8839ef" # mauve - accent (Catppuccin default)
10
+ backgroundColor = "#eff1f5" # base - background pane
11
+ secondaryBackgroundColor = "#dce0e8" # crust - secondary panes
12
+ textColor = "#4c4f69" # text - body copy
13
+ linkColor = "#1e66f5" # blue - links/URLs
14
+ borderColor = "#9ca0b0" # overlay0 - inactive border
15
+
16
+ [theme.dark] # Catppuccin Mocha (unchanged)
17
+ primaryColor = "#89b4fa" # blue
18
+ backgroundColor = "#1e1e2e" # base
19
+ secondaryBackgroundColor = "#313244" # surface0
20
+ textColor = "#cdd6f4" # text
app.py CHANGED
@@ -12,7 +12,7 @@ from utils.analysis_sources import (
12
  from utils.helpers import DATASET_SOURCES, session_key, widget_key
13
  from utils.preload import preload_once
14
  from utils.runtime import list_remote_models
15
- from utils.theme import install_catppuccin_theme
16
 
17
  load_dotenv()
18
  DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
@@ -233,7 +233,7 @@ def main() -> None:
233
  """Run the Streamlit app."""
234
 
235
  st.set_page_config(page_title="Persona UI", layout="wide")
236
- install_catppuccin_theme(st.get_option("theme.base"))
237
 
238
  sidebar = _sidebar_controls()
239
 
 
12
  from utils.helpers import DATASET_SOURCES, session_key, widget_key
13
  from utils.preload import preload_once
14
  from utils.runtime import list_remote_models
15
+ from utils.theme import active_base, install_catppuccin_theme
16
 
17
  load_dotenv()
18
  DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
 
233
  """Run the Streamlit app."""
234
 
235
  st.set_page_config(page_title="Persona UI", layout="wide")
236
+ install_catppuccin_theme(active_base())
237
 
238
  sidebar = _sidebar_controls()
239
 
tabs/analysis_core.py CHANGED
@@ -58,7 +58,7 @@ from utils.helpers import (
58
  slugify,
59
  widget_key,
60
  )
61
- from utils.theme import style_plotly_layer_controls
62
 
63
 
64
  def _filename(*parts: str) -> str:
@@ -586,7 +586,7 @@ def _render_save_buttons(
586
 
587
 
588
  def _style_plotly_figures(figs: list[object]) -> None:
589
- base = st.get_option("theme.base")
590
  for fig in figs:
591
  if isinstance(fig, go.Figure):
592
  style_plotly_layer_controls(fig, base)
 
58
  slugify,
59
  widget_key,
60
  )
61
+ from utils.theme import active_base, style_plotly_layer_controls
62
 
63
 
64
  def _filename(*parts: str) -> str:
 
586
 
587
 
588
  def _style_plotly_figures(figs: list[object]) -> None:
589
+ base = active_base()
590
  for fig in figs:
591
  if isinstance(fig, go.Figure):
592
  style_plotly_layer_controls(fig, base)
tabs/extract.py CHANGED
@@ -14,7 +14,11 @@ from persona_vectors.extraction import (
14
  from persona_vectors.preview import TokenSegment, preview_token_segments
15
 
16
  from utils.controls import render_mask_strategy_select
17
- from utils.datasets import load_dataset, load_persona_list_from_dataset
 
 
 
 
18
  from utils.helpers import (
19
  NDIF_STATUS_ICONS,
20
  persona_label,
@@ -23,6 +27,7 @@ from utils.helpers import (
23
  widget_key,
24
  )
25
  from utils.runtime import cached_model
 
26
 
27
  _LAST_VARIANTS_KEY = "extract:last_variants"
28
  _LAST_BASELINE_KEY = "extract:last_include_baseline"
@@ -138,6 +143,10 @@ def _load_qa_dataset_personas(
138
  "Try another dataset source or check that the personas file is not empty."
139
  )
140
  return None
 
 
 
 
141
  return dataset, personas
142
 
143
 
@@ -171,7 +180,7 @@ _MAX_PREVIEW_SAMPLES = 3
171
 
172
 
173
  def _preview_palette():
174
- flavor = PALETTE.latte if st.get_option("theme.base") == "light" else PALETTE.mocha
175
  return flavor.colors
176
 
177
 
 
14
  from persona_vectors.preview import TokenSegment, preview_token_segments
15
 
16
  from utils.controls import render_mask_strategy_select
17
+ from utils.datasets import (
18
+ load_dataset,
19
+ load_persona_list_from_dataset,
20
+ warm_qa_in_background,
21
+ )
22
  from utils.helpers import (
23
  NDIF_STATUS_ICONS,
24
  persona_label,
 
27
  widget_key,
28
  )
29
  from utils.runtime import cached_model
30
+ from utils.theme import active_base
31
 
32
  _LAST_VARIANTS_KEY = "extract:last_variants"
33
  _LAST_BASELINE_KEY = "extract:last_include_baseline"
 
143
  "Try another dataset source or check that the personas file is not empty."
144
  )
145
  return None
146
+
147
+ # Extract is the only tab that needs QA; warm it now so the parse overlaps
148
+ # with the user configuring the run instead of blocking the first extract.
149
+ warm_qa_in_background(dataset)
150
  return dataset, personas
151
 
152
 
 
180
 
181
 
182
  def _preview_palette():
183
+ flavor = PALETTE.latte if active_base() == "light" else PALETTE.mocha
184
  return flavor.colors
185
 
186
 
utils/contrast.py CHANGED
@@ -243,10 +243,24 @@ def render_contrast_html(result: TokenContrast) -> str:
243
  Render each token with a colored background reflecting how A- or B-specific
244
  it is, with a hover tooltip showing the raw Δlog P, plus a legend.
245
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  spans: list[str] = []
247
- for token, weight, raw in zip(
248
- result.tokens, result.weights, result.raw_diffs, strict=True
249
- ):
250
  bg = _weight_to_bg(weight)
251
  tip = escape(f"Δlog P(A−B): {raw:+.3f}")
252
  text = escape(token)
 
243
  Render each token with a colored background reflecting how A- or B-specific
244
  it is, with a hover tooltip showing the raw Δlog P, plus a legend.
245
  """
246
+ # The model often opens a response with newline tokens; under pre-wrap
247
+ # those render as blank lines before the first word. Drop leading
248
+ # whitespace-only tokens (and left-trim the first visible one) so the
249
+ # contrast starts at real content. Display-only — weights stay aligned.
250
+ items = list(
251
+ zip(result.tokens, result.weights, result.raw_diffs, strict=True)
252
+ )
253
+ start = 0
254
+ while start < len(items) and not items[start][0].strip():
255
+ start += 1
256
+ if start >= len(items):
257
+ start = 0 # all-whitespace response: render as-is, not blank
258
+ items = items[start:]
259
+
260
  spans: list[str] = []
261
+ for idx, (token, weight, raw) in enumerate(items):
262
+ if idx == 0:
263
+ token = token.lstrip()
264
  bg = _weight_to_bg(weight)
265
  tip = escape(f"Δlog P(A−B): {raw:+.3f}")
266
  text = escape(token)
utils/datasets.py CHANGED
@@ -1,6 +1,7 @@
1
  import atexit
2
  import hashlib
3
  import shutil
 
4
  from pathlib import Path
5
  from tempfile import mkdtemp
6
  from typing import Any
@@ -23,6 +24,30 @@ def _cached_dataset(cls: type) -> Any:
23
  return cls()
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  @st.cache_resource(show_spinner=False)
27
  def _cached_local_dataset(personas_path: str, qa_path: str) -> LocalPersonaDataset:
28
  """Instantiate and cache a local upload dataset for stable temp paths."""
 
1
  import atexit
2
  import hashlib
3
  import shutil
4
+ import threading
5
  from pathlib import Path
6
  from tempfile import mkdtemp
7
  from typing import Any
 
24
  return cls()
25
 
26
 
27
+ _qa_warm_lock = threading.Lock()
28
+
29
+
30
+ def warm_qa_in_background(dataset: Any) -> None:
31
+ """Trigger the dataset's lazy QA parse on a daemon thread, once.
32
+
33
+ QA loading is deferred in persona-data (large, unused outside Extract).
34
+ Kicking it off when the Extract tab opens means the parse overlaps with
35
+ the user picking personas/options instead of blocking the first run.
36
+ Idempotent across Streamlit reruns: guarded per cached dataset instance.
37
+ """
38
+
39
+ warm = getattr(dataset, "_load_qa", None)
40
+ if warm is None:
41
+ return # persona-only dataset (e.g. Nemotron) has no QA
42
+ with _qa_warm_lock:
43
+ if getattr(dataset, "_qa_warm_started", False):
44
+ return
45
+ dataset._qa_warm_started = True
46
+ threading.Thread(
47
+ target=warm, name="persona-ui-warm-qa", daemon=True
48
+ ).start()
49
+
50
+
51
  @st.cache_resource(show_spinner=False)
52
  def _cached_local_dataset(personas_path: str, qa_path: str) -> LocalPersonaDataset:
53
  """Instantiate and cache a local upload dataset for stable temp paths."""
utils/theme.py CHANGED
@@ -2,9 +2,19 @@
2
 
3
  import plotly.graph_objects as go
4
  import plotly.io as pio
 
5
  from catppuccin import PALETTE
6
 
7
 
 
 
 
 
 
 
 
 
 
8
  def _flavor(base: str | None):
9
  return PALETTE.latte if base == "light" else PALETTE.mocha
10
 
@@ -17,8 +27,14 @@ def install_catppuccin_theme(base: str | None = None) -> None:
17
  per-figure code.
18
  """
19
  c = _flavor(base).colors
20
- bg, surface, line = c.base.hex, c.surface0.hex, c.surface1.hex
21
- text, subtext = c.text.hex, c.subtext1.hex
 
 
 
 
 
 
22
 
23
  axis = dict(
24
  gridcolor=line,
@@ -61,7 +77,7 @@ def install_catppuccin_theme(base: str | None = None) -> None:
61
  scene=dict(xaxis=scene_axis, yaxis=scene_axis, zaxis=scene_axis),
62
  legend=dict(bgcolor=surface, bordercolor=line, font=dict(color=text)),
63
  colorscale=dict(
64
- diverging=[[0.0, c.blue.hex], [0.5, surface], [1.0, c.red.hex]],
65
  ),
66
  )
67
  )
 
2
 
3
  import plotly.graph_objects as go
4
  import plotly.io as pio
5
+ import streamlit as st
6
  from catppuccin import PALETTE
7
 
8
 
9
+ def active_base() -> str:
10
+ """Return the live theme flavor ("light"/"dark").
11
+
12
+ Reflects the user's choice in Streamlit's built-in Settings-menu theme
13
+ toggle, falling back to the configured ``theme.base`` default.
14
+ """
15
+ return st.context.theme.type or st.get_option("theme.base")
16
+
17
+
18
  def _flavor(base: str | None):
19
  return PALETTE.latte if base == "light" else PALETTE.mocha
20
 
 
27
  per-figure code.
28
  """
29
  c = _flavor(base).colors
30
+ bg, text, subtext = c.base.hex, c.text.hex, c.subtext1.hex
31
+ if base == "light":
32
+ # Latte's surface0/1 are heavy grays; use lighter slots so the grid
33
+ # stays subtle, the legend isn't a gray block, and diverging plots
34
+ # fade through near-white instead of mud.
35
+ surface, line, mid = c.mantle.hex, c.surface0.hex, c.base.hex
36
+ else:
37
+ surface, line, mid = c.surface0.hex, c.surface1.hex, c.surface0.hex
38
 
39
  axis = dict(
40
  gridcolor=line,
 
77
  scene=dict(xaxis=scene_axis, yaxis=scene_axis, zaxis=scene_axis),
78
  legend=dict(bgcolor=surface, bordercolor=line, font=dict(color=text)),
79
  colorscale=dict(
80
+ diverging=[[0.0, c.blue.hex], [0.5, mid], [1.0, c.red.hex]],
81
  ),
82
  )
83
  )