File size: 4,130 Bytes
a89a7f1
7ad2026
fee1567
 
 
 
 
 
b279884
fee1567
 
 
 
b279884
 
 
 
 
 
 
330d092
 
db3d901
ecd19ae
a89a7f1
ecd19ae
d39b2dd
db3d901
d39b2dd
a89a7f1
b279884
a89a7f1
 
 
 
 
 
 
 
 
 
 
330d092
2bf3d21
9ba2da4
b279884
 
 
 
 
 
 
 
 
 
 
 
a89a7f1
 
7ad2026
 
 
9ba2da4
 
 
 
 
 
 
330d092
9ba2da4
 
4df7d97
 
b279884
 
9ba2da4
a89a7f1
 
ecd19ae
 
 
 
db3d901
 
 
 
 
 
 
 
0ba2e45
 
db3d901
 
 
0ba2e45
 
db3d901
 
0ba2e45
 
9ba2da4
 
 
0ba2e45
9ba2da4
0ba2e45
9ba2da4
0ba2e45
9ba2da4
0ba2e45
4df7d97
9ba2da4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st

from tabs.analysis._shared import _render_mask_strategy_select
from tabs.analysis._state import (
    _DEFAULT_PERSONA_LIMITS,
    _LAST_PROJECTION_DIMS_KEY,
    _LAST_SIMILARITY_PERSONAS_KEY,
    _LAST_SOURCE_KEY,
    _MAX_PERSONA_COUNTS,
)
from tabs.analysis.cosine import _render_cosine_similarity
from tabs.analysis.dendrogram import _render_dendrogram_analysis
from tabs.analysis.layered import _render_layered_figure_analysis
from utils.helpers import (
    ANALYSIS_HELP_TEXT,
    ANALYSIS_MODES,
    prompt_variant_label,
    widget_key,
)
from utils.source_controls import render_source_select, render_store_select


def render_analysis_tab() -> None:
    """Render the analysis tab."""

    st.title("Analysis")
    st.caption(
        "Analyse persona vectors by cosine similarity, PCA, UMAP, Isomap, or hierarchical clustering."
    )

    source = render_source_select(widget_scope="load", last_source_key=_LAST_SOURCE_KEY)

    analysis_mode = st.segmented_control(
        "Analysis mode",
        options=ANALYSIS_MODES,
        default=ANALYSIS_MODES[0],
        key=widget_key("load", "analysis_mode"),
        label_visibility="collapsed",
    )
    if analysis_mode is None:
        analysis_mode = ANALYSIS_MODES[0]
    st.caption(ANALYSIS_HELP_TEXT[analysis_mode])

    with st.expander("Source settings", expanded=True):
        mask_strategy = _render_mask_strategy_select(analysis_mode)
        store = render_store_select(
            source,
            mask_strategy,
            state_prefix="analysis",
            widget_scope="load",
            artifacts_root_key="analysis:artifacts_root",
            model_label="Hub model",
            local_model_label="Local model",
            allow_custom_local_model=True,
            repo_help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
            fallback_help="Analysis-only model id to use if Hub config discovery is unavailable.",
        )

    if analysis_mode == "Cosine similarity":
        _render_cosine_similarity(store, mask_strategy)
        return
    if analysis_mode == "Similarity matrix":
        _render_layered_figure_analysis(
            store,
            mask_strategy,
            scope="similarity_matrix",
            figure_kind="similarity",
            button_label="Generate similarity matrix",
            title_fn=lambda v: (
                f"Centered similarity - {prompt_variant_label(v)} - persona vectors"
            ),
            include_pair_trajectories=True,
            remember_key=_LAST_SIMILARITY_PERSONAS_KEY,
            default_count_limit=_DEFAULT_PERSONA_LIMITS["similarity"],
            max_count_limit=_MAX_PERSONA_COUNTS["similarity"],
            allow_specific_personas=True,
        )
        return

    if analysis_mode == "Dendrogram":
        _render_dendrogram_analysis(store, mask_strategy)
        return

    dim_options = ["2D", "3D"]
    dim_key = widget_key("load", "projection_dims", analysis_mode)
    remembered_dim = st.session_state.get(
        dim_key,
        st.session_state.get(_LAST_PROJECTION_DIMS_KEY, "2D"),
    )
    if remembered_dim not in dim_options:
        remembered_dim = "2D"
    dimension_choice = st.segmented_control(
        "Projection dimensions",
        options=dim_options,
        default=remembered_dim,
        key=dim_key,
        label_visibility="collapsed",
    )
    if dimension_choice is not None:
        st.session_state[_LAST_PROJECTION_DIMS_KEY] = dimension_choice
    n_components = 3 if dimension_choice == "3D" else 2
    dim_suffix = "" if n_components == 2 else " (3D)"
    _render_layered_figure_analysis(
        store,
        mask_strategy,
        scope=f"{analysis_mode.lower()}{'_3d' if n_components == 3 else ''}",
        figure_kind=analysis_mode.lower(),
        button_label=f"Generate {analysis_mode}{dim_suffix} projection",
        title_fn=lambda v: (
            f"{analysis_mode}{dim_suffix} - {prompt_variant_label(v)} - persona vectors"
        ),
        n_components=n_components,
        default_count_limit=_DEFAULT_PERSONA_LIMITS[analysis_mode.lower()],
    )