from __future__ import annotations import io from dataclasses import dataclass from pathlib import Path from typing import Any import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F from persona_vectors.probes import ProbeArtifact, load_probe_artifact from utils.helpers import env_int from utils.probe_files import ( download_probe_file, download_probe_json_and_weights, parse_probe_filename, ) _PROBE_CACHE_ENTRIES = env_int("PERSONA_UI_PROBE_CACHE_ENTRIES", 8) @dataclass(frozen=True) class ProbeRunResult: input_dim: int logits: torch.Tensor probabilities: torch.Tensor predicted_index: int predicted_label: str | None class _LinearProbe(nn.Module): def __init__(self, input_dim: int, num_classes: int): super().__init__() self.linear = nn.Linear(input_dim, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) class _MLPProbe(nn.Module): def __init__( self, input_dim: int, hidden_dims: list[int], num_classes: int, dropout: float, ): super().__init__() if not hidden_dims: raise ValueError("MLP probe requires at least one hidden dimension") layers: list[nn.Module] = [] prev_dim = input_dim for hidden_dim in hidden_dims: layers.append(nn.Linear(prev_dim, hidden_dim)) layers.append(nn.ReLU()) layers.append(nn.Dropout(dropout)) prev_dim = hidden_dim layers.append(nn.Linear(prev_dim, num_classes)) self.network = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.network(x) @dataclass class LoadedProbe: model: nn.Module input_dim: int labels: list[str | None] model_type: str layer: int | None location: str | None model_name: str | None = None attribute_name: str | None = None feature_space: str | None = None task: str | None = None probe_kind: str | None = None scaler_mean: torch.Tensor | None = None scaler_std: torch.Tensor | None = None pca_mean: torch.Tensor | None = None pca_components: torch.Tensor | None = None def __post_init__(self) -> None: self.model.eval() @property def is_regression(self) -> bool: """True when the probe outputs a continuous value rather than a class.""" if self.task is not None: return self.task in {"numeric", "ordinal"} if self.probe_kind is not None: return self.probe_kind == "ridge_regression" return False def predict_batch(self, activations: torch.Tensor) -> torch.Tensor: """Return raw linear-output values for each token — no sigmoid/softmax.""" if activations.ndim != 2: raise ValueError( f"predict_batch expects [N, hidden], got {tuple(activations.shape)}" ) if activations.shape[1] != self.input_dim: raise ValueError( f"Probe expects input dim {self.input_dim}, got {activations.shape[1]}" ) batch = activations.detach().to(dtype=torch.float32, device="cpu") normalized = self._normalize_batch(batch) with torch.no_grad(): outputs = self.model(normalized).detach().cpu() if outputs.ndim == 1: outputs = outputs.unsqueeze(-1) return outputs def run(self, vector: torch.Tensor) -> ProbeRunResult: if vector.ndim != 1: raise ValueError( f"Probe expects a 1D activation vector, got shape {tuple(vector.shape)}" ) if vector.shape[0] != self.input_dim: raise ValueError( f"Probe expects input dim {self.input_dim}, got {vector.shape[0]}" ) batch = vector.detach().to(dtype=torch.float32, device="cpu").unsqueeze(0) logits_batch, probs_batch = self._forward_batch(batch) logits = logits_batch.squeeze(0) probs = probs_batch.squeeze(0) if logits.ndim == 0: logits = logits.unsqueeze(0) if probs.ndim == 0: probs = probs.unsqueeze(0) predicted_index = ( int(probs.item() >= 0.5) if probs.numel() == 1 else int(torch.argmax(probs).item()) ) predicted_label = ( self.labels[predicted_index] if 0 <= predicted_index < len(self.labels) else None ) return ProbeRunResult( input_dim=int(vector.shape[0]), logits=logits, probabilities=probs, predicted_index=predicted_index, predicted_label=predicted_label, ) def run_batch( self, activations: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Run the probe over a batch of activations. Returns ``(logits[N, C], probs[N, C], predicted_index[N])``. For single-output probes ``C == 1`` and ``probs`` holds sigmoid scores. """ if activations.ndim != 2: raise ValueError( f"run_batch expects [N, hidden], got {tuple(activations.shape)}" ) if activations.shape[1] != self.input_dim: raise ValueError( f"Probe expects input dim {self.input_dim}, got {activations.shape[1]}" ) batch = activations.detach().to(dtype=torch.float32, device="cpu") logits, probs = self._forward_batch(batch) if probs.shape[-1] == 1: predicted = (probs.squeeze(-1) >= 0.5).long() else: predicted = torch.argmax(probs, dim=-1) return logits, probs, predicted def _forward_batch(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: normalized = self._normalize_batch(batch) with torch.no_grad(): logits = self.model(normalized).detach().cpu() if logits.ndim == 1: logits = logits.unsqueeze(-1) if logits.shape[-1] == 1: probs = torch.sigmoid(logits) else: probs = F.softmax(logits, dim=-1) return logits, probs def _normalize_batch(self, batch: torch.Tensor) -> torch.Tensor: if self.scaler_mean is not None and self.scaler_std is not None: mean = self.scaler_mean.to(dtype=torch.float32) std = self.scaler_std.to(dtype=torch.float32) if mean.ndim != 1 or std.ndim != 1 or mean.shape[0] != batch.shape[1]: raise ValueError( "Probe scaler shape does not match activation hidden size: " f"mean={tuple(mean.shape)} std={tuple(std.shape)} " f"batch={tuple(batch.shape)}" ) safe_std = torch.where(std == 0, torch.ones_like(std), std) batch = (batch - mean) / safe_std if self.pca_mean is not None and self.pca_components is not None: pca_mean = self.pca_mean.to(dtype=torch.float32) components = self.pca_components.to(dtype=torch.float32) if pca_mean.ndim != 1 or pca_mean.shape[0] != batch.shape[1]: raise ValueError( "Probe PCA mean shape does not match activation hidden size: " f"mean={tuple(pca_mean.shape)} batch={tuple(batch.shape)}" ) batch = (batch - pca_mean) @ components.T return batch @st.cache_resource(show_spinner=False, max_entries=_PROBE_CACHE_ENTRIES) def load_probe(repo_id: str, filename: str) -> LoadedProbe: if filename.endswith("probe.json"): metadata_path, weights_path = download_probe_json_and_weights(repo_id, filename) return _load_persona_probe_artifact( filename=filename, metadata_path=Path(metadata_path), weights_path=Path(weights_path), ) path = download_probe_file(repo_id, filename) return _load_probe_payload( filename=filename, payload=_torch_load(path), ) @st.cache_resource(show_spinner=False, max_entries=_PROBE_CACHE_ENTRIES) def load_local_probe(root_dir: str, filename: str) -> LoadedProbe: root = Path(root_dir).expanduser() path = (root / filename).resolve() if root.resolve() not in path.parents: raise ValueError("Probe path must stay inside the selected local directory.") if path.name == "probe.json": return _load_persona_probe_artifact( filename=filename, metadata_path=path, weights_path=path.with_name("weights.safetensors"), ) if path.name == "weights.safetensors": return _load_persona_probe_artifact( filename=filename, metadata_path=path.with_name("probe.json"), weights_path=path, ) return _load_probe_payload( filename=filename, payload=_torch_load(path), ) @st.cache_resource(show_spinner=False, max_entries=_PROBE_CACHE_ENTRIES) def load_probe_from_bytes(filename: str, data: bytes) -> LoadedProbe: return _load_probe_payload( filename=filename, payload=_torch_load(io.BytesIO(data)), ) def _load_probe_payload( *, filename: str, payload: object, ) -> LoadedProbe: if not isinstance(payload, dict): raise TypeError(f"Probe payload must be a dict, got {type(payload)!r}") metadata = parse_probe_filename(filename) state_dict = _get_state_dict(payload) input_dim = _coerce_probe_dim(payload.get("input_dim"), state_dict, dim="input") model_input_dim = _coerce_probe_dim( payload.get("artifact_feature_dim") or input_dim, state_dict, dim="input", ) num_classes = _coerce_probe_dim( payload.get("num_classes"), state_dict, dim="classes" ) model = _build_probe_module( payload, state_dict=state_dict, input_dim=model_input_dim, num_classes=num_classes, ) labels = _normalize_labels( payload.get("idx_to_label") or payload.get("class_names"), num_classes, ) raw_layer = payload.get("layer") try: layer = int(raw_layer) if raw_layer is not None else metadata.layer except (TypeError, ValueError): layer = metadata.layer raw_location = payload.get("location") location = ( raw_location if isinstance(raw_location, str) and raw_location else metadata.location ) return LoadedProbe( model=model, input_dim=input_dim, labels=labels, model_type=str(payload.get("model_type") or metadata.model_type), layer=layer, location=location, model_name=_optional_str(payload.get("model_name")) or metadata.model_name, attribute_name=( _optional_str(payload.get("attribute_name")) or metadata.attribute_name ), feature_space=( ( f"pca{payload['n_pca_components']}" if payload.get("n_pca_components") else None ) or _optional_str(payload.get("feature_space")) or metadata.scope ), task=_optional_str(payload.get("task")), probe_kind=_optional_str(payload.get("probe_kind")), scaler_mean=_as_cpu_tensor(payload.get("scaler_mean")), scaler_std=_as_cpu_tensor( _first_present(payload, "scaler_std", "scaler_scale") ), pca_mean=_as_cpu_tensor(payload.get("pca_mean")), pca_components=_as_cpu_tensor(payload.get("pca_components")), ) def _torch_load(file_or_buffer: object) -> object: return torch.load(file_or_buffer, map_location="cpu", weights_only=True) def _load_persona_probe_artifact( *, filename: str, metadata_path: Path, weights_path: Path, ) -> LoadedProbe: if metadata_path.parent != weights_path.parent: raise ValueError("Canonical probe files must share one artifact directory.") artifact = load_probe_artifact(metadata_path) return _loaded_probe_from_artifact(filename=filename, artifact=artifact) def _loaded_probe_from_artifact( *, filename: str, artifact: ProbeArtifact, ) -> LoadedProbe: metadata = artifact.metadata tensors = artifact.tensors payload = { **metadata, "model_type": "linear", "model_state_dict": { "linear.weight": tensors["weight"], "linear.bias": tensors["bias"], }, "num_classes": int(tensors["weight"].shape[0]), "idx_to_label": metadata.get("class_names"), "scaler_mean": tensors.get("scaler_mean"), "scaler_std": tensors.get("scaler_scale"), "pca_mean": tensors.get("pca_mean"), "pca_components": tensors.get("pca_components"), } return _load_probe_payload(filename=filename, payload=payload) def _build_probe_module( payload: dict[str, Any], *, state_dict: dict[str, torch.Tensor], input_dim: int, num_classes: int, ) -> nn.Module: model_type = str(payload.get("model_type") or "").lower() if model_type in {"lr", "linear", "logreg", "logistic_regression"}: module = _LinearProbe(input_dim=input_dim, num_classes=num_classes) state_dict = _normalize_linear_state_dict(state_dict) elif model_type == "mlp": hidden_dims = _coerce_hidden_dims(payload.get("hidden_dims")) dropout = float(payload.get("dropout") or 0.0) module = _MLPProbe( input_dim=input_dim, hidden_dims=hidden_dims, num_classes=num_classes, dropout=dropout, ) state_dict = _strip_known_prefixes(state_dict) else: if _looks_linear(state_dict): module = _LinearProbe(input_dim=input_dim, num_classes=num_classes) state_dict = _normalize_linear_state_dict(state_dict) else: raise ValueError(f"Unsupported probe model type: {model_type!r}") module.load_state_dict(state_dict, strict=True) return module def _get_state_dict(payload: dict[str, Any]) -> dict[str, torch.Tensor]: for key in ("model_state_dict", "state_dict", "probe_state_dict"): value = payload.get(key) if isinstance(value, dict): return { str(k): v.detach().cpu() if isinstance(v, torch.Tensor) else v for k, v in value.items() } raise TypeError("Probe payload is missing model_state_dict") def _coerce_probe_dim( value: object, state_dict: dict[str, torch.Tensor], *, dim: str, ) -> int: if value is not None: return int(value) weights = [ tensor for key, tensor in state_dict.items() if key.endswith("weight") and isinstance(tensor, torch.Tensor) and tensor.ndim == 2 ] if not weights: raise ValueError(f"Cannot infer probe {dim} dimension from state dict") tensor = weights[0] if dim == "input" else weights[-1] return int(tensor.shape[1] if dim == "input" else tensor.shape[0]) def _normalize_linear_state_dict( state_dict: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: stripped = _strip_known_prefixes(state_dict) if "linear.weight" in stripped: return stripped if "weight" in stripped: out = {"linear.weight": stripped["weight"]} if "bias" in stripped: out["linear.bias"] = stripped["bias"] return out return stripped def _strip_known_prefixes( state_dict: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: out: dict[str, torch.Tensor] = {} for key, value in state_dict.items(): stripped = key for prefix in ("module.", "model.", "probe."): if stripped.startswith(prefix): stripped = stripped[len(prefix) :] out[stripped] = value return out def _looks_linear(state_dict: dict[str, torch.Tensor]) -> bool: stripped = _strip_known_prefixes(state_dict) return "weight" in stripped or "linear.weight" in stripped def _coerce_hidden_dims(value: Any) -> list[int]: if value is None: return [] if isinstance(value, int): return [value] if isinstance(value, str): return [int(part.strip()) for part in value.split(",") if part.strip()] if isinstance(value, (list, tuple)): return [int(part) for part in value] raise TypeError(f"Unsupported hidden_dims value: {type(value)!r}") def _as_cpu_tensor(value: Any) -> torch.Tensor | None: if not isinstance(value, torch.Tensor): return None return value.detach().cpu() def _optional_str(value: Any) -> str | None: if isinstance(value, str) and value: return value return None def _first_present(payload: dict[str, Any], *keys: str) -> Any: for key in keys: value = payload.get(key) if value is not None: return value return None def _normalize_labels(raw_labels: Any, num_classes: int) -> list[str | None]: if isinstance(raw_labels, (list, tuple)): labels = [str(label) for label in raw_labels[:num_classes]] return labels + [None] * (num_classes - len(labels)) if not isinstance(raw_labels, dict): return [None] * num_classes labels: list[str | None] = [None] * num_classes for raw_idx, raw_label in raw_labels.items(): try: idx = int(raw_idx) except (TypeError, ValueError): continue if 0 <= idx < num_classes: labels[idx] = str(raw_label) return labels