| """ |
| Contrastive token-level log-probability comparison for compare mode. |
| |
| For a pair of responses generated under different persona contexts, each token |
| gets a weight: |
| |
| w(token) = log P(token | context_A) β log P(token | context_B) |
| |
| Positive (red) β token is more characteristic of persona A. |
| Negative (blue) β token is more characteristic of persona B. |
| Near-zero (gray) β both personas would emit this token with similar likelihood. |
| """ |
|
|
| from dataclasses import dataclass |
| from html import escape |
|
|
| import torch |
| from nnterp import StandardizedTransformer |
|
|
| from utils.chat import decode_token, format_generation_prompt, resolve_saved_tensor |
|
|
|
|
| @dataclass |
| class TokenContrast: |
| tokens: list[str] |
| weights: list[float] |
| raw_diffs: list[float] |
| label_a: str |
| label_b: str |
|
|
|
|
| |
|
|
|
|
| def _normalise_diffs(diffs: torch.Tensor) -> list[float]: |
| """ |
| Clip at the 95th percentile of |diff| and scale to [-1, 1] so a few |
| high-magnitude tokens don't wash out everything else. |
| """ |
| if len(diffs) < 2: |
| return diffs.tolist() |
| clip_val = max(torch.quantile(diffs.abs(), 0.95).item(), 0.3) |
| return (diffs.float().clamp(-clip_val, clip_val) / clip_val).tolist() |
|
|
|
|
| def _strip_special_ids( |
| ids: torch.Tensor, |
| tokenizer: object, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Return display ids and a mask that excludes special tokens.""" |
| ids = ids.cpu() |
| special_ids = set(getattr(tokenizer, "all_special_ids", []) or []) |
| if not special_ids or ids.numel() == 0: |
| return ids, torch.ones(ids.shape[0], dtype=torch.bool) |
| keep = torch.tensor( |
| [tid.item() not in special_ids for tid in ids], dtype=torch.bool |
| ) |
| return ids[keep], keep |
|
|
|
|
| def _prepare_trace_input_ids( |
| tokenizer: object, |
| context_messages: list[dict[str, str]], |
| response_ids: torch.Tensor, |
| ) -> tuple[torch.Tensor, int, int]: |
| """Build exact trace input ids and return ``(input_ids, n_ctx, n_resp)``.""" |
| context_prompt, _ = format_generation_prompt(context_messages, tokenizer) |
| context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0] |
| input_ids = torch.cat([context_ids.cpu(), response_ids.detach().cpu()]) |
| n_ctx = len(context_ids) |
| n_resp = len(response_ids) |
| return input_ids, n_ctx, n_resp |
|
|
|
|
| def _build_contrast( |
| tokenizer: object, |
| response_ids: torch.Tensor, |
| lp_a: torch.Tensor, |
| lp_b: torch.Tensor, |
| label_a: str, |
| label_b: str, |
| ) -> TokenContrast: |
| diffs = (lp_a - lp_b).cpu() |
| display_ids, keep_mask = _strip_special_ids(response_ids, tokenizer) |
| display_diffs = diffs[keep_mask] |
| return TokenContrast( |
| tokens=[decode_token(tokenizer, tid.item()) for tid in display_ids], |
| weights=_normalise_diffs(display_diffs), |
| raw_diffs=display_diffs.float().tolist(), |
| label_a=label_a, |
| label_b=label_b, |
| ) |
|
|
|
|
| |
| PassSpec = tuple[str, torch.Tensor, int, int, torch.Tensor] |
|
|
|
|
| def _score_passes( |
| model: StandardizedTransformer, |
| specs: list[PassSpec], |
| remote: bool, |
| ndif_api_key: str | None = None, |
| ) -> dict[str, torch.Tensor]: |
| """ |
| Run one forward pass per spec and return reduced per-token logprobs. |
| |
| The log-softmax and target-pick happen *inside* the trace, so only the |
| reduced ``[n_resp]`` logprob vector per pass is shipped back β not the full |
| ``[1, seq, vocab]`` logits (which would be hundreds of MB per pass on NDIF). |
| """ |
|
|
| def _score_pass( |
| input_ids: torch.Tensor, |
| n_ctx: int, |
| n_resp: int, |
| target_ids: torch.Tensor, |
| ) -> torch.Tensor: |
| if remote: |
| from utils.runtime import remote_backend |
|
|
| backend = remote_backend(model, ndif_api_key) |
| else: |
| backend = None |
| with torch.no_grad(), model.trace(input_ids, remote=remote, backend=backend): |
| |
| |
| resp_logits = model.logits[0, n_ctx - 1 : n_ctx - 1 + n_resp].float() |
| log_probs = torch.log_softmax(resp_logits, dim=-1) |
| targets = target_ids.to(log_probs.device).view(-1, 1) |
| picked = log_probs.gather(1, targets).view(-1) |
| out = picked.detach().cpu().save() |
| return resolve_saved_tensor(out) |
|
|
| return { |
| key: _score_pass(input_ids, n_ctx, n_resp, target_ids) |
| for key, input_ids, n_ctx, n_resp, target_ids in specs |
| } |
|
|
|
|
| def _specs_for_response( |
| tokenizer: object, |
| response_ids: torch.Tensor, |
| context_a: list[dict[str, str]], |
| context_b: list[dict[str, str]], |
| prefix: str, |
| ) -> list[PassSpec]: |
| """Build the (under_a, under_b) pass specs for a single response.""" |
| input_a, n_ctx_a, n_resp = _prepare_trace_input_ids( |
| tokenizer, context_a, response_ids |
| ) |
| input_b, n_ctx_b, _ = _prepare_trace_input_ids(tokenizer, context_b, response_ids) |
| return [ |
| (f"{prefix}_under_a", input_a, n_ctx_a, n_resp, response_ids), |
| (f"{prefix}_under_b", input_b, n_ctx_b, n_resp, response_ids), |
| ] |
|
|
|
|
| def compute_contrast( |
| model: StandardizedTransformer, |
| context_a: list[dict[str, str]], |
| context_b: list[dict[str, str]], |
| response_ids: torch.Tensor, |
| label_a: str, |
| label_b: str, |
| remote: bool = False, |
| ndif_api_key: str | None = None, |
| ) -> "TokenContrast | None": |
| """Compute per-token contrast weights for a single response (2 forward passes).""" |
| tokenizer = model.tokenizer |
| if response_ids.numel() == 0: |
| return None |
|
|
| specs = _specs_for_response(tokenizer, response_ids, context_a, context_b, "r") |
| out = _score_passes(model, specs, remote, ndif_api_key) |
| return _build_contrast( |
| tokenizer, response_ids, out["r_under_a"], out["r_under_b"], label_a, label_b |
| ) |
|
|
|
|
| def compute_contrast_pair( |
| model: StandardizedTransformer, |
| context_a: list[dict[str, str]], |
| context_b: list[dict[str, str]], |
| response_ids_a: torch.Tensor, |
| response_ids_b: torch.Tensor, |
| label_a: str, |
| label_b: str, |
| remote: bool = False, |
| ndif_api_key: str | None = None, |
| ) -> tuple["TokenContrast | None", "TokenContrast | None"]: |
| """ |
| Compute contrast weights for both panel responses (up to 4 remote passes). |
| """ |
| tokenizer = model.tokenizer |
| if response_ids_a.numel() == 0 and response_ids_b.numel() == 0: |
| return None, None |
|
|
| specs: list[PassSpec] = [] |
| if response_ids_a.numel() > 0: |
| specs += _specs_for_response( |
| tokenizer, response_ids_a, context_a, context_b, "a" |
| ) |
| if response_ids_b.numel() > 0: |
| specs += _specs_for_response( |
| tokenizer, response_ids_b, context_a, context_b, "b" |
| ) |
|
|
| out = _score_passes(model, specs, remote, ndif_api_key) |
|
|
| def _build(resp_ids: torch.Tensor, prefix: str) -> "TokenContrast | None": |
| k_a, k_b = f"{prefix}_under_a", f"{prefix}_under_b" |
| if resp_ids.numel() == 0 or k_a not in out or k_b not in out: |
| return None |
| return _build_contrast( |
| tokenizer, resp_ids, out[k_a], out[k_b], label_a, label_b |
| ) |
|
|
| return _build(response_ids_a, "a"), _build(response_ids_b, "b") |
|
|
|
|
| |
|
|
|
|
| def _weight_to_bg(w: float) -> str: |
| """Map a normalised weight in [-1, 1] to a CSS rgba background color.""" |
| w = max(-1.0, min(1.0, w)) |
| alpha = abs(w) * 0.5 |
| if w > 0.05: |
| return f"rgba(210,60,60,{alpha:.3f})" |
| if w < -0.05: |
| return f"rgba(50,110,210,{alpha:.3f})" |
| return "rgba(0,0,0,0)" |
|
|
|
|
| _CONTRAST_CSS = ( |
| "<style>" |
| ".contrast-tok{position:relative;border-radius:2px;padding:0 1px;" |
| "cursor:default;white-space:pre;}" |
| ".contrast-tok>.contrast-tip{display:none;position:absolute;bottom:100%;" |
| "left:50%;transform:translateX(-50%);margin-bottom:4px;padding:2px 6px;" |
| "border-radius:3px;background:#222;color:#eee;font-size:0.72em;" |
| "font-family:ui-monospace,monospace;white-space:nowrap;pointer-events:none;" |
| "z-index:10;box-shadow:0 2px 6px rgba(0,0,0,0.3);}" |
| ".contrast-tok:hover>.contrast-tip{display:block;}" |
| "</style>" |
| ) |
|
|
|
|
| def render_contrast_html(result: TokenContrast) -> str: |
| """ |
| Render each token with a colored background reflecting how A- or B-specific |
| it is, with a hover tooltip showing the raw Ξlog P, plus a legend. |
| """ |
| |
| |
| |
| |
| items = list(zip(result.tokens, result.weights, result.raw_diffs, strict=True)) |
| start = 0 |
| while start < len(items) and not items[start][0].strip(): |
| start += 1 |
| if start >= len(items): |
| start = 0 |
| items = items[start:] |
|
|
| spans: list[str] = [] |
| for idx, (token, weight, raw) in enumerate(items): |
| if idx == 0: |
| token = token.lstrip() |
| bg = _weight_to_bg(weight) |
| tip = escape(f"Ξlog P(AβB): {raw:+.3f}") |
| text = escape(token) |
| spans.append( |
| f'<span class="contrast-tok" style="background:{bg};">' |
| f'{text}<span class="contrast-tip">{tip}</span></span>' |
| ) |
|
|
| la = escape(result.label_a) |
| lb = escape(result.label_b) |
|
|
| return ( |
| _CONTRAST_CSS + '<div style="font-family:inherit;line-height:1.75;' |
| 'white-space:pre-wrap;word-break:break-word;padding:2px 0 6px 0;">' |
| + "".join(spans) |
| + '<div style="margin-top:10px;font-size:0.72em;color:#888;' |
| + 'display:flex;gap:12px;flex-wrap:wrap;">' |
| + '<span><span style="background:rgba(210,60,60,0.45);' |
| + f'padding:1px 6px;border-radius:2px;"> </span> {la}</span>' |
| + '<span><span style="background:rgba(50,110,210,0.45);' |
| + f'padding:1px 6px;border-radius:2px;"> </span> {lb}</span>' |
| + '<span style="color:#aaa;">gray = shared by both</span>' |
| + "</div>" |
| + "</div>" |
| ) |
|
|