persona-ui / utils /probes.py
Jac-Zac
Big refactoring
b279884
from __future__ import annotations
import io
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
from persona_vectors.probes import ProbeArtifact, load_probe_artifact
from utils.helpers import env_int
from utils.probe_files import (
download_probe_file,
download_probe_json_and_weights,
parse_probe_filename,
)
_PROBE_CACHE_ENTRIES = env_int("PERSONA_UI_PROBE_CACHE_ENTRIES", 8)
@dataclass(frozen=True)
class ProbeRunResult:
input_dim: int
logits: torch.Tensor
probabilities: torch.Tensor
predicted_index: int
predicted_label: str | None
class _LinearProbe(nn.Module):
def __init__(self, input_dim: int, num_classes: int):
super().__init__()
self.linear = nn.Linear(input_dim, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
class _MLPProbe(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
num_classes: int,
dropout: float,
):
super().__init__()
if not hidden_dims:
raise ValueError("MLP probe requires at least one hidden dimension")
layers: list[nn.Module] = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, hidden_dim))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout))
prev_dim = hidden_dim
layers.append(nn.Linear(prev_dim, num_classes))
self.network = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.network(x)
@dataclass
class LoadedProbe:
model: nn.Module
input_dim: int
labels: list[str | None]
model_type: str
layer: int | None
location: str | None
model_name: str | None = None
attribute_name: str | None = None
feature_space: str | None = None
task: str | None = None
probe_kind: str | None = None
scaler_mean: torch.Tensor | None = None
scaler_std: torch.Tensor | None = None
pca_mean: torch.Tensor | None = None
pca_components: torch.Tensor | None = None
def __post_init__(self) -> None:
self.model.eval()
@property
def is_regression(self) -> bool:
"""True when the probe outputs a continuous value rather than a class."""
if self.task is not None:
return self.task in {"numeric", "ordinal"}
if self.probe_kind is not None:
return self.probe_kind == "ridge_regression"
return False
def predict_batch(self, activations: torch.Tensor) -> torch.Tensor:
"""Return raw linear-output values for each token — no sigmoid/softmax."""
if activations.ndim != 2:
raise ValueError(
f"predict_batch expects [N, hidden], got {tuple(activations.shape)}"
)
if activations.shape[1] != self.input_dim:
raise ValueError(
f"Probe expects input dim {self.input_dim}, got {activations.shape[1]}"
)
batch = activations.detach().to(dtype=torch.float32, device="cpu")
normalized = self._normalize_batch(batch)
with torch.no_grad():
outputs = self.model(normalized).detach().cpu()
if outputs.ndim == 1:
outputs = outputs.unsqueeze(-1)
return outputs
def run(self, vector: torch.Tensor) -> ProbeRunResult:
if vector.ndim != 1:
raise ValueError(
f"Probe expects a 1D activation vector, got shape {tuple(vector.shape)}"
)
if vector.shape[0] != self.input_dim:
raise ValueError(
f"Probe expects input dim {self.input_dim}, got {vector.shape[0]}"
)
batch = vector.detach().to(dtype=torch.float32, device="cpu").unsqueeze(0)
logits_batch, probs_batch = self._forward_batch(batch)
logits = logits_batch.squeeze(0)
probs = probs_batch.squeeze(0)
if logits.ndim == 0:
logits = logits.unsqueeze(0)
if probs.ndim == 0:
probs = probs.unsqueeze(0)
predicted_index = (
int(probs.item() >= 0.5)
if probs.numel() == 1
else int(torch.argmax(probs).item())
)
predicted_label = (
self.labels[predicted_index]
if 0 <= predicted_index < len(self.labels)
else None
)
return ProbeRunResult(
input_dim=int(vector.shape[0]),
logits=logits,
probabilities=probs,
predicted_index=predicted_index,
predicted_label=predicted_label,
)
def run_batch(
self, activations: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run the probe over a batch of activations.
Returns ``(logits[N, C], probs[N, C], predicted_index[N])``. For
single-output probes ``C == 1`` and ``probs`` holds sigmoid scores.
"""
if activations.ndim != 2:
raise ValueError(
f"run_batch expects [N, hidden], got {tuple(activations.shape)}"
)
if activations.shape[1] != self.input_dim:
raise ValueError(
f"Probe expects input dim {self.input_dim}, got {activations.shape[1]}"
)
batch = activations.detach().to(dtype=torch.float32, device="cpu")
logits, probs = self._forward_batch(batch)
if probs.shape[-1] == 1:
predicted = (probs.squeeze(-1) >= 0.5).long()
else:
predicted = torch.argmax(probs, dim=-1)
return logits, probs, predicted
def _forward_batch(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
normalized = self._normalize_batch(batch)
with torch.no_grad():
logits = self.model(normalized).detach().cpu()
if logits.ndim == 1:
logits = logits.unsqueeze(-1)
if logits.shape[-1] == 1:
probs = torch.sigmoid(logits)
else:
probs = F.softmax(logits, dim=-1)
return logits, probs
def _normalize_batch(self, batch: torch.Tensor) -> torch.Tensor:
if self.scaler_mean is not None and self.scaler_std is not None:
mean = self.scaler_mean.to(dtype=torch.float32)
std = self.scaler_std.to(dtype=torch.float32)
if mean.ndim != 1 or std.ndim != 1 or mean.shape[0] != batch.shape[1]:
raise ValueError(
"Probe scaler shape does not match activation hidden size: "
f"mean={tuple(mean.shape)} std={tuple(std.shape)} "
f"batch={tuple(batch.shape)}"
)
safe_std = torch.where(std == 0, torch.ones_like(std), std)
batch = (batch - mean) / safe_std
if self.pca_mean is not None and self.pca_components is not None:
pca_mean = self.pca_mean.to(dtype=torch.float32)
components = self.pca_components.to(dtype=torch.float32)
if pca_mean.ndim != 1 or pca_mean.shape[0] != batch.shape[1]:
raise ValueError(
"Probe PCA mean shape does not match activation hidden size: "
f"mean={tuple(pca_mean.shape)} batch={tuple(batch.shape)}"
)
batch = (batch - pca_mean) @ components.T
return batch
@st.cache_resource(show_spinner=False, max_entries=_PROBE_CACHE_ENTRIES)
def load_probe(repo_id: str, filename: str) -> LoadedProbe:
if filename.endswith("probe.json"):
metadata_path, weights_path = download_probe_json_and_weights(repo_id, filename)
return _load_persona_probe_artifact(
filename=filename,
metadata_path=Path(metadata_path),
weights_path=Path(weights_path),
)
path = download_probe_file(repo_id, filename)
return _load_probe_payload(
filename=filename,
payload=_torch_load(path),
)
@st.cache_resource(show_spinner=False, max_entries=_PROBE_CACHE_ENTRIES)
def load_local_probe(root_dir: str, filename: str) -> LoadedProbe:
root = Path(root_dir).expanduser()
path = (root / filename).resolve()
if root.resolve() not in path.parents:
raise ValueError("Probe path must stay inside the selected local directory.")
if path.name == "probe.json":
return _load_persona_probe_artifact(
filename=filename,
metadata_path=path,
weights_path=path.with_name("weights.safetensors"),
)
if path.name == "weights.safetensors":
return _load_persona_probe_artifact(
filename=filename,
metadata_path=path.with_name("probe.json"),
weights_path=path,
)
return _load_probe_payload(
filename=filename,
payload=_torch_load(path),
)
@st.cache_resource(show_spinner=False, max_entries=_PROBE_CACHE_ENTRIES)
def load_probe_from_bytes(filename: str, data: bytes) -> LoadedProbe:
return _load_probe_payload(
filename=filename,
payload=_torch_load(io.BytesIO(data)),
)
def _load_probe_payload(
*,
filename: str,
payload: object,
) -> LoadedProbe:
if not isinstance(payload, dict):
raise TypeError(f"Probe payload must be a dict, got {type(payload)!r}")
metadata = parse_probe_filename(filename)
state_dict = _get_state_dict(payload)
input_dim = _coerce_probe_dim(payload.get("input_dim"), state_dict, dim="input")
model_input_dim = _coerce_probe_dim(
payload.get("artifact_feature_dim") or input_dim,
state_dict,
dim="input",
)
num_classes = _coerce_probe_dim(
payload.get("num_classes"), state_dict, dim="classes"
)
model = _build_probe_module(
payload,
state_dict=state_dict,
input_dim=model_input_dim,
num_classes=num_classes,
)
labels = _normalize_labels(
payload.get("idx_to_label") or payload.get("class_names"),
num_classes,
)
raw_layer = payload.get("layer")
try:
layer = int(raw_layer) if raw_layer is not None else metadata.layer
except (TypeError, ValueError):
layer = metadata.layer
raw_location = payload.get("location")
location = (
raw_location
if isinstance(raw_location, str) and raw_location
else metadata.location
)
return LoadedProbe(
model=model,
input_dim=input_dim,
labels=labels,
model_type=str(payload.get("model_type") or metadata.model_type),
layer=layer,
location=location,
model_name=_optional_str(payload.get("model_name")) or metadata.model_name,
attribute_name=(
_optional_str(payload.get("attribute_name")) or metadata.attribute_name
),
feature_space=(
(
f"pca{payload['n_pca_components']}"
if payload.get("n_pca_components")
else None
)
or _optional_str(payload.get("feature_space"))
or metadata.scope
),
task=_optional_str(payload.get("task")),
probe_kind=_optional_str(payload.get("probe_kind")),
scaler_mean=_as_cpu_tensor(payload.get("scaler_mean")),
scaler_std=_as_cpu_tensor(
_first_present(payload, "scaler_std", "scaler_scale")
),
pca_mean=_as_cpu_tensor(payload.get("pca_mean")),
pca_components=_as_cpu_tensor(payload.get("pca_components")),
)
def _torch_load(file_or_buffer: object) -> object:
return torch.load(file_or_buffer, map_location="cpu", weights_only=True)
def _load_persona_probe_artifact(
*,
filename: str,
metadata_path: Path,
weights_path: Path,
) -> LoadedProbe:
if metadata_path.parent != weights_path.parent:
raise ValueError("Canonical probe files must share one artifact directory.")
artifact = load_probe_artifact(metadata_path)
return _loaded_probe_from_artifact(filename=filename, artifact=artifact)
def _loaded_probe_from_artifact(
*,
filename: str,
artifact: ProbeArtifact,
) -> LoadedProbe:
metadata = artifact.metadata
tensors = artifact.tensors
payload = {
**metadata,
"model_type": "linear",
"model_state_dict": {
"linear.weight": tensors["weight"],
"linear.bias": tensors["bias"],
},
"num_classes": int(tensors["weight"].shape[0]),
"idx_to_label": metadata.get("class_names"),
"scaler_mean": tensors.get("scaler_mean"),
"scaler_std": tensors.get("scaler_scale"),
"pca_mean": tensors.get("pca_mean"),
"pca_components": tensors.get("pca_components"),
}
return _load_probe_payload(filename=filename, payload=payload)
def _build_probe_module(
payload: dict[str, Any],
*,
state_dict: dict[str, torch.Tensor],
input_dim: int,
num_classes: int,
) -> nn.Module:
model_type = str(payload.get("model_type") or "").lower()
if model_type in {"lr", "linear", "logreg", "logistic_regression"}:
module = _LinearProbe(input_dim=input_dim, num_classes=num_classes)
state_dict = _normalize_linear_state_dict(state_dict)
elif model_type == "mlp":
hidden_dims = _coerce_hidden_dims(payload.get("hidden_dims"))
dropout = float(payload.get("dropout") or 0.0)
module = _MLPProbe(
input_dim=input_dim,
hidden_dims=hidden_dims,
num_classes=num_classes,
dropout=dropout,
)
state_dict = _strip_known_prefixes(state_dict)
else:
if _looks_linear(state_dict):
module = _LinearProbe(input_dim=input_dim, num_classes=num_classes)
state_dict = _normalize_linear_state_dict(state_dict)
else:
raise ValueError(f"Unsupported probe model type: {model_type!r}")
module.load_state_dict(state_dict, strict=True)
return module
def _get_state_dict(payload: dict[str, Any]) -> dict[str, torch.Tensor]:
for key in ("model_state_dict", "state_dict", "probe_state_dict"):
value = payload.get(key)
if isinstance(value, dict):
return {
str(k): v.detach().cpu() if isinstance(v, torch.Tensor) else v
for k, v in value.items()
}
raise TypeError("Probe payload is missing model_state_dict")
def _coerce_probe_dim(
value: object,
state_dict: dict[str, torch.Tensor],
*,
dim: str,
) -> int:
if value is not None:
return int(value)
weights = [
tensor
for key, tensor in state_dict.items()
if key.endswith("weight")
and isinstance(tensor, torch.Tensor)
and tensor.ndim == 2
]
if not weights:
raise ValueError(f"Cannot infer probe {dim} dimension from state dict")
tensor = weights[0] if dim == "input" else weights[-1]
return int(tensor.shape[1] if dim == "input" else tensor.shape[0])
def _normalize_linear_state_dict(
state_dict: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
stripped = _strip_known_prefixes(state_dict)
if "linear.weight" in stripped:
return stripped
if "weight" in stripped:
out = {"linear.weight": stripped["weight"]}
if "bias" in stripped:
out["linear.bias"] = stripped["bias"]
return out
return stripped
def _strip_known_prefixes(
state_dict: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
out: dict[str, torch.Tensor] = {}
for key, value in state_dict.items():
stripped = key
for prefix in ("module.", "model.", "probe."):
if stripped.startswith(prefix):
stripped = stripped[len(prefix) :]
out[stripped] = value
return out
def _looks_linear(state_dict: dict[str, torch.Tensor]) -> bool:
stripped = _strip_known_prefixes(state_dict)
return "weight" in stripped or "linear.weight" in stripped
def _coerce_hidden_dims(value: Any) -> list[int]:
if value is None:
return []
if isinstance(value, int):
return [value]
if isinstance(value, str):
return [int(part.strip()) for part in value.split(",") if part.strip()]
if isinstance(value, (list, tuple)):
return [int(part) for part in value]
raise TypeError(f"Unsupported hidden_dims value: {type(value)!r}")
def _as_cpu_tensor(value: Any) -> torch.Tensor | None:
if not isinstance(value, torch.Tensor):
return None
return value.detach().cpu()
def _optional_str(value: Any) -> str | None:
if isinstance(value, str) and value:
return value
return None
def _first_present(payload: dict[str, Any], *keys: str) -> Any:
for key in keys:
value = payload.get(key)
if value is not None:
return value
return None
def _normalize_labels(raw_labels: Any, num_classes: int) -> list[str | None]:
if isinstance(raw_labels, (list, tuple)):
labels = [str(label) for label in raw_labels[:num_classes]]
return labels + [None] * (num_classes - len(labels))
if not isinstance(raw_labels, dict):
return [None] * num_classes
labels: list[str | None] = [None] * num_classes
for raw_idx, raw_label in raw_labels.items():
try:
idx = int(raw_idx)
except (TypeError, ValueError):
continue
if 0 <= idx < num_classes:
labels[idx] = str(raw_label)
return labels