File size: 3,973 Bytes
c607869
 
d39b2dd
c607869
 
f4259c0
d39b2dd
db3d901
c607869
 
 
 
f4259c0
c607869
a89a7f1
db3d901
 
 
 
 
 
 
 
 
 
 
a89a7f1
 
 
73d72c1
a89a7f1
 
 
 
 
77c2d62
 
c30bbc5
db3d901
 
 
 
 
 
 
e75684b
a89a7f1
 
 
330d092
0ba2e45
 
db3d901
ecd19ae
a89a7f1
 
f4259c0
 
 
 
 
 
 
 
a89a7f1
 
b279884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4259c0
 
a89a7f1
 
 
 
 
 
 
 
 
 
db3d901
 
 
 
 
 
c607869
 
 
 
 
 
 
 
 
 
d39b2dd
 
 
 
 
 
 
 
 
 
 
 
a89a7f1
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from __future__ import annotations

import hashlib
import logging
import os
import re
from collections.abc import Iterable
from enum import Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from persona_data.synth_persona import PersonaData

logger = logging.getLogger(__name__)


class DatasetSource(str, Enum):
    SYNTH_PERSONA = "HuggingFace: synth-persona"
    NEMOTRON_FRANCE = "HuggingFace: nemotron-france"
    NEMOTRON_USA = "HuggingFace: nemotron-usa"
    LOCAL_UPLOAD = "Local JSONL upload"


DATASET_SOURCES = [s.value for s in DatasetSource]


# Variant key -> human-readable label mapping
VARIANT_LABELS = {
    "empty": "None",
    "baseline": "Baseline",
    "templated": "Template",
    "biography": "Biography",
    "custom": "Custom",
}

CHAT_PROMPT_MODES = ("empty", "templated", "biography", "custom")
CHAT_PROMPT_MODE_LABELS = [VARIANT_LABELS[key] for key in CHAT_PROMPT_MODES]
CHAT_PROMPT_MODE_LABEL_TO_KEY = {VARIANT_LABELS[key]: key for key in CHAT_PROMPT_MODES}
ANALYSIS_MODES = [
    "Cosine similarity",
    "Similarity matrix",
    "PCA",
    "UMAP",
    "Isomap",
    "Dendrogram",
]

ANALYSIS_HELP_TEXT = {
    "Cosine similarity": "Compare layer-wise alignment between variants.",
    "Similarity matrix": "Compare centered pairwise similarity between persona vectors by layer, with pair trajectories across layers.",
    "PCA": "Project per-persona vectors into a 2D or 3D global view.",
    "UMAP": "Project per-persona vectors into a 2D or 3D local-neighborhood view.",
    "Isomap": "Project per-persona vectors with graph-geodesic distances to probe manifold-like geometry.",
    "Dendrogram": "Hierarchical clustering of persona vectors β€” shows biography and templated side by side for direct comparison.",
}

NDIF_STATUS_ICONS = {
    "RECEIVED": "β—‰",
    "QUEUED": "β—Ž",
    "DISPATCHED": "β—ˆ",
    "RUNNING": "●",
    "COMPLETED": "βœ“",
    "ERROR": "βœ—",
}


def format_ndif_status(
    job_id: str,
    status_name: str,
    description: str,
    *,
    prefix: str | None = None,
    completed_detail: str | None = None,
) -> str:
    """Build the shared one-line NDIF status label used across the UI."""

    icon = NDIF_STATUS_ICONS.get(status_name, "β€’")
    detail = (
        completed_detail
        if completed_detail is not None and status_name == "COMPLETED"
        else description
    )
    label = f"{icon} `{job_id}` **{status_name}** β€” {detail}"
    return f"{prefix}: {label}" if prefix else label


def slugify(value: str) -> str:
    """Convert a string to a filesystem-safe slug."""

    return re.sub(r"[^a-z0-9]+", "_", value.lower()).strip("_") or "unknown"


def widget_key(*parts: str) -> str:
    """Generate a namespaced Streamlit widget key from parts."""

    return "::".join(parts)


def session_key(*parts: str) -> str:
    """Generate a colon-separated Streamlit session-state key from parts."""

    return ":".join(parts)


def env_int(name: str, default: int, *, minimum: int = 1) -> int:
    """Read a bounded integer from the environment."""

    try:
        return max(minimum, int(os.environ.get(name, str(default))))
    except ValueError:
        logger.warning("Ignoring invalid integer for %s", name)
        return default


def personas_fingerprint(persona_ids: Iterable[str]) -> str:
    """Stable short fingerprint for a set of persona ids.

    Used as a discriminator in widget keys and session-state keys. At ~1k
    personas, joining ids would produce ~20 KB strings; the sha1 prefix is
    fixed-length and keeps tracebacks readable.
    """

    joined = "|".join(sorted(persona_ids))
    return hashlib.sha1(joined.encode()).hexdigest()[:16]


def prompt_variant_label(variant: str) -> str:
    """Return a human-friendly prompt-variant label."""

    return VARIANT_LABELS.get(variant, variant.title())


def persona_label(persona: PersonaData) -> str:
    """Format a persona for selection widgets."""

    return f"{persona.name} ({persona.id})"