import html from dataclasses import dataclass import streamlit as st from catppuccin import PALETTE from persona_data.prompts import format_prompt from persona_data.synth_persona import BASELINE_PERSONA_ID, PersonaData, QAPair from persona_vectors.artifacts import SUPPORTED_VARIANTS from persona_vectors.extraction import ( MaskStrategy, prepare_inputs_for_strategy, run_extraction, ) from persona_vectors.preview import TokenSegment, preview_token_segments from utils.controls import render_mask_strategy_select from utils.datasets import ( load_dataset, load_persona_list_from_dataset, warm_qa_in_background, ) from utils.helpers import ( format_ndif_status, persona_label, prompt_variant_label, session_key, widget_key, ) from utils.runtime import cached_model, remote_backend, session_ndif_api_key from utils.theme import active_base _LAST_VARIANTS_KEY = "extract:last_variants" _LAST_BASELINE_KEY = "extract:last_include_baseline" _LAST_PERSONA_IDS_KEY = "extract:last_persona_ids" _LAST_MAX_QUESTIONS_KEY = "extract:last_max_questions" _LAST_MASK_STRATEGY_KEY = "extract:last_mask_strategy" _PERSONAS_FILE_KEY = session_key("extract", "personas_file") _QA_FILE_KEY = session_key("extract", "qa_file") _DEFAULT_MAX_QUESTIONS = 50 @dataclass(frozen=True) class ExtractSettings: mask_strategy: MaskStrategy max_questions: int def _build_run_plan( selected_variants: list[str], runs: list[tuple[PersonaData, list[QAPair]]], ) -> list[tuple[PersonaData, list[QAPair], str]]: """Cartesian product of personas x variants.""" return [(p, qa, v) for v in selected_variants for p, qa in runs] def _row_label(persona: PersonaData, variant: str) -> str: return f"{persona.name} · {prompt_variant_label(variant)}" def _extract_widget_key( model_name: str, remote: bool, dataset_source: str, suffix: str ) -> str: return widget_key("extract", str(remote), model_name, dataset_source, suffix) def _render_local_dataset_upload(dataset_source: str) -> None: if dataset_source != "Local JSONL upload": return with st.expander("Local dataset upload", expanded=True): st.file_uploader( "personas.jsonl", type=["jsonl"], key=_PERSONAS_FILE_KEY, help="Expected fields: id, persona, templated_view, biography_view", ) st.file_uploader( "qa.jsonl", type=["jsonl"], key=_QA_FILE_KEY, help="Expected fields: id, qid, type, item_type, scope, question, answer", ) def _render_variant_controls( *, model_name: str, remote: bool, dataset_source: str, ) -> tuple[list[str], bool] | None: default_variants = st.session_state.get( _LAST_VARIANTS_KEY, list(SUPPORTED_VARIANTS) ) selected_variants = st.multiselect( "Persona variants", options=SUPPORTED_VARIANTS, default=[v for v in default_variants if v in SUPPORTED_VARIANTS] or list(SUPPORTED_VARIANTS), format_func=prompt_variant_label, key=_extract_widget_key(model_name, remote, dataset_source, "persona_variants"), help="Extract these variants for each selected persona.", ) include_baseline = st.checkbox( "Extract Assistant baseline", value=st.session_state.get(_LAST_BASELINE_KEY, False), key=_extract_widget_key(model_name, remote, dataset_source, "baseline"), help="Also extract the Assistant baseline persona using the first persona's QA set.", ) st.session_state[_LAST_VARIANTS_KEY] = selected_variants st.session_state[_LAST_BASELINE_KEY] = include_baseline if not selected_variants: st.info("Select at least one persona variant.") return None return selected_variants, include_baseline def _load_qa_dataset_personas( dataset_source: str, ) -> tuple[object, list[PersonaData]] | None: try: dataset, dataset_status = load_dataset( dataset_source, personas_file=st.session_state.get(_PERSONAS_FILE_KEY), qa_file=st.session_state.get(_QA_FILE_KEY), ) personas = load_persona_list_from_dataset(dataset) st.caption(dataset_status) except Exception as exc: st.error(f"Could not load data: {exc}") st.info( "Upload both JSONL files or switch to the built-in SynthPersona source." ) return None if not getattr(dataset, "supports_qa", True): st.info("This dataset is persona-only for now. Use Chat to browse personas.") return None if not personas: st.warning("No personas found in the selected dataset.") st.info( "Try another dataset source or check that the personas file is not empty." ) return None # Extract is the only tab that needs QA; warm it now so the parse overlaps # with the user configuring the run instead of blocking the first extract. warm_qa_in_background(dataset) return dataset, personas def _render_persona_select( *, personas: list[PersonaData], model_name: str, remote: bool, dataset_source: str, ) -> list[PersonaData] | None: last_persona_ids: set[str] = set(st.session_state.get(_LAST_PERSONA_IDS_KEY, [])) default_personas = [p for p in personas if p.id in last_persona_ids] or [ personas[0] ] selected_personas = st.multiselect( "Personas", options=personas, default=default_personas, format_func=persona_label, key=_extract_widget_key(model_name, remote, dataset_source, "persona_select"), ) st.session_state[_LAST_PERSONA_IDS_KEY] = [p.id for p in selected_personas] if not selected_personas: st.info("Select at least one persona.") return None return selected_personas _MAX_PREVIEW_SAMPLES = 3 def _preview_palette(): flavor = PALETTE.latte if active_base() == "light" else PALETTE.mocha return flavor.colors def _render_token_legend_html() -> str: c = _preview_palette() return ( '
' f'masked' f'question' f'response' f'special' f'template' "
" ) def _token_style(segment: TokenSegment) -> str: c = _preview_palette() style = { "response": f"color:{c.sky.hex}", "question": f"color:{c.yellow.hex}", }.get(segment.role, f"color:{c.subtext1.hex}") if segment.is_special: style = f"color:{c.mauve.hex};font-weight:bold" if segment.is_masked: style = ( f"{style};background:{c.green.hex};color:{c.base.hex};" "border-radius:2px;padding:0 1px" ) return style def _render_sample_tokens_html(p, tokenizer, *, max_tokens: int = 200) -> str: spans: list[str] = [] for segment in preview_token_segments(p, tokenizer, max_tokens=max_tokens): spans.append( f'{html.escape(segment.text)}' ) return ( '
'
        f"{''.join(spans)}
