File size: 4,130 Bytes
a89a7f1 7ad2026 fee1567 b279884 fee1567 b279884 330d092 db3d901 ecd19ae a89a7f1 ecd19ae d39b2dd db3d901 d39b2dd a89a7f1 b279884 a89a7f1 330d092 2bf3d21 9ba2da4 b279884 a89a7f1 7ad2026 9ba2da4 330d092 9ba2da4 4df7d97 b279884 9ba2da4 a89a7f1 ecd19ae db3d901 0ba2e45 db3d901 0ba2e45 db3d901 0ba2e45 9ba2da4 0ba2e45 9ba2da4 0ba2e45 9ba2da4 0ba2e45 9ba2da4 0ba2e45 4df7d97 9ba2da4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | import streamlit as st
from tabs.analysis._shared import _render_mask_strategy_select
from tabs.analysis._state import (
_DEFAULT_PERSONA_LIMITS,
_LAST_PROJECTION_DIMS_KEY,
_LAST_SIMILARITY_PERSONAS_KEY,
_LAST_SOURCE_KEY,
_MAX_PERSONA_COUNTS,
)
from tabs.analysis.cosine import _render_cosine_similarity
from tabs.analysis.dendrogram import _render_dendrogram_analysis
from tabs.analysis.layered import _render_layered_figure_analysis
from utils.helpers import (
ANALYSIS_HELP_TEXT,
ANALYSIS_MODES,
prompt_variant_label,
widget_key,
)
from utils.source_controls import render_source_select, render_store_select
def render_analysis_tab() -> None:
"""Render the analysis tab."""
st.title("Analysis")
st.caption(
"Analyse persona vectors by cosine similarity, PCA, UMAP, Isomap, or hierarchical clustering."
)
source = render_source_select(widget_scope="load", last_source_key=_LAST_SOURCE_KEY)
analysis_mode = st.segmented_control(
"Analysis mode",
options=ANALYSIS_MODES,
default=ANALYSIS_MODES[0],
key=widget_key("load", "analysis_mode"),
label_visibility="collapsed",
)
if analysis_mode is None:
analysis_mode = ANALYSIS_MODES[0]
st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
with st.expander("Source settings", expanded=True):
mask_strategy = _render_mask_strategy_select(analysis_mode)
store = render_store_select(
source,
mask_strategy,
state_prefix="analysis",
widget_scope="load",
artifacts_root_key="analysis:artifacts_root",
model_label="Hub model",
local_model_label="Local model",
allow_custom_local_model=True,
repo_help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
fallback_help="Analysis-only model id to use if Hub config discovery is unavailable.",
)
if analysis_mode == "Cosine similarity":
_render_cosine_similarity(store, mask_strategy)
return
if analysis_mode == "Similarity matrix":
_render_layered_figure_analysis(
store,
mask_strategy,
scope="similarity_matrix",
figure_kind="similarity",
button_label="Generate similarity matrix",
title_fn=lambda v: (
f"Centered similarity - {prompt_variant_label(v)} - persona vectors"
),
include_pair_trajectories=True,
remember_key=_LAST_SIMILARITY_PERSONAS_KEY,
default_count_limit=_DEFAULT_PERSONA_LIMITS["similarity"],
max_count_limit=_MAX_PERSONA_COUNTS["similarity"],
allow_specific_personas=True,
)
return
if analysis_mode == "Dendrogram":
_render_dendrogram_analysis(store, mask_strategy)
return
dim_options = ["2D", "3D"]
dim_key = widget_key("load", "projection_dims", analysis_mode)
remembered_dim = st.session_state.get(
dim_key,
st.session_state.get(_LAST_PROJECTION_DIMS_KEY, "2D"),
)
if remembered_dim not in dim_options:
remembered_dim = "2D"
dimension_choice = st.segmented_control(
"Projection dimensions",
options=dim_options,
default=remembered_dim,
key=dim_key,
label_visibility="collapsed",
)
if dimension_choice is not None:
st.session_state[_LAST_PROJECTION_DIMS_KEY] = dimension_choice
n_components = 3 if dimension_choice == "3D" else 2
dim_suffix = "" if n_components == 2 else " (3D)"
_render_layered_figure_analysis(
store,
mask_strategy,
scope=f"{analysis_mode.lower()}{'_3d' if n_components == 3 else ''}",
figure_kind=analysis_mode.lower(),
button_label=f"Generate {analysis_mode}{dim_suffix} projection",
title_fn=lambda v: (
f"{analysis_mode}{dim_suffix} - {prompt_variant_label(v)} - persona vectors"
),
n_components=n_components,
default_count_limit=_DEFAULT_PERSONA_LIMITS[analysis_mode.lower()],
)
|