""" Contrastive token-level log-probability comparison for compare mode. For a pair of responses generated under different persona contexts, each token gets a weight: w(token) = log P(token | context_A) − log P(token | context_B) Positive (red) → token is more characteristic of persona A. Negative (blue) → token is more characteristic of persona B. Near-zero (gray) → both personas would emit this token with similar likelihood. """ from dataclasses import dataclass from html import escape import torch from nnterp import StandardizedTransformer from utils.chat import decode_token, format_generation_prompt, resolve_saved_tensor @dataclass class TokenContrast: tokens: list[str] weights: list[float] # normalised to [-1, 1], used for coloring raw_diffs: list[float] # unclipped log P(A) - log P(B) per token label_a: str label_b: str # ── Weight computation ──────────────────────────────────────────────────────── def _normalise_diffs(diffs: torch.Tensor) -> list[float]: """ Clip at the 95th percentile of |diff| and scale to [-1, 1] so a few high-magnitude tokens don't wash out everything else. """ if len(diffs) < 2: return diffs.tolist() clip_val = max(torch.quantile(diffs.abs(), 0.95).item(), 0.3) return (diffs.float().clamp(-clip_val, clip_val) / clip_val).tolist() def _strip_special_ids( ids: torch.Tensor, tokenizer: object, ) -> tuple[torch.Tensor, torch.Tensor]: """Return display ids and a mask that excludes special tokens.""" ids = ids.cpu() special_ids = set(getattr(tokenizer, "all_special_ids", []) or []) if not special_ids or ids.numel() == 0: return ids, torch.ones(ids.shape[0], dtype=torch.bool) keep = torch.tensor( [tid.item() not in special_ids for tid in ids], dtype=torch.bool ) return ids[keep], keep def _prepare_trace_input_ids( tokenizer: object, context_messages: list[dict[str, str]], response_ids: torch.Tensor, ) -> tuple[torch.Tensor, int, int]: """Build exact trace input ids and return ``(input_ids, n_ctx, n_resp)``.""" context_prompt, _ = format_generation_prompt(context_messages, tokenizer) context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0] input_ids = torch.cat([context_ids.cpu(), response_ids.detach().cpu()]) n_ctx = len(context_ids) n_resp = len(response_ids) return input_ids, n_ctx, n_resp def _build_contrast( tokenizer: object, response_ids: torch.Tensor, lp_a: torch.Tensor, lp_b: torch.Tensor, label_a: str, label_b: str, ) -> TokenContrast: diffs = (lp_a - lp_b).cpu() display_ids, keep_mask = _strip_special_ids(response_ids, tokenizer) display_diffs = diffs[keep_mask] return TokenContrast( tokens=[decode_token(tokenizer, tid.item()) for tid in display_ids], weights=_normalise_diffs(display_diffs), raw_diffs=display_diffs.float().tolist(), label_a=label_a, label_b=label_b, ) # Each spec: (key, input_ids, n_ctx, n_resp, target_ids). PassSpec = tuple[str, torch.Tensor, int, int, torch.Tensor] def _score_passes( model: StandardizedTransformer, specs: list[PassSpec], remote: bool, ndif_api_key: str | None = None, ) -> dict[str, torch.Tensor]: """ Run one forward pass per spec and return reduced per-token logprobs. The log-softmax and target-pick happen *inside* the trace, so only the reduced ``[n_resp]`` logprob vector per pass is shipped back — not the full ``[1, seq, vocab]`` logits (which would be hundreds of MB per pass on NDIF). """ def _score_pass( input_ids: torch.Tensor, n_ctx: int, n_resp: int, target_ids: torch.Tensor, ) -> torch.Tensor: if remote: from utils.runtime import remote_backend backend = remote_backend(model, ndif_api_key) else: backend = None with torch.no_grad(), model.trace(input_ids, remote=remote, backend=backend): # logit at position i predicts token i+1, so response token j # (at full-text position n_ctx+j) uses logit at n_ctx+j-1. resp_logits = model.logits[0, n_ctx - 1 : n_ctx - 1 + n_resp].float() log_probs = torch.log_softmax(resp_logits, dim=-1) targets = target_ids.to(log_probs.device).view(-1, 1) picked = log_probs.gather(1, targets).view(-1) out = picked.detach().cpu().save() return resolve_saved_tensor(out) return { key: _score_pass(input_ids, n_ctx, n_resp, target_ids) for key, input_ids, n_ctx, n_resp, target_ids in specs } def _specs_for_response( tokenizer: object, response_ids: torch.Tensor, context_a: list[dict[str, str]], context_b: list[dict[str, str]], prefix: str, ) -> list[PassSpec]: """Build the (under_a, under_b) pass specs for a single response.""" input_a, n_ctx_a, n_resp = _prepare_trace_input_ids( tokenizer, context_a, response_ids ) input_b, n_ctx_b, _ = _prepare_trace_input_ids(tokenizer, context_b, response_ids) return [ (f"{prefix}_under_a", input_a, n_ctx_a, n_resp, response_ids), (f"{prefix}_under_b", input_b, n_ctx_b, n_resp, response_ids), ] def compute_contrast( model: StandardizedTransformer, context_a: list[dict[str, str]], context_b: list[dict[str, str]], response_ids: torch.Tensor, label_a: str, label_b: str, remote: bool = False, ndif_api_key: str | None = None, ) -> "TokenContrast | None": """Compute per-token contrast weights for a single response (2 forward passes).""" tokenizer = model.tokenizer if response_ids.numel() == 0: return None specs = _specs_for_response(tokenizer, response_ids, context_a, context_b, "r") out = _score_passes(model, specs, remote, ndif_api_key) return _build_contrast( tokenizer, response_ids, out["r_under_a"], out["r_under_b"], label_a, label_b ) def compute_contrast_pair( model: StandardizedTransformer, context_a: list[dict[str, str]], context_b: list[dict[str, str]], response_ids_a: torch.Tensor, response_ids_b: torch.Tensor, label_a: str, label_b: str, remote: bool = False, ndif_api_key: str | None = None, ) -> tuple["TokenContrast | None", "TokenContrast | None"]: """ Compute contrast weights for both panel responses (up to 4 remote passes). """ tokenizer = model.tokenizer if response_ids_a.numel() == 0 and response_ids_b.numel() == 0: return None, None specs: list[PassSpec] = [] if response_ids_a.numel() > 0: specs += _specs_for_response( tokenizer, response_ids_a, context_a, context_b, "a" ) if response_ids_b.numel() > 0: specs += _specs_for_response( tokenizer, response_ids_b, context_a, context_b, "b" ) out = _score_passes(model, specs, remote, ndif_api_key) def _build(resp_ids: torch.Tensor, prefix: str) -> "TokenContrast | None": k_a, k_b = f"{prefix}_under_a", f"{prefix}_under_b" if resp_ids.numel() == 0 or k_a not in out or k_b not in out: return None return _build_contrast( tokenizer, resp_ids, out[k_a], out[k_b], label_a, label_b ) return _build(response_ids_a, "a"), _build(response_ids_b, "b") # ── HTML rendering ──────────────────────────────────────────────────────────── def _weight_to_bg(w: float) -> str: """Map a normalised weight in [-1, 1] to a CSS rgba background color.""" w = max(-1.0, min(1.0, w)) alpha = abs(w) * 0.5 # cap at 0.5 opacity so text stays readable if w > 0.05: return f"rgba(210,60,60,{alpha:.3f})" if w < -0.05: return f"rgba(50,110,210,{alpha:.3f})" return "rgba(0,0,0,0)" _CONTRAST_CSS = ( "" ) def render_contrast_html(result: TokenContrast) -> str: """ Render each token with a colored background reflecting how A- or B-specific it is, with a hover tooltip showing the raw Δlog P, plus a legend. """ # The model often opens a response with newline tokens; under pre-wrap # those render as blank lines before the first word. Drop leading # whitespace-only tokens (and left-trim the first visible one) so the # contrast starts at real content. Display-only — weights stay aligned. items = list(zip(result.tokens, result.weights, result.raw_diffs, strict=True)) start = 0 while start < len(items) and not items[start][0].strip(): start += 1 if start >= len(items): start = 0 # all-whitespace response: render as-is, not blank items = items[start:] spans: list[str] = [] for idx, (token, weight, raw) in enumerate(items): if idx == 0: token = token.lstrip() bg = _weight_to_bg(weight) tip = escape(f"Δlog P(A−B): {raw:+.3f}") text = escape(token) spans.append( f'' f'{text}{tip}' ) la = escape(result.label_a) lb = escape(result.label_b) return ( _CONTRAST_CSS + '
' + "".join(spans) + '
' + ' {la}' + ' {lb}' + 'gray = shared by both' + "
" + "
" )