persona-ui / utils /contrast.py
Jac-Zac
add session-scoped NDIF execution and improve cold-load UX
ae347c6
"""
Contrastive token-level log-probability comparison for compare mode.
For a pair of responses generated under different persona contexts, each token
gets a weight:
w(token) = log P(token | context_A) βˆ’ log P(token | context_B)
Positive (red) β†’ token is more characteristic of persona A.
Negative (blue) β†’ token is more characteristic of persona B.
Near-zero (gray) β†’ both personas would emit this token with similar likelihood.
"""
from dataclasses import dataclass
from html import escape
import torch
from nnterp import StandardizedTransformer
from utils.chat import decode_token, format_generation_prompt, resolve_saved_tensor
@dataclass
class TokenContrast:
tokens: list[str]
weights: list[float] # normalised to [-1, 1], used for coloring
raw_diffs: list[float] # unclipped log P(A) - log P(B) per token
label_a: str
label_b: str
# ── Weight computation ────────────────────────────────────────────────────────
def _normalise_diffs(diffs: torch.Tensor) -> list[float]:
"""
Clip at the 95th percentile of |diff| and scale to [-1, 1] so a few
high-magnitude tokens don't wash out everything else.
"""
if len(diffs) < 2:
return diffs.tolist()
clip_val = max(torch.quantile(diffs.abs(), 0.95).item(), 0.3)
return (diffs.float().clamp(-clip_val, clip_val) / clip_val).tolist()
def _strip_special_ids(
ids: torch.Tensor,
tokenizer: object,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return display ids and a mask that excludes special tokens."""
ids = ids.cpu()
special_ids = set(getattr(tokenizer, "all_special_ids", []) or [])
if not special_ids or ids.numel() == 0:
return ids, torch.ones(ids.shape[0], dtype=torch.bool)
keep = torch.tensor(
[tid.item() not in special_ids for tid in ids], dtype=torch.bool
)
return ids[keep], keep
def _prepare_trace_input_ids(
tokenizer: object,
context_messages: list[dict[str, str]],
response_ids: torch.Tensor,
) -> tuple[torch.Tensor, int, int]:
"""Build exact trace input ids and return ``(input_ids, n_ctx, n_resp)``."""
context_prompt, _ = format_generation_prompt(context_messages, tokenizer)
context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0]
input_ids = torch.cat([context_ids.cpu(), response_ids.detach().cpu()])
n_ctx = len(context_ids)
n_resp = len(response_ids)
return input_ids, n_ctx, n_resp
def _build_contrast(
tokenizer: object,
response_ids: torch.Tensor,
lp_a: torch.Tensor,
lp_b: torch.Tensor,
label_a: str,
label_b: str,
) -> TokenContrast:
diffs = (lp_a - lp_b).cpu()
display_ids, keep_mask = _strip_special_ids(response_ids, tokenizer)
display_diffs = diffs[keep_mask]
return TokenContrast(
tokens=[decode_token(tokenizer, tid.item()) for tid in display_ids],
weights=_normalise_diffs(display_diffs),
raw_diffs=display_diffs.float().tolist(),
label_a=label_a,
label_b=label_b,
)
# Each spec: (key, input_ids, n_ctx, n_resp, target_ids).
PassSpec = tuple[str, torch.Tensor, int, int, torch.Tensor]
def _score_passes(
model: StandardizedTransformer,
specs: list[PassSpec],
remote: bool,
ndif_api_key: str | None = None,
) -> dict[str, torch.Tensor]:
"""
Run one forward pass per spec and return reduced per-token logprobs.
The log-softmax and target-pick happen *inside* the trace, so only the
reduced ``[n_resp]`` logprob vector per pass is shipped back β€” not the full
``[1, seq, vocab]`` logits (which would be hundreds of MB per pass on NDIF).
"""
def _score_pass(
input_ids: torch.Tensor,
n_ctx: int,
n_resp: int,
target_ids: torch.Tensor,
) -> torch.Tensor:
if remote:
from utils.runtime import remote_backend
backend = remote_backend(model, ndif_api_key)
else:
backend = None
with torch.no_grad(), model.trace(input_ids, remote=remote, backend=backend):
# logit at position i predicts token i+1, so response token j
# (at full-text position n_ctx+j) uses logit at n_ctx+j-1.
resp_logits = model.logits[0, n_ctx - 1 : n_ctx - 1 + n_resp].float()
log_probs = torch.log_softmax(resp_logits, dim=-1)
targets = target_ids.to(log_probs.device).view(-1, 1)
picked = log_probs.gather(1, targets).view(-1)
out = picked.detach().cpu().save()
return resolve_saved_tensor(out)
return {
key: _score_pass(input_ids, n_ctx, n_resp, target_ids)
for key, input_ids, n_ctx, n_resp, target_ids in specs
}
def _specs_for_response(
tokenizer: object,
response_ids: torch.Tensor,
context_a: list[dict[str, str]],
context_b: list[dict[str, str]],
prefix: str,
) -> list[PassSpec]:
"""Build the (under_a, under_b) pass specs for a single response."""
input_a, n_ctx_a, n_resp = _prepare_trace_input_ids(
tokenizer, context_a, response_ids
)
input_b, n_ctx_b, _ = _prepare_trace_input_ids(tokenizer, context_b, response_ids)
return [
(f"{prefix}_under_a", input_a, n_ctx_a, n_resp, response_ids),
(f"{prefix}_under_b", input_b, n_ctx_b, n_resp, response_ids),
]
def compute_contrast(
model: StandardizedTransformer,
context_a: list[dict[str, str]],
context_b: list[dict[str, str]],
response_ids: torch.Tensor,
label_a: str,
label_b: str,
remote: bool = False,
ndif_api_key: str | None = None,
) -> "TokenContrast | None":
"""Compute per-token contrast weights for a single response (2 forward passes)."""
tokenizer = model.tokenizer
if response_ids.numel() == 0:
return None
specs = _specs_for_response(tokenizer, response_ids, context_a, context_b, "r")
out = _score_passes(model, specs, remote, ndif_api_key)
return _build_contrast(
tokenizer, response_ids, out["r_under_a"], out["r_under_b"], label_a, label_b
)
def compute_contrast_pair(
model: StandardizedTransformer,
context_a: list[dict[str, str]],
context_b: list[dict[str, str]],
response_ids_a: torch.Tensor,
response_ids_b: torch.Tensor,
label_a: str,
label_b: str,
remote: bool = False,
ndif_api_key: str | None = None,
) -> tuple["TokenContrast | None", "TokenContrast | None"]:
"""
Compute contrast weights for both panel responses (up to 4 remote passes).
"""
tokenizer = model.tokenizer
if response_ids_a.numel() == 0 and response_ids_b.numel() == 0:
return None, None
specs: list[PassSpec] = []
if response_ids_a.numel() > 0:
specs += _specs_for_response(
tokenizer, response_ids_a, context_a, context_b, "a"
)
if response_ids_b.numel() > 0:
specs += _specs_for_response(
tokenizer, response_ids_b, context_a, context_b, "b"
)
out = _score_passes(model, specs, remote, ndif_api_key)
def _build(resp_ids: torch.Tensor, prefix: str) -> "TokenContrast | None":
k_a, k_b = f"{prefix}_under_a", f"{prefix}_under_b"
if resp_ids.numel() == 0 or k_a not in out or k_b not in out:
return None
return _build_contrast(
tokenizer, resp_ids, out[k_a], out[k_b], label_a, label_b
)
return _build(response_ids_a, "a"), _build(response_ids_b, "b")
# ── HTML rendering ────────────────────────────────────────────────────────────
def _weight_to_bg(w: float) -> str:
"""Map a normalised weight in [-1, 1] to a CSS rgba background color."""
w = max(-1.0, min(1.0, w))
alpha = abs(w) * 0.5 # cap at 0.5 opacity so text stays readable
if w > 0.05:
return f"rgba(210,60,60,{alpha:.3f})"
if w < -0.05:
return f"rgba(50,110,210,{alpha:.3f})"
return "rgba(0,0,0,0)"
_CONTRAST_CSS = (
"<style>"
".contrast-tok{position:relative;border-radius:2px;padding:0 1px;"
"cursor:default;white-space:pre;}"
".contrast-tok>.contrast-tip{display:none;position:absolute;bottom:100%;"
"left:50%;transform:translateX(-50%);margin-bottom:4px;padding:2px 6px;"
"border-radius:3px;background:#222;color:#eee;font-size:0.72em;"
"font-family:ui-monospace,monospace;white-space:nowrap;pointer-events:none;"
"z-index:10;box-shadow:0 2px 6px rgba(0,0,0,0.3);}"
".contrast-tok:hover>.contrast-tip{display:block;}"
"</style>"
)
def render_contrast_html(result: TokenContrast) -> str:
"""
Render each token with a colored background reflecting how A- or B-specific
it is, with a hover tooltip showing the raw Ξ”log P, plus a legend.
"""
# The model often opens a response with newline tokens; under pre-wrap
# those render as blank lines before the first word. Drop leading
# whitespace-only tokens (and left-trim the first visible one) so the
# contrast starts at real content. Display-only β€” weights stay aligned.
items = list(zip(result.tokens, result.weights, result.raw_diffs, strict=True))
start = 0
while start < len(items) and not items[start][0].strip():
start += 1
if start >= len(items):
start = 0 # all-whitespace response: render as-is, not blank
items = items[start:]
spans: list[str] = []
for idx, (token, weight, raw) in enumerate(items):
if idx == 0:
token = token.lstrip()
bg = _weight_to_bg(weight)
tip = escape(f"Ξ”log P(Aβˆ’B): {raw:+.3f}")
text = escape(token)
spans.append(
f'<span class="contrast-tok" style="background:{bg};">'
f'{text}<span class="contrast-tip">{tip}</span></span>'
)
la = escape(result.label_a)
lb = escape(result.label_b)
return (
_CONTRAST_CSS + '<div style="font-family:inherit;line-height:1.75;'
'white-space:pre-wrap;word-break:break-word;padding:2px 0 6px 0;">'
+ "".join(spans)
+ '<div style="margin-top:10px;font-size:0.72em;color:#888;'
+ 'display:flex;gap:12px;flex-wrap:wrap;">'
+ '<span><span style="background:rgba(210,60,60,0.45);'
+ f'padding:1px 6px;border-radius:2px;">&thinsp;</span>&nbsp;{la}</span>'
+ '<span><span style="background:rgba(50,110,210,0.45);'
+ f'padding:1px 6px;border-radius:2px;">&thinsp;</span>&nbsp;{lb}</span>'
+ '<span style="color:#aaa;">gray = shared by both</span>'
+ "</div>"
+ "</div>"
)