from __future__ import annotations from pathlib import Path import streamlit as st import torch from utils.chat import build_chat_messages from utils.helpers import env_int, session_key, widget_key from utils.probe_files import ( DEFAULT_LOCAL_PROBE_DIR, DEFAULT_PROBE_REPO, list_local_probe_files, list_probe_files, model_probe_dir_name, parse_probe_filename, ) from utils.probe_overlay import ( attach_overlays, build_classification_overlays, build_regression_overlays, clear_overlays, ) from utils.probe_trace import ConversationTrace, trace_conversation from utils.probes import ( LoadedProbe, load_local_probe, load_probe, load_probe_from_bytes, ) from utils.runtime import cached_model, session_ndif_api_key from utils.selection_controls import remembered_segmented_control _LAST_SOURCE_KEY = session_key("probe", "last_source") _LAST_LOCAL_FILE_KEY = session_key("probe", "last_local_file") _LAST_HUB_FILE_KEY = session_key("probe", "last_hub_file") _PROBE_SOURCES = ("Local artifact", "Hugging Face repo", "Upload .pt") _DERIVED_CACHE_TRACKER_KEY = session_key("probe", "derived_cache_keys") # Keep enough room for the three retained traces plus a few recently explored # probes per trace. Derived outputs are much smaller than the trace activations # themselves, so this avoids needless recomputation without reopening # unbounded growth. _DERIVED_CACHE_ENTRIES = env_int("PERSONA_UI_PROBE_DERIVED_CACHE_ENTRIES", 12) # --------------------------------------------------------------------------- # Probe selection # --------------------------------------------------------------------------- def _probe_label(filename: str) -> str: metadata = parse_probe_filename(filename) prefix = f"{metadata.model_name} / " if metadata.model_name else "" return f"{prefix}{metadata.label}" def _model_compatible_files(files: list[str], model_name: str) -> list[str]: model_dir = model_probe_dir_name(model_name) compatible = [ filename for filename in files if Path(filename).parts and Path(filename).parts[0] == model_dir ] return compatible or files def _default_file(files: list[str], remembered: str | None) -> str: if remembered and remembered in files: return remembered return files[0] def _render_probe_selector(*, context_key: str, model_name: str) -> LoadedProbe | None: """Inline source + file selector. Returns the loaded probe or None.""" source = remembered_segmented_control( "Probe source", options=_PROBE_SOURCES, key=widget_key(context_key, "probe_source"), remember_key=_LAST_SOURCE_KEY, default=_PROBE_SOURCES[0], label_visibility="collapsed", ) if source == "Local artifact": return _render_local_probe(context_key=context_key, model_name=model_name) if source == "Hugging Face repo": return _render_hub_probe(context_key=context_key, model_name=model_name) return _render_upload_probe(context_key=context_key) def _render_local_probe(*, context_key: str, model_name: str) -> LoadedProbe | None: root_dir = st.text_input( "Probe directory", value=st.session_state.get( widget_key(context_key, "probe_local_dir"), DEFAULT_LOCAL_PROBE_DIR ), key=widget_key(context_key, "probe_local_dir"), ) files = list_local_probe_files(root_dir.strip()) if not files: st.warning("No probe files found in that directory.") return None files = _model_compatible_files(files, model_name) default = _default_file(files, st.session_state.get(_LAST_LOCAL_FILE_KEY)) selected = st.selectbox( "Probe", options=files, index=files.index(default), format_func=_probe_label, key=widget_key(context_key, "probe_local_file"), ) st.session_state[_LAST_LOCAL_FILE_KEY] = selected try: return load_local_probe(root_dir.strip(), selected) except Exception as exc: st.error(f"Could not load probe: {exc}") return None def _render_hub_probe(*, context_key: str, model_name: str) -> LoadedProbe | None: repo_id = st.text_input( "Probe repo", value=st.session_state.get( widget_key(context_key, "probe_repo"), DEFAULT_PROBE_REPO ), key=widget_key(context_key, "probe_repo"), ) if not repo_id.strip(): return None files = list_probe_files(repo_id.strip()) if not files: st.warning("No probe files found in that repo.") return None files = _model_compatible_files(files, model_name) default = _default_file(files, st.session_state.get(_LAST_HUB_FILE_KEY)) selected = st.selectbox( "Probe", options=files, index=files.index(default), format_func=_probe_label, key=widget_key(context_key, "probe_hub_file"), ) st.session_state[_LAST_HUB_FILE_KEY] = selected try: return load_probe(repo_id.strip(), selected) except Exception as exc: st.error(f"Could not load probe: {exc}") return None def _render_upload_probe(*, context_key: str) -> LoadedProbe | None: uploaded = st.file_uploader( "Upload probe (.pt)", type=["pt"], key=widget_key(context_key, "probe_upload"), ) if uploaded is None: return None try: return load_probe_from_bytes(uploaded.name, uploaded.getvalue()) except Exception as exc: st.error(f"Could not load probe: {exc}") return None # --------------------------------------------------------------------------- # Probe card + target validation # --------------------------------------------------------------------------- def _render_probe_card(probe: LoadedProbe) -> None: parts: list[str] = [] if probe.attribute_name: parts.append(f"**{probe.attribute_name}**") parts.append(f"layer `{probe.layer if probe.layer is not None else '?'}`") parts.append(f"kind `{probe.model_type}`") if probe.feature_space: parts.append(f"`{probe.feature_space}`") if probe.location: parts.append(f"`{probe.location}`") classes = ( ", ".join(label for label in probe.labels if label) or f"{len(probe.labels)} classes" ) parts.append(f"classes: {classes}") st.markdown("  ·  ".join(parts)) def _model_dimensions(model: object) -> tuple[int, int]: config = getattr(model, "config", None) hidden_size = getattr(model, "hidden_size", None) or getattr( config, "hidden_size", None ) num_layers = ( getattr(model, "num_layers", None) or getattr(config, "num_hidden_layers", None) or getattr(config, "n_layer", None) ) if hidden_size is None or num_layers is None: raise ValueError("Could not read hidden_size and num_layers from the model.") return int(hidden_size), int(num_layers) def _resolve_target( *, probe: LoadedProbe, context_key: str, num_layers: int ) -> tuple[int, str]: layer = probe.layer if layer is None: layer = int( st.number_input( "Layer (probe did not specify one)", min_value=0, max_value=max(0, num_layers - 1), value=min(15, max(0, num_layers - 1)), step=1, key=widget_key(context_key, "probe_layer"), ) ) location = probe.location if location is None: location = st.selectbox( "Activation location (probe did not specify one)", options=("post_reasoning", "pre_reasoning"), key=widget_key(context_key, "probe_location"), ) return layer, location def _validate( *, probe: LoadedProbe, layer: int, num_layers: int, hidden_size: int ) -> bool: if not 0 <= layer < num_layers: st.error(f"Probe layer {layer} is outside the model's {num_layers} layers.") return False if probe.input_dim != hidden_size: st.warning( f"Probe input dim ({probe.input_dim}) does not match the model's hidden " f"size ({hidden_size}). Predictions will not be meaningful." ) return False return True # --------------------------------------------------------------------------- # Cached batched probe forward # --------------------------------------------------------------------------- def _store_derived_cache(key: str, value: object) -> None: """Store one derived probe result while keeping a small MRU window.""" tracked = st.session_state.setdefault(_DERIVED_CACHE_TRACKER_KEY, []) if not isinstance(tracked, list): tracked = [] tracked = [existing for existing in tracked if existing != key] tracked.append(key) while len(tracked) > _DERIVED_CACHE_ENTRIES: st.session_state.pop(tracked.pop(0), None) st.session_state[_DERIVED_CACHE_TRACKER_KEY] = tracked st.session_state[key] = value def _get_derived_cache(key: str) -> object | None: """Return a derived probe result and refresh its MRU position.""" cached = st.session_state.get(key) if cached is None: return None tracked = st.session_state.get(_DERIVED_CACHE_TRACKER_KEY) if isinstance(tracked, list) and key in tracked: tracked = [existing for existing in tracked if existing != key] tracked.append(key) st.session_state[_DERIVED_CACHE_TRACKER_KEY] = tracked return cached def _classification_predictions( probe: LoadedProbe, activations: torch.Tensor, cache_key: str ) -> tuple[torch.Tensor, torch.Tensor]: full_key = widget_key("probe_predictions", cache_key, str(id(probe))) cached = _get_derived_cache(full_key) if cached is not None: return cached _, probs, predicted = probe.run_batch(activations) _store_derived_cache(full_key, (probs, predicted)) return probs, predicted def _regression_values( probe: LoadedProbe, activations: torch.Tensor, cache_key: str ) -> torch.Tensor: full_key = widget_key("probe_values", cache_key, str(id(probe))) cached = _get_derived_cache(full_key) if cached is not None: return cached values = probe.predict_batch(activations) _store_derived_cache(full_key, values) return values # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- def _has_assistant_message(messages: list[dict]) -> bool: return any(m.get("role") == "assistant" and m.get("content") for m in messages) def _apply_overlays( *, probe: LoadedProbe, trace: ConversationTrace, messages: list[dict] ) -> bool: if probe.is_regression: values = _regression_values(probe, trace.activations, trace.cache_key) overlays = build_regression_overlays( trace=trace, values=values, labels=probe.labels, attribute_name=probe.attribute_name, ) else: probs, predicted = _classification_predictions( probe, trace.activations, trace.cache_key ) binary = probs.shape[1] == 1 or (probs.shape[1] == 2 and len(probe.labels) == 2) overlays = build_classification_overlays( trace=trace, probs=probs, predicted=predicted, labels=probe.labels, binary=binary, attribute_name=probe.attribute_name, ) attach_overlays(messages, overlays) return bool(overlays) def render_probe_inspector( *, context_key: str, model_name: str, remote: bool, active_system_prompt: str | None, chat_state: dict[str, object], enabled: bool, ) -> None: messages: list[dict] = chat_state["messages"] # type: ignore[assignment] if not enabled: clear_overlays(messages) return status_key = widget_key(context_key, "probe_status") sig_key = widget_key(context_key, "probe_scored_sig") def _conversation_sig() -> int: return hash( tuple( (m.get("role"), m.get("content")) for m in messages if m.get("content") ) ) def _reset() -> None: clear_overlays(messages) st.session_state.pop(status_key, None) st.session_state.pop(sig_key, None) with st.expander("Probe", expanded=True): if not _has_assistant_message(messages): _reset() st.caption("Probe overlay shows up after the first assistant reply.") return probe = _render_probe_selector(context_key=context_key, model_name=model_name) if probe is None: _reset() return _render_probe_card(probe) model = cached_model(model_name=model_name) try: hidden_size, num_layers = _model_dimensions(model) except Exception as exc: _reset() st.error(str(exc)) return layer, location = _resolve_target( probe=probe, context_key=context_key, num_layers=num_layers ) if not _validate( probe=probe, layer=layer, num_layers=num_layers, hidden_size=hidden_size ): _reset() return # The probe scores via a separate forward pass over the whole # conversation, so it's fully decoupled from generation: pick or switch # probes any time and score on demand. Gate that pass behind a button # instead of re-running it on every Streamlit rerun. Overlays live on # the message dicts, so they persist across reruns until refreshed. run = st.button( "Run probe", type="primary", key=widget_key(context_key, "probe_run"), help="Score the current conversation with the selected probe.", ) if not run: status = st.session_state.get(status_key) if not status: st.caption("Press **Run probe** to score the conversation.") elif st.session_state.get(sig_key) != _conversation_sig(): # Conversation changed since it was scored: drop the now-stale # overlay so it can't paint over edited/new text. clear_overlays(messages) st.caption("Conversation changed — press **Run probe** to refresh.") else: st.caption(f"{status} · press **Run probe** to refresh.") return chat_messages = build_chat_messages(active_system_prompt, messages) with st.spinner("Tracing conversation..."): try: trace = trace_conversation( model=model, model_name=model_name, messages=chat_messages, layer=layer, location=location, remote=remote, ndif_api_key=session_ndif_api_key(), ) except Exception as exc: _reset() st.error(f"Trace failed: {exc}") return if not trace.assistant_spans: _reset() st.warning( "Could not locate assistant tokens in the traced sequence, so " "the overlay can't be aligned to message bodies." ) return try: applied = _apply_overlays(probe=probe, trace=trace, messages=messages) except Exception as exc: _reset() st.error(f"Probe execution failed: {exc}") return if not applied: _reset() return n_body = sum( sum(1 for i in range(s, e) if not bool(trace.is_special[i].item())) for s, e in trace.assistant_spans ) kind = "regression" if probe.is_regression else "classification" status = ( f"{kind} · {len(trace.assistant_spans)} assistant message(s) · " f"{n_body} body tokens · layer {trace.layer} · {trace.location}" ) st.session_state[status_key] = status st.session_state[sig_key] = _conversation_sig() st.caption(status)