" ) def _render_mask_strategy_select( *, model_name: str, remote: bool, dataset_source: str, ) -> MaskStrategy: return render_mask_strategy_select( key=_extract_widget_key(model_name, remote, dataset_source, "mask_strategy"), last_key=_LAST_MASK_STRATEGY_KEY, help_text="Which tokens contribute to the averaged hidden state.", ) def _collect_runs( *, dataset, selected_personas: list[PersonaData], ) -> list[tuple[PersonaData, list[QAPair]]] | None: runs, skipped = [], [] for persona in selected_personas: if persona.id == BASELINE_PERSONA_ID: qa = list( dataset.get_qa(BASELINE_PERSONA_ID, item_type="mcq", scope="shared") ) elif hasattr(dataset, "train_test_split"): qa, _ = dataset.train_test_split(persona.id) else: qa = list(dataset.get_qa(persona.id)) if qa: runs.append((persona, qa)) else: skipped.append(persona) if skipped: names = ", ".join(p.name for p in skipped) st.warning(f"No train QA pairs found for: {names}. They will be skipped.") if not runs: st.info("No personas have matching QA pairs.") return None return runs def _render_max_questions( *, model_name: str, remote: bool, dataset_source: str, runs: list[tuple[PersonaData, list[QAPair]]], ) -> int: max_q = min(len(qa_pairs) for _, qa_pairs in runs) default = min(_DEFAULT_MAX_QUESTIONS, max_q) max_questions = st.slider( "Max questions (train split)", min_value=1, max_value=max_q, value=min( max(st.session_state.get(_LAST_MAX_QUESTIONS_KEY, default), 1), max_q ), key=_extract_widget_key(model_name, remote, dataset_source, "max_questions"), ) st.session_state[_LAST_MAX_QUESTIONS_KEY] = max_questions return max_questions def _render_extract_actions() -> tuple[bool, bool]: run_col, preview_col, _spacer = st.columns([1, 1, 4], gap="small") with run_col: run_clicked = st.button( "Run extraction", type="primary", width="stretch", ) with preview_col: preview_clicked = st.button("Preview tokens", width="stretch") return run_clicked, preview_clicked def _render_token_preview( *, model_name: str, run_plan: list[tuple[PersonaData, list[QAPair], str]], settings: ExtractSettings, ) -> None: with st.spinner("Loading tokenizer..."): model = cached_model(model_name=model_name) st.markdown(_render_token_legend_html(), unsafe_allow_html=True) for persona, qa_pairs, variant in run_plan: system_prompt = format_prompt(persona, variant) # type: ignore[arg-type] prepared = prepare_inputs_for_strategy( tokenizer=model.tokenizer, system_prompt=system_prompt, qa_pairs=qa_pairs[: settings.max_questions], mask_strategy=settings.mask_strategy, ) st.caption(_row_label(persona, variant)) for i, p in enumerate(prepared[:_MAX_PREVIEW_SAMPLES]): question = p.question if len(p.question) <= 60 else p.question[:57] + "..." seq_len = int(p.input_ids.shape[0]) masked = int(p.token_mask.sum()) label = f"sample {i} — {question} (len={seq_len}, masked={masked})" with st.expander(label): st.markdown( _render_sample_tokens_html(p, model.tokenizer), unsafe_allow_html=True, ) if len(prepared) > _MAX_PREVIEW_SAMPLES: remaining = len(prepared) - _MAX_PREVIEW_SAMPLES st.caption(f"… and {remaining} more sample(s) not shown.") def _run_extraction_plan( *, remote: bool, model_name: str, run_plan: list[tuple[PersonaData, list[QAPair], str]], settings: ExtractSettings, ) -> None: status_box = st.empty() status_box.info("Extraction in progress...") progress = st.progress(0, text="Preparing extraction...") ndif_status_box = st.empty() def _on_ndif_status(job_id: str, status_name: str, description: str) -> None: ndif_status_box.caption(format_ndif_status(job_id, status_name, description)) with st.spinner("Loading model..."): model = cached_model(model_name=model_name) try: total_steps = len(run_plan) results = [] for step, (persona, qa_pairs, variant) in enumerate(run_plan): progress.progress( step / total_steps if total_steps else 1.0, text=f"{_row_label(persona, variant)} ({step + 1}/{total_steps})", ) selected_qa = qa_pairs[: settings.max_questions] results.extend( run_extraction( model=model, model_name=model_name, qa_pairs=selected_qa, variants=(variant,), persona=persona, mask_strategy=settings.mask_strategy, remote=remote, on_status=_on_ndif_status if remote else None, backend_factory=( ( lambda: remote_backend( model, session_ndif_api_key(), on_status=_on_ndif_status, ) ) if remote else None ), ) ) progress.progress(1.0, text="Extraction complete") except Exception as exc: st.error(f"Extraction failed: {exc}") return finally: progress.empty() ndif_status_box.empty() status_box.empty() st.success(f"Saved {len(results)} artifact set(s)") for result in results: st.markdown( f"- **{result.persona_name}** · {prompt_variant_label(result.variant)}: " f"{result.n_questions} questions" ) def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> None: """Render the extraction tab.""" st.title("Extract") st.caption("Extract per-persona activation vectors from train QA pairs.") _render_local_dataset_upload(dataset_source) variant_choice = _render_variant_controls( model_name=model_name, remote=remote, dataset_source=dataset_source, ) if variant_choice is None: return selected_variants, include_baseline = variant_choice loaded = _load_qa_dataset_personas(dataset_source) if loaded is None: return dataset, personas = loaded selected_personas = _render_persona_select( personas=personas, model_name=model_name, remote=remote, dataset_source=dataset_source, ) if selected_personas is None: return personas_for_runs = list(selected_personas) baseline = getattr(dataset, "baseline", None) if include_baseline and baseline is not None: personas_for_runs.append(baseline) runs = _collect_runs(dataset=dataset, selected_personas=personas_for_runs) if runs is None: return max_questions = _render_max_questions( model_name=model_name, remote=remote, dataset_source=dataset_source, runs=runs, ) with st.expander("Advanced", expanded=False): mask_strategy = _render_mask_strategy_select( model_name=model_name, remote=remote, dataset_source=dataset_source, ) settings = ExtractSettings( mask_strategy=mask_strategy, max_questions=max_questions, ) run_clicked, preview_clicked = _render_extract_actions() run_plan = _build_run_plan(selected_variants, runs) if preview_clicked: _render_token_preview( model_name=model_name, run_plan=run_plan, settings=settings, ) return if not run_clicked: return _run_extraction_plan( remote=remote, model_name=model_name, run_plan=run_plan, settings=settings, )