Spaces:
Sleeping
Sleeping
Jac-Zac commited on
Commit ·
c607869
1
Parent(s): db3d901
Performance improvements
Browse files- .env.example +5 -0
- app.py +16 -0
- tabs/analysis_core.py +8 -5
- tabs/chat.py +4 -2
- tabs/chat_shared.py +4 -1
- tabs/chat_ui.py +1 -1
- tabs/compare_chat.py +7 -3
- utils/analysis_metadata.py +16 -0
- utils/analysis_sources.py +6 -1
- utils/chat.py +10 -4
- utils/helpers.py +19 -1
- utils/preload.py +69 -0
- utils/runtime.py +9 -5
.env.example
CHANGED
|
@@ -18,3 +18,8 @@ ARTIFACTS_DIR=artifacts
|
|
| 18 |
# Default model IDs shown in the sidebar (optional — change to override the built-in defaults)
|
| 19 |
# DEFAULT_MODEL=google/gemma-2-2b-it
|
| 20 |
# REMOTE_DEFAULT_MODEL=google/gemma-2-9b-it
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# Default model IDs shown in the sidebar (optional — change to override the built-in defaults)
|
| 19 |
# DEFAULT_MODEL=google/gemma-2-2b-it
|
| 20 |
# REMOTE_DEFAULT_MODEL=google/gemma-2-9b-it
|
| 21 |
+
|
| 22 |
+
# Cache sizing knobs (optional)
|
| 23 |
+
# Keep model cache at 1 unless you have enough RAM for multiple loaded models.
|
| 24 |
+
# PERSONA_UI_MODEL_CACHE_ENTRIES=1
|
| 25 |
+
# PERSONA_UI_STORE_CACHE_ENTRIES=4
|
app.py
CHANGED
|
@@ -5,6 +5,7 @@ import streamlit as st
|
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
|
| 7 |
from utils.helpers import DATASET_SOURCES, session_key
|
|
|
|
| 8 |
from utils.runtime import list_remote_models
|
| 9 |
from utils.theme import install_catppuccin_theme
|
| 10 |
|
|
@@ -28,6 +29,15 @@ _SIDEBAR_DATASET_SOURCE_KEY = session_key("sidebar", "dataset_source")
|
|
| 28 |
|
| 29 |
_TABS = ["Chat", "Analysis", "Extract"]
|
| 30 |
_TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
@dataclass(frozen=True)
|
|
@@ -181,6 +191,12 @@ def main() -> None:
|
|
| 181 |
|
| 182 |
render_chat_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
if __name__ == "__main__":
|
| 186 |
main()
|
|
|
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
|
| 7 |
from utils.helpers import DATASET_SOURCES, session_key
|
| 8 |
+
from utils.preload import preload_once
|
| 9 |
from utils.runtime import list_remote_models
|
| 10 |
from utils.theme import install_catppuccin_theme
|
| 11 |
|
|
|
|
| 29 |
|
| 30 |
_TABS = ["Chat", "Analysis", "Extract"]
|
| 31 |
_TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
|
| 32 |
+
_TAB_PRELOAD_MODULES = {
|
| 33 |
+
"Chat": ("tabs.analysis_core", "tabs.extract", "tabs.compare_chat"),
|
| 34 |
+
"Analysis": ("tabs.chat", "tabs.extract"),
|
| 35 |
+
"Extract": ("tabs.chat", "tabs.analysis_core"),
|
| 36 |
+
}
|
| 37 |
+
_TAB_PRELOAD_FUNCTIONS = {
|
| 38 |
+
"Chat": ("utils.analysis_metadata:synth_persona_attribute_names",),
|
| 39 |
+
"Extract": ("utils.analysis_metadata:synth_persona_attribute_names",),
|
| 40 |
+
}
|
| 41 |
|
| 42 |
|
| 43 |
@dataclass(frozen=True)
|
|
|
|
| 191 |
|
| 192 |
render_chat_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
|
| 193 |
|
| 194 |
+
preload_once(
|
| 195 |
+
f"after-{sidebar.active_tab.lower()}",
|
| 196 |
+
modules=_TAB_PRELOAD_MODULES.get(sidebar.active_tab, ()),
|
| 197 |
+
functions=_TAB_PRELOAD_FUNCTIONS.get(sidebar.active_tab, ()),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
|
| 201 |
if __name__ == "__main__":
|
| 202 |
main()
|
tabs/analysis_core.py
CHANGED
|
@@ -7,7 +7,7 @@ from pathlib import Path
|
|
| 7 |
import plotly.graph_objects as go
|
| 8 |
import streamlit as st
|
| 9 |
from persona_data.environment import get_artifacts_dir
|
| 10 |
-
from persona_data.synth_persona import BASELINE_PERSONA_ID
|
| 11 |
from persona_vectors.attributes import (
|
| 12 |
DEFAULT_MAX_ATTRIBUTE_CATEGORIES,
|
| 13 |
attribute_color_kwargs,
|
|
@@ -45,6 +45,10 @@ from utils.analysis_sources import (
|
|
| 45 |
store_id,
|
| 46 |
store_layers_cached,
|
| 47 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
from utils.controls import render_mask_strategy_select
|
| 49 |
from utils.helpers import (
|
| 50 |
ANALYSIS_HELP_TEXT,
|
|
@@ -99,9 +103,8 @@ _PROJECTION_COLOR_MODES = ["Persona", "K-means clusters", "Persona attribute"]
|
|
| 99 |
_MAX_ATTRIBUTE_CATEGORIES = DEFAULT_MAX_ATTRIBUTE_CATEGORIES
|
| 100 |
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
return SynthPersonaDataset()
|
| 105 |
|
| 106 |
|
| 107 |
def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
|
|
@@ -983,7 +986,7 @@ def _render_projection_color_config(
|
|
| 983 |
|
| 984 |
if color_mode == "Persona attribute":
|
| 985 |
persona_dataset = _synth_persona_dataset()
|
| 986 |
-
attribute_options = list(
|
| 987 |
if not attribute_options:
|
| 988 |
st.info("No persona attributes are available for this dataset.")
|
| 989 |
return None
|
|
|
|
| 7 |
import plotly.graph_objects as go
|
| 8 |
import streamlit as st
|
| 9 |
from persona_data.environment import get_artifacts_dir
|
| 10 |
+
from persona_data.synth_persona import BASELINE_PERSONA_ID
|
| 11 |
from persona_vectors.attributes import (
|
| 12 |
DEFAULT_MAX_ATTRIBUTE_CATEGORIES,
|
| 13 |
attribute_color_kwargs,
|
|
|
|
| 45 |
store_id,
|
| 46 |
store_layers_cached,
|
| 47 |
)
|
| 48 |
+
from utils.analysis_metadata import (
|
| 49 |
+
synth_persona_attribute_names,
|
| 50 |
+
synth_persona_dataset_cached,
|
| 51 |
+
)
|
| 52 |
from utils.controls import render_mask_strategy_select
|
| 53 |
from utils.helpers import (
|
| 54 |
ANALYSIS_HELP_TEXT,
|
|
|
|
| 103 |
_MAX_ATTRIBUTE_CATEGORIES = DEFAULT_MAX_ATTRIBUTE_CATEGORIES
|
| 104 |
|
| 105 |
|
| 106 |
+
def _synth_persona_dataset():
|
| 107 |
+
return synth_persona_dataset_cached()
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
|
|
|
|
| 986 |
|
| 987 |
if color_mode == "Persona attribute":
|
| 988 |
persona_dataset = _synth_persona_dataset()
|
| 989 |
+
attribute_options = list(synth_persona_attribute_names())
|
| 990 |
if not attribute_options:
|
| 991 |
st.info("No persona attributes are available for this dataset.")
|
| 992 |
return None
|
tabs/chat.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
from typing import cast
|
| 4 |
|
| 5 |
import streamlit as st
|
| 6 |
-
from persona_data.synth_persona import PersonaData
|
| 7 |
|
| 8 |
from state import (
|
| 9 |
ChatState,
|
|
@@ -29,6 +28,9 @@ from utils.chat_export import save_chat_export
|
|
| 29 |
from utils.helpers import session_key, widget_key
|
| 30 |
from utils.runtime import cached_model
|
| 31 |
|
|
|
|
|
|
|
|
|
|
| 32 |
_LAST_PERSONA_ID_KEY = session_key("chat", "last_persona_id")
|
| 33 |
_LAST_PROMPT_MODE_KEY = session_key("chat", "last_prompt_mode")
|
| 34 |
_LAST_COMPARE_MODE_KEY = session_key("chat", "last_compare_mode")
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
from typing import TYPE_CHECKING, cast
|
| 4 |
|
| 5 |
import streamlit as st
|
|
|
|
| 6 |
|
| 7 |
from state import (
|
| 8 |
ChatState,
|
|
|
|
| 28 |
from utils.helpers import session_key, widget_key
|
| 29 |
from utils.runtime import cached_model
|
| 30 |
|
| 31 |
+
if TYPE_CHECKING:
|
| 32 |
+
from persona_data.synth_persona import PersonaData
|
| 33 |
+
|
| 34 |
_LAST_PERSONA_ID_KEY = session_key("chat", "last_persona_id")
|
| 35 |
_LAST_PROMPT_MODE_KEY = session_key("chat", "last_prompt_mode")
|
| 36 |
_LAST_COMPARE_MODE_KEY = session_key("chat", "last_compare_mode")
|
tabs/chat_shared.py
CHANGED
|
@@ -2,9 +2,9 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
from collections.abc import Callable
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
import streamlit as st
|
| 7 |
-
from persona_data.synth_persona import PersonaData
|
| 8 |
|
| 9 |
from state import ChatState
|
| 10 |
from tabs.chat_ui import GenerationConfig, render_persona_prompt_controls
|
|
@@ -12,6 +12,9 @@ from utils.chat import ChatReply, generate_chat_reply
|
|
| 12 |
from utils.datasets import load_persona_list
|
| 13 |
from utils.helpers import session_key
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
@dataclass(frozen=True)
|
| 17 |
class ChatSelection:
|
|
|
|
| 2 |
|
| 3 |
from collections.abc import Callable
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
|
| 7 |
import streamlit as st
|
|
|
|
| 8 |
|
| 9 |
from state import ChatState
|
| 10 |
from tabs.chat_ui import GenerationConfig, render_persona_prompt_controls
|
|
|
|
| 12 |
from utils.datasets import load_persona_list
|
| 13 |
from utils.helpers import session_key
|
| 14 |
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from persona_data.synth_persona import PersonaData
|
| 17 |
+
|
| 18 |
|
| 19 |
@dataclass(frozen=True)
|
| 20 |
class ChatSelection:
|
tabs/chat_ui.py
CHANGED
|
@@ -5,7 +5,6 @@ from dataclasses import asdict, dataclass
|
|
| 5 |
from typing import TYPE_CHECKING, Any
|
| 6 |
|
| 7 |
import streamlit as st
|
| 8 |
-
from persona_data.synth_persona import PersonaData
|
| 9 |
|
| 10 |
from utils.helpers import (
|
| 11 |
CHAT_PROMPT_MODE_LABEL_TO_KEY,
|
|
@@ -16,6 +15,7 @@ from utils.helpers import (
|
|
| 16 |
)
|
| 17 |
|
| 18 |
if TYPE_CHECKING:
|
|
|
|
| 19 |
from utils.contrast import TokenContrast
|
| 20 |
|
| 21 |
GENERATION_DEFAULTS = {
|
|
|
|
| 5 |
from typing import TYPE_CHECKING, Any
|
| 6 |
|
| 7 |
import streamlit as st
|
|
|
|
| 8 |
|
| 9 |
from utils.helpers import (
|
| 10 |
CHAT_PROMPT_MODE_LABEL_TO_KEY,
|
|
|
|
| 15 |
)
|
| 16 |
|
| 17 |
if TYPE_CHECKING:
|
| 18 |
+
from persona_data.synth_persona import PersonaData
|
| 19 |
from utils.contrast import TokenContrast
|
| 20 |
|
| 21 |
GENERATION_DEFAULTS = {
|
tabs/compare_chat.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
|
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
-
from typing import Any
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
-
from nnterp import StandardizedTransformer
|
| 6 |
-
from persona_data.synth_persona import PersonaData
|
| 7 |
|
| 8 |
from state import ChatState, default_chat_state, reset_chat_context_state
|
| 9 |
from tabs.chat_shared import (
|
|
@@ -24,6 +24,10 @@ from .chat_ui import (
|
|
| 24 |
render_system_prompt,
|
| 25 |
)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
@dataclass(frozen=True)
|
| 29 |
class ComparePanel:
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
from dataclasses import dataclass
|
| 4 |
+
from typing import TYPE_CHECKING, Any
|
| 5 |
|
| 6 |
import streamlit as st
|
|
|
|
|
|
|
| 7 |
|
| 8 |
from state import ChatState, default_chat_state, reset_chat_context_state
|
| 9 |
from tabs.chat_shared import (
|
|
|
|
| 24 |
render_system_prompt,
|
| 25 |
)
|
| 26 |
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from nnterp import StandardizedTransformer
|
| 29 |
+
from persona_data.synth_persona import PersonaData
|
| 30 |
+
|
| 31 |
|
| 32 |
@dataclass(frozen=True)
|
| 33 |
class ComparePanel:
|
utils/analysis_metadata.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@lru_cache(maxsize=1)
|
| 8 |
+
def synth_persona_dataset_cached() -> Any:
|
| 9 |
+
from persona_data.synth_persona import SynthPersonaDataset
|
| 10 |
+
|
| 11 |
+
return SynthPersonaDataset()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@lru_cache(maxsize=1)
|
| 15 |
+
def synth_persona_attribute_names() -> tuple[str, ...]:
|
| 16 |
+
return tuple(synth_persona_dataset_cached().attribute_names)
|
utils/analysis_sources.py
CHANGED
|
@@ -11,6 +11,8 @@ from persona_vectors.artifacts import (
|
|
| 11 |
from persona_vectors.extraction import MaskStrategy
|
| 12 |
from persona_vectors.hub import list_hub_vector_models
|
| 13 |
|
|
|
|
|
|
|
| 14 |
Store = ActivationStore | HFActivationStore
|
| 15 |
|
| 16 |
DEFAULT_HUB_REPO = os.environ.get(
|
|
@@ -23,7 +25,10 @@ SOURCE_LOCAL = "Local activations"
|
|
| 23 |
SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
|
| 24 |
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
| 27 |
def activation_store_cached(
|
| 28 |
source: str,
|
| 29 |
location: str,
|
|
|
|
| 11 |
from persona_vectors.extraction import MaskStrategy
|
| 12 |
from persona_vectors.hub import list_hub_vector_models
|
| 13 |
|
| 14 |
+
from utils.helpers import env_int
|
| 15 |
+
|
| 16 |
Store = ActivationStore | HFActivationStore
|
| 17 |
|
| 18 |
DEFAULT_HUB_REPO = os.environ.get(
|
|
|
|
| 25 |
SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
|
| 26 |
|
| 27 |
|
| 28 |
+
_STORE_CACHE_ENTRIES = env_int("PERSONA_UI_STORE_CACHE_ENTRIES", 4)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@st.cache_resource(show_spinner=False, max_entries=_STORE_CACHE_ENTRIES)
|
| 32 |
def activation_store_cached(
|
| 33 |
source: str,
|
| 34 |
location: str,
|
utils/chat.py
CHANGED
|
@@ -3,14 +3,14 @@ from __future__ import annotations
|
|
| 3 |
import logging
|
| 4 |
from contextlib import contextmanager, nullcontext
|
| 5 |
from dataclasses import dataclass
|
| 6 |
-
from typing import TYPE_CHECKING, Literal
|
| 7 |
|
| 8 |
-
import torch
|
| 9 |
from persona_data.prompts import format_messages, format_prompt, normalize_messages
|
| 10 |
-
from persona_data.synth_persona import PersonaData
|
| 11 |
|
| 12 |
if TYPE_CHECKING:
|
|
|
|
| 13 |
from nnterp import StandardizedTransformer
|
|
|
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
|
|
@@ -19,7 +19,7 @@ SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
|
|
| 19 |
@dataclass
|
| 20 |
class ChatReply:
|
| 21 |
text: str
|
| 22 |
-
generated_ids:
|
| 23 |
|
| 24 |
|
| 25 |
def build_chat_messages(
|
|
@@ -133,6 +133,8 @@ def format_generation_prompt(
|
|
| 133 |
|
| 134 |
def resolve_saved_tensor(value: object) -> torch.Tensor:
|
| 135 |
"""Resolve an nnsight ``.save()`` proxy (or raw tensor) to a CPU tensor."""
|
|
|
|
|
|
|
| 136 |
resolved = value.value if getattr(value, "value", None) is not None else value
|
| 137 |
if not isinstance(resolved, torch.Tensor):
|
| 138 |
raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}")
|
|
@@ -158,6 +160,8 @@ def _seeded_rng(seed: int | None):
|
|
| 158 |
yield
|
| 159 |
return
|
| 160 |
|
|
|
|
|
|
|
| 161 |
cuda_ctx = torch.random.fork_rng(devices=range(torch.cuda.device_count()))
|
| 162 |
mps_ctx = (
|
| 163 |
torch.random.fork_rng(devices=range(1), device_type="mps")
|
|
@@ -203,6 +207,8 @@ def generate_chat_reply(
|
|
| 203 |
ChatReply with generated text and token ids.
|
| 204 |
"""
|
| 205 |
|
|
|
|
|
|
|
| 206 |
tokenizer = model.tokenizer
|
| 207 |
prompt, prompt_token_count = format_generation_prompt(messages, tokenizer)
|
| 208 |
|
|
|
|
| 3 |
import logging
|
| 4 |
from contextlib import contextmanager, nullcontext
|
| 5 |
from dataclasses import dataclass
|
| 6 |
+
from typing import TYPE_CHECKING, Any, Literal
|
| 7 |
|
|
|
|
| 8 |
from persona_data.prompts import format_messages, format_prompt, normalize_messages
|
|
|
|
| 9 |
|
| 10 |
if TYPE_CHECKING:
|
| 11 |
+
import torch
|
| 12 |
from nnterp import StandardizedTransformer
|
| 13 |
+
from persona_data.synth_persona import PersonaData
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
|
|
|
|
| 19 |
@dataclass
|
| 20 |
class ChatReply:
|
| 21 |
text: str
|
| 22 |
+
generated_ids: Any | None = None
|
| 23 |
|
| 24 |
|
| 25 |
def build_chat_messages(
|
|
|
|
| 133 |
|
| 134 |
def resolve_saved_tensor(value: object) -> torch.Tensor:
|
| 135 |
"""Resolve an nnsight ``.save()`` proxy (or raw tensor) to a CPU tensor."""
|
| 136 |
+
import torch
|
| 137 |
+
|
| 138 |
resolved = value.value if getattr(value, "value", None) is not None else value
|
| 139 |
if not isinstance(resolved, torch.Tensor):
|
| 140 |
raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}")
|
|
|
|
| 160 |
yield
|
| 161 |
return
|
| 162 |
|
| 163 |
+
import torch
|
| 164 |
+
|
| 165 |
cuda_ctx = torch.random.fork_rng(devices=range(torch.cuda.device_count()))
|
| 166 |
mps_ctx = (
|
| 167 |
torch.random.fork_rng(devices=range(1), device_type="mps")
|
|
|
|
| 207 |
ChatReply with generated text and token ids.
|
| 208 |
"""
|
| 209 |
|
| 210 |
+
import torch
|
| 211 |
+
|
| 212 |
tokenizer = model.tokenizer
|
| 213 |
prompt, prompt_token_count = format_generation_prompt(messages, tokenizer)
|
| 214 |
|
utils/helpers.py
CHANGED
|
@@ -1,9 +1,17 @@
|
|
|
|
|
|
|
|
| 1 |
import hashlib
|
|
|
|
|
|
|
| 2 |
import re
|
| 3 |
from collections.abc import Iterable
|
| 4 |
from enum import Enum
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
|
| 8 |
|
| 9 |
class DatasetSource(str, Enum):
|
|
@@ -74,6 +82,16 @@ def session_key(*parts: str) -> str:
|
|
| 74 |
return ":".join(parts)
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
def personas_fingerprint(persona_ids: Iterable[str]) -> str:
|
| 78 |
"""Stable short fingerprint for a set of persona ids.
|
| 79 |
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import hashlib
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
import re
|
| 7 |
from collections.abc import Iterable
|
| 8 |
from enum import Enum
|
| 9 |
+
from typing import TYPE_CHECKING
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from persona_data.synth_persona import PersonaData
|
| 13 |
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
|
| 17 |
class DatasetSource(str, Enum):
|
|
|
|
| 82 |
return ":".join(parts)
|
| 83 |
|
| 84 |
|
| 85 |
+
def env_int(name: str, default: int, *, minimum: int = 1) -> int:
|
| 86 |
+
"""Read a bounded integer from the environment."""
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
return max(minimum, int(os.environ.get(name, str(default))))
|
| 90 |
+
except ValueError:
|
| 91 |
+
logger.warning("Ignoring invalid integer for %s", name)
|
| 92 |
+
return default
|
| 93 |
+
|
| 94 |
+
|
| 95 |
def personas_fingerprint(persona_ids: Iterable[str]) -> str:
|
| 96 |
"""Stable short fingerprint for a set of persona ids.
|
| 97 |
|
utils/preload.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import importlib
|
| 4 |
+
import logging
|
| 5 |
+
import threading
|
| 6 |
+
import time
|
| 7 |
+
from collections.abc import Iterable
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
_started: set[tuple[str, ...]] = set()
|
| 12 |
+
_lock = threading.Lock()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _warm_imports(
|
| 16 |
+
modules: tuple[str, ...],
|
| 17 |
+
functions: tuple[str, ...],
|
| 18 |
+
delay_seconds: float,
|
| 19 |
+
) -> None:
|
| 20 |
+
if delay_seconds > 0:
|
| 21 |
+
time.sleep(delay_seconds)
|
| 22 |
+
for module in modules:
|
| 23 |
+
try:
|
| 24 |
+
importlib.import_module(module)
|
| 25 |
+
except Exception:
|
| 26 |
+
logger.debug("Background preload failed for %s", module, exc_info=True)
|
| 27 |
+
for function_path in functions:
|
| 28 |
+
try:
|
| 29 |
+
module_name, function_name = function_path.split(":", 1)
|
| 30 |
+
function = getattr(importlib.import_module(module_name), function_name)
|
| 31 |
+
function()
|
| 32 |
+
except Exception:
|
| 33 |
+
logger.debug(
|
| 34 |
+
"Background preload failed for %s", function_path, exc_info=True
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def preload_once(
|
| 39 |
+
name: str,
|
| 40 |
+
*,
|
| 41 |
+
modules: Iterable[str] = (),
|
| 42 |
+
functions: Iterable[str] = (),
|
| 43 |
+
delay_seconds: float = 0.25,
|
| 44 |
+
) -> None:
|
| 45 |
+
"""Warm small predictable costs on a daemon thread after the visible render.
|
| 46 |
+
|
| 47 |
+
Keep this limited to imports and tiny local metadata. Avoid model
|
| 48 |
+
construction, Hub requests, and Streamlit cache population because those can
|
| 49 |
+
steal enough CPU or I/O to make the visible page feel slower.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
module_tuple = tuple(dict.fromkeys(modules))
|
| 53 |
+
function_tuple = tuple(dict.fromkeys(functions))
|
| 54 |
+
if not module_tuple and not function_tuple:
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
key = (name, *module_tuple, *function_tuple)
|
| 58 |
+
with _lock:
|
| 59 |
+
if key in _started:
|
| 60 |
+
return
|
| 61 |
+
_started.add(key)
|
| 62 |
+
|
| 63 |
+
thread = threading.Thread(
|
| 64 |
+
target=_warm_imports,
|
| 65 |
+
args=(module_tuple, function_tuple, delay_seconds),
|
| 66 |
+
name=f"persona-ui-preload-{name}",
|
| 67 |
+
daemon=True,
|
| 68 |
+
)
|
| 69 |
+
thread.start()
|
utils/runtime.py
CHANGED
|
@@ -4,9 +4,12 @@ from collections.abc import Iterable
|
|
| 4 |
|
| 5 |
import streamlit as st
|
| 6 |
|
|
|
|
|
|
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
_LANGUAGE_MODEL_CLASSES = {"LanguageModel", "StandardizedTransformer"}
|
| 9 |
_EXPECTED_NDIF_STATES = {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def _iter_deployments(raw: object) -> Iterable[dict]:
|
|
@@ -91,16 +94,17 @@ def list_remote_models() -> list[str]:
|
|
| 91 |
return sorted(set(model_names))
|
| 92 |
|
| 93 |
|
| 94 |
-
@st.cache_resource(show_spinner=False, max_entries=
|
| 95 |
def cached_model(model_name: str):
|
| 96 |
"""Load and cache a standardized nnterp model.
|
| 97 |
|
| 98 |
Streamlit reruns this app on every interaction, so caching keeps one loaded
|
| 99 |
-
model instance
|
| 100 |
-
|
| 101 |
-
|
| 102 |
constructor ignores it, and excluding it avoids loading duplicate local
|
| 103 |
-
model objects when toggling NDIF.
|
|
|
|
| 104 |
"""
|
| 105 |
|
| 106 |
import torch
|
|
|
|
| 4 |
|
| 5 |
import streamlit as st
|
| 6 |
|
| 7 |
+
from utils.helpers import env_int
|
| 8 |
+
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
_LANGUAGE_MODEL_CLASSES = {"LanguageModel", "StandardizedTransformer"}
|
| 11 |
_EXPECTED_NDIF_STATES = {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}
|
| 12 |
+
_MODEL_CACHE_ENTRIES = env_int("PERSONA_UI_MODEL_CACHE_ENTRIES", 1)
|
| 13 |
|
| 14 |
|
| 15 |
def _iter_deployments(raw: object) -> Iterable[dict]:
|
|
|
|
| 94 |
return sorted(set(model_names))
|
| 95 |
|
| 96 |
|
| 97 |
+
@st.cache_resource(show_spinner=False, max_entries=_MODEL_CACHE_ENTRIES)
|
| 98 |
def cached_model(model_name: str):
|
| 99 |
"""Load and cache a standardized nnterp model.
|
| 100 |
|
| 101 |
Streamlit reruns this app on every interaction, so caching keeps one loaded
|
| 102 |
+
model instance instead of reloading weights on every widget change.
|
| 103 |
+
``remote`` is intentionally not part of the cache key: it matters at
|
| 104 |
+
generation/trace time, but the current ``StandardizedTransformer``
|
| 105 |
constructor ignores it, and excluding it avoids loading duplicate local
|
| 106 |
+
model objects when toggling NDIF. The cache defaults to one model to avoid
|
| 107 |
+
keeping multiple large models in RAM.
|
| 108 |
"""
|
| 109 |
|
| 110 |
import torch
|