Jac-Zac commited on
Commit
c607869
·
1 Parent(s): db3d901

Performance improvements

Browse files
.env.example CHANGED
@@ -18,3 +18,8 @@ ARTIFACTS_DIR=artifacts
18
  # Default model IDs shown in the sidebar (optional — change to override the built-in defaults)
19
  # DEFAULT_MODEL=google/gemma-2-2b-it
20
  # REMOTE_DEFAULT_MODEL=google/gemma-2-9b-it
 
 
 
 
 
 
18
  # Default model IDs shown in the sidebar (optional — change to override the built-in defaults)
19
  # DEFAULT_MODEL=google/gemma-2-2b-it
20
  # REMOTE_DEFAULT_MODEL=google/gemma-2-9b-it
21
+
22
+ # Cache sizing knobs (optional)
23
+ # Keep model cache at 1 unless you have enough RAM for multiple loaded models.
24
+ # PERSONA_UI_MODEL_CACHE_ENTRIES=1
25
+ # PERSONA_UI_STORE_CACHE_ENTRIES=4
app.py CHANGED
@@ -5,6 +5,7 @@ import streamlit as st
5
  from dotenv import load_dotenv
6
 
7
  from utils.helpers import DATASET_SOURCES, session_key
 
8
  from utils.runtime import list_remote_models
9
  from utils.theme import install_catppuccin_theme
10
 
@@ -28,6 +29,15 @@ _SIDEBAR_DATASET_SOURCE_KEY = session_key("sidebar", "dataset_source")
28
 
29
  _TABS = ["Chat", "Analysis", "Extract"]
30
  _TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  @dataclass(frozen=True)
@@ -181,6 +191,12 @@ def main() -> None:
181
 
182
  render_chat_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
183
 
 
 
 
 
 
 
184
 
185
  if __name__ == "__main__":
186
  main()
 
5
  from dotenv import load_dotenv
6
 
7
  from utils.helpers import DATASET_SOURCES, session_key
8
+ from utils.preload import preload_once
9
  from utils.runtime import list_remote_models
10
  from utils.theme import install_catppuccin_theme
11
 
 
29
 
30
  _TABS = ["Chat", "Analysis", "Extract"]
31
  _TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
32
+ _TAB_PRELOAD_MODULES = {
33
+ "Chat": ("tabs.analysis_core", "tabs.extract", "tabs.compare_chat"),
34
+ "Analysis": ("tabs.chat", "tabs.extract"),
35
+ "Extract": ("tabs.chat", "tabs.analysis_core"),
36
+ }
37
+ _TAB_PRELOAD_FUNCTIONS = {
38
+ "Chat": ("utils.analysis_metadata:synth_persona_attribute_names",),
39
+ "Extract": ("utils.analysis_metadata:synth_persona_attribute_names",),
40
+ }
41
 
42
 
43
  @dataclass(frozen=True)
 
191
 
192
  render_chat_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
193
 
194
+ preload_once(
195
+ f"after-{sidebar.active_tab.lower()}",
196
+ modules=_TAB_PRELOAD_MODULES.get(sidebar.active_tab, ()),
197
+ functions=_TAB_PRELOAD_FUNCTIONS.get(sidebar.active_tab, ()),
198
+ )
199
+
200
 
201
  if __name__ == "__main__":
202
  main()
tabs/analysis_core.py CHANGED
@@ -7,7 +7,7 @@ from pathlib import Path
7
  import plotly.graph_objects as go
8
  import streamlit as st
9
  from persona_data.environment import get_artifacts_dir
