import plotly.graph_objects as go import streamlit as st from persona_data.synth_persona import BASELINE_PERSONA_ID from persona_vectors.extraction import MaskStrategy from persona_vectors.plots import save_plot_html from tabs.analysis._state import ( _DEFAULT_LAYER_FRAMES, _HIGHLIGHT_OTHER_COLOR, _HIGHLIGHT_OTHER_LABEL, _LAST_LAYER_FRAMES_KEY, _LAST_MASK_STRATEGY_KEY, PersonaOptions, _is_assistant_persona, _persona_names_state_key, _personas_empty_message, _remembered_selectbox, _sequence_to_list, ) from utils.analysis_sources import ( Store, available_variants, load_persona_vectors_cached, load_variant_vectors_cached, persona_names_cached, personas_cached, store_cache_parts, store_id, store_layers_cached, ) from utils.controls import render_mask_strategy_select from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key from utils.theme import active_base, style_plotly_layer_controls def _gray_out_unselected_personas(fig: go.Figure) -> None: def _gray_trace(trace: object) -> None: marker = getattr(trace, "marker", None) if marker is None: return colors = _sequence_to_list(getattr(marker, "color", None)) labels = _sequence_to_list(getattr(trace, "customdata", None)) if colors is not None and labels is not None and len(colors) == len(labels): trace.marker.color = [ ( _HIGHLIGHT_OTHER_COLOR if str(label) == _HIGHLIGHT_OTHER_LABEL else color ) for label, color in zip(labels, colors, strict=True) ] return if getattr(trace, "name", None) == _HIGHLIGHT_OTHER_LABEL: trace.marker.color = _HIGHLIGHT_OTHER_COLOR trace.opacity = 0.28 for trace in fig.data: _gray_trace(trace) for frame in fig.frames: for trace in frame.data: _gray_trace(trace) def _layers_for_variant( store: Store, variant: str, persona_ids: list[str], mask_strategy: MaskStrategy, ) -> list[int]: source, location, model_name = store_cache_parts(store) return store_layers_cached( source, location, model_name, mask_strategy.value, (variant,), tuple(persona_ids), ) def _load_persona_vectors( store: Store, variant: str, mask_strategy: MaskStrategy, persona_ids: list[str], ): source, location, model_name = store_cache_parts(store) return load_persona_vectors_cached( source, location, model_name, mask_strategy.value, variant, tuple(persona_ids), ) def _load_variant_vectors( store: Store, variants: list[str] | tuple[str, ...], mask_strategy: MaskStrategy, persona_ids: list[str], ): source, location, model_name = store_cache_parts(store) return load_variant_vectors_cached( source, location, model_name, mask_strategy.value, tuple(variants), tuple(persona_ids), ) def _evenly_spaced_layers(layers: list[int], max_count: int) -> list[int]: if max_count >= len(layers): return layers if max_count <= 1: return [layers[0]] last = len(layers) - 1 indices = [round(i * last / (max_count - 1)) for i in range(max_count)] return [layers[index] for index in dict.fromkeys(indices)] def _render_layer_frame_controls( store: Store, scope: str, layers: list[int], ) -> list[int]: if len(layers) <= _DEFAULT_LAYER_FRAMES: st.caption(f"Using all {len(layers)} available layer(s).") return layers frame_count = st.slider( "Layer frames", min_value=2, max_value=len(layers), value=min( max( int( st.session_state.get( _LAST_LAYER_FRAMES_KEY, _DEFAULT_LAYER_FRAMES, ) ), 2, ), len(layers), ), key=widget_key("load", "layer_frames", scope, store_id(store)), help="Limit animated Plotly frames to keep browser and RAM usage bounded.", ) st.session_state[_LAST_LAYER_FRAMES_KEY] = frame_count selected = _evenly_spaced_layers(layers, frame_count) st.caption(f"Using {len(selected)} of {len(layers)} layers.") return selected def _load_persona_options( store: Store, variants: list[str], mask_strategy: MaskStrategy, *, empty_message: str, ) -> PersonaOptions | None: source, location, model_name = store_cache_parts(store) variant_key = tuple(variants) persona_ids = personas_cached( source, location, model_name, mask_strategy.value, variant_key, include_baseline=True, ) if not persona_ids: st.info(empty_message) return None persona_names = persona_names_cached( source, location, model_name, mask_strategy.value, variant_key, tuple(persona_ids), ) assistant_ids = [ persona_id for persona_id in persona_ids if _is_assistant_persona(persona_id, persona_names.get(persona_id)) ] assistant_id = next( ( persona_id for persona_id in assistant_ids if persona_id == BASELINE_PERSONA_ID ), assistant_ids[0] if assistant_ids else None, ) regular_ids = [ persona_id for persona_id in persona_ids if persona_id not in assistant_ids ] if not regular_ids and assistant_id is None: st.info("No personas found for this model and variant.") return None return PersonaOptions( regular_ids=regular_ids, assistant_id=assistant_id, persona_names=persona_names, ) def _seed_persona_memory( remember_key: str, options: PersonaOptions, *, default_all: bool, default_count_limit: int | None = None, ) -> tuple[int, bool]: remembered_count_key = f"{remember_key}:count" remembered_assistant_key = f"{remember_key}:include_assistant" legacy_ids = st.session_state.get(remember_key, []) if isinstance(legacy_ids, list) and legacy_ids: st.session_state.setdefault( remembered_count_key, sum(persona_id in options.regular_ids for persona_id in legacy_ids), ) st.session_state.setdefault( remembered_assistant_key, options.assistant_id in legacy_ids, ) if default_count_limit is not None: default_count = min(default_count_limit, len(options.regular_ids)) elif default_all: default_count = len(options.regular_ids) else: default_count = min(1, len(options.regular_ids)) remembered_count = int(st.session_state.get(remembered_count_key, default_count)) persona_count = min(max(remembered_count, 0), len(options.regular_ids)) include_assistant = bool(st.session_state.get(remembered_assistant_key, False)) return persona_count, include_assistant def _render_persona_count_controls( store: Store, variants: list[str], mask_strategy: MaskStrategy, widget_scope: str, options: PersonaOptions, *, default_count: int, include_assistant_default: bool, max_count_limit: int | None = None, ) -> tuple[int, bool]: count_key = widget_key( "load", "persona_count", widget_scope, store.model_name, mask_strategy.value, *variants, ) assistant_key = widget_key( "load", "include_assistant", widget_scope, store.model_name, mask_strategy.value, *variants, ) if options.regular_ids: max_count = ( min(max_count_limit, len(options.regular_ids)) if max_count_limit is not None else len(options.regular_ids) ) persona_count = st.slider( "Personas", min_value=0 if options.assistant_id is not None else 1, max_value=max_count, value=min(default_count, max_count), key=count_key, help="Use the first N available non-assistant personas.", ) else: persona_count = 0 st.caption("No non-assistant personas are available for this selection.") include_assistant = False if options.assistant_id is not None: include_assistant = st.checkbox( "Include Assistant persona", value=include_assistant_default, key=assistant_key, ) return persona_count, include_assistant def _select_artifact_personas( store: Store, variants: list[str], mask_strategy: MaskStrategy, *, widget_scope: str, remember_key: str, default_all: bool = False, default_count_limit: int | None = None, max_count_limit: int | None = None, ) -> list[str]: empty_message = _personas_empty_message(variants) options = _load_persona_options( store, variants, mask_strategy, empty_message=empty_message, ) if options is None: st.session_state.pop(_persona_names_state_key(widget_scope), None) return [] default_count, include_assistant_default = _seed_persona_memory( remember_key, options, default_all=default_all, default_count_limit=default_count_limit, ) persona_count, include_assistant = _render_persona_count_controls( store, variants, mask_strategy, widget_scope, options, default_count=default_count, include_assistant_default=include_assistant_default, max_count_limit=max_count_limit, ) persona_ids = options.regular_ids[:persona_count] if include_assistant and options.assistant_id is not None: persona_ids.append(options.assistant_id) remembered_count_key = f"{remember_key}:count" remembered_assistant_key = f"{remember_key}:include_assistant" st.session_state[remembered_count_key] = persona_count st.session_state[remembered_assistant_key] = include_assistant st.session_state[remember_key] = persona_ids st.session_state[_persona_names_state_key(widget_scope)] = options.persona_names if not persona_ids: st.info("Select at least one persona or include the Assistant persona.") return [] regular_label = f"{persona_count} persona{'s' if persona_count != 1 else ''}" assistant_label = ( " plus Assistant" if include_assistant and options.assistant_id else "" ) st.caption(f"Using {regular_label}{assistant_label}.") return persona_ids def _render_persona_select_controls( options: PersonaOptions, widget_scope: str, *, max_selections: int | None = None, ) -> list[str]: select_key = widget_key("load", "persona_select", widget_scope) assistant_key = widget_key("load", "persona_select_assistant", widget_scope) label_map = { persona_id: f"{options.persona_names.get(persona_id, persona_id)} ({persona_id})" for persona_id in options.regular_ids } sorted_labels = sorted(label_map.values()) selected_labels = st.multiselect( "Select personas", options=sorted_labels, key=select_key, placeholder="Search and select personas...", max_selections=max_selections, ) label_to_id = {label: persona_id for persona_id, label in label_map.items()} selected_ids = [label_to_id[label] for label in selected_labels] if options.assistant_id is not None: include_assistant = st.checkbox( "Include Assistant persona", key=assistant_key, ) if include_assistant: selected_ids.append(options.assistant_id) st.session_state[_persona_names_state_key(widget_scope)] = dict( options.persona_names ) if not selected_ids: st.info("Select at least one persona.") return selected_ids def _render_save_buttons( figs: list[object], filenames: list[str], key_suffix: str, ) -> None: """Render the Save HTML button for one or more figures.""" if st.button("Save HTML", key=widget_key("load", "save_html", key_suffix)): try: _style_plotly_figures(figs) paths = [ save_plot_html(fig, fn) for fig, fn in zip(figs, filenames, strict=True) ] st.success(f"Saved {len(paths)} HTML file(s) to `artifacts/plots`.") except Exception as exc: st.error(f"Could not save HTML: {exc}") def _style_plotly_figures(figs: list[object]) -> None: base = active_base() for fig in figs: if isinstance(fig, go.Figure): style_plotly_layer_controls(fig, base) def _plotly_chart(fig: object) -> None: _style_plotly_figures([fig]) st.plotly_chart( fig, width="stretch", config={"responsive": True, "displaylogo": False}, ) def _render_mask_strategy_select(scope: str) -> MaskStrategy: return render_mask_strategy_select( key=widget_key("load", "mask_strategy", scope), last_key=_LAST_MASK_STRATEGY_KEY, remember_key="source:last_mask_strategy", help_text="Which extracted activation set to load.", ) def _select_single_variant_samples( store: Store, mask_strategy: MaskStrategy, scope: str, *, remember_key: str, variant_remember_key: str, default_count_limit: int, max_count_limit: int | None = None, allow_specific_personas: bool = False, ) -> tuple[str, list[str], str, list[int]] | None: variants = available_variants(store, mask_strategy) if not variants: st.info("No variants with saved vectors for this model.") return None variant_key = widget_key("load", "variant", scope, store_id(store)) default_variant = "biography" if "biography" in variants else variants[0] variant = _remembered_selectbox( "Variant", key=variant_key, remember_key=variant_remember_key, options=variants, default=default_variant, format_func=prompt_variant_label, ) widget_scope = f"{scope}:{store_id(store)}" select_specific = False if allow_specific_personas: select_specific = st.toggle( "Select specific personas", value=False, key=widget_key("load", "select_specific_personas", scope, store_id(store)), help="Search and select specific personas instead of using the first N.", ) if select_specific: options = _load_persona_options( store, [variant], mask_strategy, empty_message=_personas_empty_message([variant]), ) if options is None: st.session_state.pop(_persona_names_state_key(widget_scope), None) return None persona_ids = _render_persona_select_controls( options, widget_scope, max_selections=max_count_limit, ) else: persona_ids = _select_artifact_personas( store, [variant], mask_strategy, widget_scope=widget_scope, remember_key=remember_key, default_count_limit=default_count_limit, max_count_limit=max_count_limit, ) if not persona_ids: return None persona_key = personas_fingerprint(persona_ids) layer_options = _layers_for_variant(store, variant, persona_ids, mask_strategy) if not layer_options: st.info("No shared layers are available for the selected personas.") return None selected_layers = _render_layer_frame_controls(store, scope, layer_options) return variant, persona_ids, persona_key, selected_layers