10
- from persona_data.synth_persona import BASELINE_PERSONA_ID, SynthPersonaDataset
11
  from persona_vectors.attributes import (
12
  DEFAULT_MAX_ATTRIBUTE_CATEGORIES,
13
  attribute_color_kwargs,
@@ -45,6 +45,10 @@ from utils.analysis_sources import (
45
  store_id,
46
  store_layers_cached,
47
  )
 
 
 
 
48
  from utils.controls import render_mask_strategy_select
49
  from utils.helpers import (
50
  ANALYSIS_HELP_TEXT,
@@ -99,9 +103,8 @@ _PROJECTION_COLOR_MODES = ["Persona", "K-means clusters", "Persona attribute"]
99
  _MAX_ATTRIBUTE_CATEGORIES = DEFAULT_MAX_ATTRIBUTE_CATEGORIES
100
 
101
 
102
- @st.cache_resource(show_spinner=False)
103
- def _synth_persona_dataset() -> SynthPersonaDataset:
104
- return SynthPersonaDataset()
105
 
106
 
107
  def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
@@ -983,7 +986,7 @@ def _render_projection_color_config(
983
 
984
  if color_mode == "Persona attribute":
985
  persona_dataset = _synth_persona_dataset()
986
- attribute_options = list(persona_dataset.attribute_names)
987
  if not attribute_options:
988
  st.info("No persona attributes are available for this dataset.")
989
  return None
 
7
  import plotly.graph_objects as go
8
  import streamlit as st
9
  from persona_data.environment import get_artifacts_dir
10
+ from persona_data.synth_persona import BASELINE_PERSONA_ID
11
  from persona_vectors.attributes import (
12
  DEFAULT_MAX_ATTRIBUTE_CATEGORIES,
13
  attribute_color_kwargs,
 
45
  store_id,
46
  store_layers_cached,
47
  )
48
+ from utils.analysis_metadata import (
49
+ synth_persona_attribute_names,
50
+ synth_persona_dataset_cached,
51
+ )
52
  from utils.controls import render_mask_strategy_select
53
  from utils.helpers import (
54
  ANALYSIS_HELP_TEXT,
 
103
  _MAX_ATTRIBUTE_CATEGORIES = DEFAULT_MAX_ATTRIBUTE_CATEGORIES
104
 
105
 
106
+ def _synth_persona_dataset():
107
+ return synth_persona_dataset_cached()
 
108
 
109
 
110
  def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
 
986
 
987
  if color_mode == "Persona attribute":
988
  persona_dataset = _synth_persona_dataset()
989
+ attribute_options = list(synth_persona_attribute_names())
990
  if not attribute_options:
991
  st.info("No persona attributes are available for this dataset.")
992
  return None
tabs/chat.py CHANGED
@@ -1,9 +1,8 @@
1
  from __future__ import annotations
2
 
3
- from typing import cast
4
 
5
  import streamlit as st
6
- from persona_data.synth_persona import PersonaData
7
 
8
  from state import (
9
  ChatState,
@@ -29,6 +28,9 @@ from utils.chat_export import save_chat_export
29
  from utils.helpers import session_key, widget_key
30
  from utils.runtime import cached_model
31
 
 
 
 
32
  _LAST_PERSONA_ID_KEY = session_key("chat", "last_persona_id")
33
  _LAST_PROMPT_MODE_KEY = session_key("chat", "last_prompt_mode")
34
  _LAST_COMPARE_MODE_KEY = session_key("chat", "last_compare_mode")
 
1
  from __future__ import annotations
2
 
3
+ from typing import TYPE_CHECKING, cast
4
 
5
  import streamlit as st
 
6
 
7
  from state import (
8
  ChatState,
 
28
  from utils.helpers import session_key, widget_key
29
  from utils.runtime import cached_model
30
 
31
+ if TYPE_CHECKING:
32
+ from persona_data.synth_persona import PersonaData
33
+
34
  _LAST_PERSONA_ID_KEY = session_key("chat", "last_persona_id")
35
  _LAST_PROMPT_MODE_KEY = session_key("chat", "last_prompt_mode")
36
  _LAST_COMPARE_MODE_KEY = session_key("chat", "last_compare_mode")
tabs/chat_shared.py CHANGED
@@ -2,9 +2,9 @@ from __future__ import annotations
2
 
3
  from collections.abc import Callable
4
  from dataclasses import dataclass
 
5
 
6
  import streamlit as st
7
- from persona_data.synth_persona import PersonaData
8
 
9
  from state import ChatState
10
  from tabs.chat_ui import GenerationConfig, render_persona_prompt_controls
@@ -12,6 +12,9 @@ from utils.chat import ChatReply, generate_chat_reply
12
  from utils.datasets import load_persona_list
13
  from utils.helpers import session_key
14
 
 
 
 
15
 
16
  @dataclass(frozen=True)
17
  class ChatSelection:
 
2
 
3
  from collections.abc import Callable
4
  from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING
6
 
7
  import streamlit as st
 
8
 
9
  from state import ChatState
10
  from tabs.chat_ui import GenerationConfig, render_persona_prompt_controls
 
12
  from utils.datasets import load_persona_list
13
  from utils.helpers import session_key
14
 
15
+ if TYPE_CHECKING:
16
+ from persona_data.synth_persona import PersonaData
17
+
18
 
19
  @dataclass(frozen=True)
20
  class ChatSelection:
tabs/chat_ui.py CHANGED
@@ -5,7 +5,6 @@ from dataclasses import asdict, dataclass
5
  from typing import TYPE_CHECKING, Any
6
 
7
  import streamlit as st
8
- from persona_data.synth_persona import PersonaData
9
 
10
  from utils.helpers import (
11
  CHAT_PROMPT_MODE_LABEL_TO_KEY,
@@ -16,6 +15,7 @@ from utils.helpers import (
16
  )
17
 
18
  if TYPE_CHECKING:
 
19
  from utils.contrast import TokenContrast
20
 
21
  GENERATION_DEFAULTS = {
 
5
  from typing import TYPE_CHECKING, Any
6
 
7
  import streamlit as st
 
8
 
9
  from utils.helpers import (
10
  CHAT_PROMPT_MODE_LABEL_TO_KEY,
 
15
  )
16
 
17
  if TYPE_CHECKING:
18
+ from persona_data.synth_persona import PersonaData
19
  from utils.contrast import TokenContrast
20
 
21
  GENERATION_DEFAULTS = {
tabs/compare_chat.py CHANGED
@@ -1,9 +1,9 @@
 
 
1
  from dataclasses import dataclass
2
- from typing import Any
3
 
4
  import streamlit as st
5
- from nnterp import StandardizedTransformer
6
- from persona_data.synth_persona import PersonaData
7
 
8
  from state import ChatState, default_chat_state, reset_chat_context_state
9
  from tabs.chat_shared import (
@@ -24,6 +24,10 @@ from .chat_ui import (
24
  render_system_prompt,
25
  )
26
 
 
 
 
 
27
 
28
  @dataclass(frozen=True)
29
  class ComparePanel:
 
1
+ from __future__ import annotations
2
+
3
  from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Any
5
 
6
  import streamlit as st
 
 
7
 
8
  from state import ChatState, default_chat_state, reset_chat_context_state
9
  from tabs.chat_shared import (
 
24
  render_system_prompt,
25
  )
26
 
27
+ if TYPE_CHECKING:
28
+ from nnterp import StandardizedTransformer
29
+ from persona_data.synth_persona import PersonaData
30
+
31
 
32
  @dataclass(frozen=True)
33
  class ComparePanel:
utils/analysis_metadata.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from functools import lru_cache
4
+ from typing import Any
5
+
6
+
7
+ @lru_cache(maxsize=1)
8
+ def synth_persona_dataset_cached() -> Any:
9
+ from persona_data.synth_persona import SynthPersonaDataset
10
+
11
+ return SynthPersonaDataset()
12
+
13
+
14
+ @lru_cache(maxsize=1)
15
+ def synth_persona_attribute_names() -> tuple[str, ...]:
16
+ return tuple(synth_persona_dataset_cached().attribute_names)
utils/analysis_sources.py CHANGED
@@ -11,6 +11,8 @@ from persona_vectors.artifacts import (
11
  from persona_vectors.extraction import MaskStrategy
12
  from persona_vectors.hub import list_hub_vector_models
13
 
 
 
14
  Store = ActivationStore | HFActivationStore
15
 
16
  DEFAULT_HUB_REPO = os.environ.get(
@@ -23,7 +25,10 @@ SOURCE_LOCAL = "Local activations"
23
  SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
24
 
25
 
26
- @st.cache_resource(show_spinner=False, max_entries=1)
 
 
 
27
  def activation_store_cached(
28
  source: str,
29
  location: str,
 
11
  from persona_vectors.extraction import MaskStrategy
12
  from persona_vectors.hub import list_hub_vector_models
13
 
14
+ from utils.helpers import env_int
15
+
16
  Store = ActivationStore | HFActivationStore
17
 
18
  DEFAULT_HUB_REPO = os.environ.get(
 
25
  SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
26
 
27
 
28
+ _STORE_CACHE_ENTRIES = env_int("PERSONA_UI_STORE_CACHE_ENTRIES", 4)
29
+
30
+
31
+ @st.cache_resource(show_spinner=False, max_entries=_STORE_CACHE_ENTRIES)
32
  def activation_store_cached(
33
  source: str,
34
  location: str,
utils/chat.py CHANGED
@@ -3,14 +3,14 @@ from __future__ import annotations
3
  import logging
4
  from contextlib import contextmanager, nullcontext
5
  from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, Literal
7
 
8
- import torch
9
  from persona_data.prompts import format_messages, format_prompt, normalize_messages
10
- from persona_data.synth_persona import PersonaData
11
 
12
  if TYPE_CHECKING:
 
13
  from nnterp import StandardizedTransformer
 
14
 
15
  logger = logging.getLogger(__name__)
16
  SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
@@ -19,7 +19,7 @@ SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
19
  @dataclass
20
  class ChatReply:
21
  text: str
22
- generated_ids: torch.Tensor | None = None
23
 
24
 
25
  def build_chat_messages(
@@ -133,6 +133,8 @@ def format_generation_prompt(
133
 
134
  def resolve_saved_tensor(value: object) -> torch.Tensor:
135
  """Resolve an nnsight ``.save()`` proxy (or raw tensor) to a CPU tensor."""
 
 
136
  resolved = value.value if getattr(value, "value", None) is not None else value
137
  if not isinstance(resolved, torch.Tensor):
138
  raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}")
@@ -158,6 +160,8 @@ def _seeded_rng(seed: int | None):
158
  yield
159
  return
160
 
 
 
161
  cuda_ctx = torch.random.fork_rng(devices=range(torch.cuda.device_count()))
162
  mps_ctx = (
163
  torch.random.fork_rng(devices=range(1), device_type="mps")
@@ -203,6 +207,8 @@ def generate_chat_reply(
203
  ChatReply with generated text and token ids.
204
  """
205
 
 
 
206
  tokenizer = model.tokenizer
207
  prompt, prompt_token_count = format_generation_prompt(messages, tokenizer)
208
 
 
3
  import logging
4
  from contextlib import contextmanager, nullcontext
5
  from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, Any, Literal
7
 
 
8
  from persona_data.prompts import format_messages, format_prompt, normalize_messages
 
9
 
10
  if TYPE_CHECKING:
11
+ import torch
12
  from nnterp import StandardizedTransformer
13
+ from persona_data.synth_persona import PersonaData
14
 
15
  logger = logging.getLogger(__name__)
16
  SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
 
19
  @dataclass
20
  class ChatReply:
21
  text: str
22
+ generated_ids: Any | None = None
23
 
24
 
25
  def build_chat_messages(
 
133
 
134
  def resolve_saved_tensor(value: object) -> torch.Tensor:
135
  """Resolve an nnsight ``.save()`` proxy (or raw tensor) to a CPU tensor."""
136
+ import torch
137
+
138
  resolved = value.value if getattr(value, "value", None) is not None else value
139
  if not isinstance(resolved, torch.Tensor):
140
  raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}")
 
160
  yield
161
  return
162
 
163
+ import torch
164
+
165
  cuda_ctx = torch.random.fork_rng(devices=range(torch.cuda.device_count()))
166
  mps_ctx = (
167
  torch.random.fork_rng(devices=range(1), device_type="mps")
 
207
  ChatReply with generated text and token ids.
208
  """
209
 
210
+ import torch
211
+
212
  tokenizer = model.tokenizer
213
  prompt, prompt_token_count = format_generation_prompt(messages, tokenizer)
214
 
utils/helpers.py CHANGED
@@ -1,9 +1,17 @@
 
 
1
  import hashlib
 
 
2
  import re
3
  from collections.abc import Iterable
4
  from enum import Enum
 
 
 
 
5
 
6
- from persona_data.synth_persona import PersonaData
7
 
8
 
9
  class DatasetSource(str, Enum):
@@ -74,6 +82,16 @@ def session_key(*parts: str) -> str:
74
  return ":".join(parts)
75
 
76
 
 
 
 
 
 
 
 
 
 
 
77
  def personas_fingerprint(persona_ids: Iterable[str]) -> str:
78
  """Stable short fingerprint for a set of persona ids.
79
 
 
1
+ from __future__ import annotations
2
+
3
  import hashlib
4
+ import logging
5
+ import os
6
  import re
7
  from collections.abc import Iterable
8
  from enum import Enum
9
+ from typing import TYPE_CHECKING
10
+
11
+ if TYPE_CHECKING:
12
+ from persona_data.synth_persona import PersonaData
13
 
14
+ logger = logging.getLogger(__name__)
15
 
16
 
17
  class DatasetSource(str, Enum):
 
82
  return ":".join(parts)
83
 
84
 
85
+ def env_int(name: str, default: int, *, minimum: int = 1) -> int:
86
+ """Read a bounded integer from the environment."""
87
+
88
+ try:
89
+ return max(minimum, int(os.environ.get(name, str(default))))
90
+ except ValueError:
91
+ logger.warning("Ignoring invalid integer for %s", name)
92
+ return default
93
+
94
+
95
  def personas_fingerprint(persona_ids: Iterable[str]) -> str:
96
  """Stable short fingerprint for a set of persona ids.
97
 
utils/preload.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import logging
5
+ import threading
6
+ import time
7
+ from collections.abc import Iterable
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ _started: set[tuple[str, ...]] = set()
12
+ _lock = threading.Lock()
13
+
14
+
15
+ def _warm_imports(
16
+ modules: tuple[str, ...],
17
+ functions: tuple[str, ...],
18
+ delay_seconds: float,
19
+ ) -> None:
20
+ if delay_seconds > 0:
21
+ time.sleep(delay_seconds)
22
+ for module in modules:
23
+ try:
24
+ importlib.import_module(module)
25
+ except Exception:
26
+ logger.debug("Background preload failed for %s", module, exc_info=True)
27
+ for function_path in functions:
28
+ try:
29
+ module_name, function_name = function_path.split(":", 1)
30
+ function = getattr(importlib.import_module(module_name), function_name)
31
+ function()
32
+ except Exception:
33
+ logger.debug(
34
+ "Background preload failed for %s", function_path, exc_info=True
35
+ )
36
+
37
+
38
+ def preload_once(
39
+ name: str,
40
+ *,
41
+ modules: Iterable[str] = (),
42
+ functions: Iterable[str] = (),
43
+ delay_seconds: float = 0.25,
44
+ ) -> None:
45
+ """Warm small predictable costs on a daemon thread after the visible render.
46
+
47
+ Keep this limited to imports and tiny local metadata. Avoid model
48
+ construction, Hub requests, and Streamlit cache population because those can
49
+ steal enough CPU or I/O to make the visible page feel slower.
50
+ """
51
+
52
+ module_tuple = tuple(dict.fromkeys(modules))
53
+ function_tuple = tuple(dict.fromkeys(functions))
54
+ if not module_tuple and not function_tuple:
55
+ return
56
+
57
+ key = (name, *module_tuple, *function_tuple)
58
+ with _lock:
59
+ if key in _started:
60
+ return
61
+ _started.add(key)
62
+
63
+ thread = threading.Thread(
64
+ target=_warm_imports,
65
+ args=(module_tuple, function_tuple, delay_seconds),
66
+ name=f"persona-ui-preload-{name}",
67
+ daemon=True,
68
+ )
69
+ thread.start()
utils/runtime.py CHANGED
@@ -4,9 +4,12 @@ from collections.abc import Iterable
4
 
5
  import streamlit as st
6
 
 
 
7
  logger = logging.getLogger(__name__)
8
  _LANGUAGE_MODEL_CLASSES = {"LanguageModel", "StandardizedTransformer"}
9
  _EXPECTED_NDIF_STATES = {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}
 
10
 
11
 
12
  def _iter_deployments(raw: object) -> Iterable[dict]:
@@ -91,16 +94,17 @@ def list_remote_models() -> list[str]:
91
  return sorted(set(model_names))
92
 
93
 
94
- @st.cache_resource(show_spinner=False, max_entries=1)
95
  def cached_model(model_name: str):
96
  """Load and cache a standardized nnterp model.
97
 
98
  Streamlit reruns this app on every interaction, so caching keeps one loaded
99
- model instance per model name instead of reloading weights on every widget
100
- change. ``remote`` is intentionally not part of the cache key: it matters
101
- at generation/trace time, but the current ``StandardizedTransformer``
102
  constructor ignores it, and excluding it avoids loading duplicate local
103
- model objects when toggling NDIF.
 
104
  """
105
 
106
  import torch
 
4
 
5
  import streamlit as st
6
 
7
+ from utils.helpers import env_int
8
+
9
  logger = logging.getLogger(__name__)
10
  _LANGUAGE_MODEL_CLASSES = {"LanguageModel", "StandardizedTransformer"}
11
  _EXPECTED_NDIF_STATES = {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}
12
+ _MODEL_CACHE_ENTRIES = env_int("PERSONA_UI_MODEL_CACHE_ENTRIES", 1)
13
 
14
 
15
  def _iter_deployments(raw: object) -> Iterable[dict]:
 
94
  return sorted(set(model_names))
95
 
96
 
97
+ @st.cache_resource(show_spinner=False, max_entries=_MODEL_CACHE_ENTRIES)
98
  def cached_model(model_name: str):
99
  """Load and cache a standardized nnterp model.
100
 
101
  Streamlit reruns this app on every interaction, so caching keeps one loaded
102
+ model instance instead of reloading weights on every widget change.
103
+ ``remote`` is intentionally not part of the cache key: it matters at
104
+ generation/trace time, but the current ``StandardizedTransformer``
105
  constructor ignores it, and excluding it avoids loading duplicate local
106
+ model objects when toggling NDIF. The cache defaults to one model to avoid
107
+ keeping multiple large models in RAM.
108
  """
109
 
110
  import torch