File size: 10,754 Bytes
a9950fb c30bbc5 93d5dc5 a9950fb 9ba2da4 a9950fb 9ba2da4 77c2d62 a9950fb 9ba2da4 a9950fb 9ba2da4 a9950fb c30bbc5 a9950fb 9ba2da4 a9950fb ae347c6 a9950fb 9ba2da4 a9950fb ae347c6 a9950fb c30bbc5 a9950fb 220e208 9ba2da4 220e208 a9950fb 9ba2da4 a9950fb 9ba2da4 a9950fb ae347c6 a9950fb ae347c6 a9950fb ae347c6 a9950fb ae347c6 a9950fb 9ac8f1c b279884 9ac8f1c a9950fb 9ac8f1c a9950fb 220e208 a9950fb 220e208 a9950fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 | """
Contrastive token-level log-probability comparison for compare mode.
For a pair of responses generated under different persona contexts, each token
gets a weight:
w(token) = log P(token | context_A) β log P(token | context_B)
Positive (red) β token is more characteristic of persona A.
Negative (blue) β token is more characteristic of persona B.
Near-zero (gray) β both personas would emit this token with similar likelihood.
"""
from dataclasses import dataclass
from html import escape
import torch
from nnterp import StandardizedTransformer
from utils.chat import decode_token, format_generation_prompt, resolve_saved_tensor
@dataclass
class TokenContrast:
tokens: list[str]
weights: list[float] # normalised to [-1, 1], used for coloring
raw_diffs: list[float] # unclipped log P(A) - log P(B) per token
label_a: str
label_b: str
# ββ Weight computation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _normalise_diffs(diffs: torch.Tensor) -> list[float]:
"""
Clip at the 95th percentile of |diff| and scale to [-1, 1] so a few
high-magnitude tokens don't wash out everything else.
"""
if len(diffs) < 2:
return diffs.tolist()
clip_val = max(torch.quantile(diffs.abs(), 0.95).item(), 0.3)
return (diffs.float().clamp(-clip_val, clip_val) / clip_val).tolist()
def _strip_special_ids(
ids: torch.Tensor,
tokenizer: object,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return display ids and a mask that excludes special tokens."""
ids = ids.cpu()
special_ids = set(getattr(tokenizer, "all_special_ids", []) or [])
if not special_ids or ids.numel() == 0:
return ids, torch.ones(ids.shape[0], dtype=torch.bool)
keep = torch.tensor(
[tid.item() not in special_ids for tid in ids], dtype=torch.bool
)
return ids[keep], keep
def _prepare_trace_input_ids(
tokenizer: object,
context_messages: list[dict[str, str]],
response_ids: torch.Tensor,
) -> tuple[torch.Tensor, int, int]:
"""Build exact trace input ids and return ``(input_ids, n_ctx, n_resp)``."""
context_prompt, _ = format_generation_prompt(context_messages, tokenizer)
context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0]
input_ids = torch.cat([context_ids.cpu(), response_ids.detach().cpu()])
n_ctx = len(context_ids)
n_resp = len(response_ids)
return input_ids, n_ctx, n_resp
def _build_contrast(
tokenizer: object,
response_ids: torch.Tensor,
lp_a: torch.Tensor,
lp_b: torch.Tensor,
label_a: str,
label_b: str,
) -> TokenContrast:
diffs = (lp_a - lp_b).cpu()
display_ids, keep_mask = _strip_special_ids(response_ids, tokenizer)
display_diffs = diffs[keep_mask]
return TokenContrast(
tokens=[decode_token(tokenizer, tid.item()) for tid in display_ids],
weights=_normalise_diffs(display_diffs),
raw_diffs=display_diffs.float().tolist(),
label_a=label_a,
label_b=label_b,
)
# Each spec: (key, input_ids, n_ctx, n_resp, target_ids).
PassSpec = tuple[str, torch.Tensor, int, int, torch.Tensor]
def _score_passes(
model: StandardizedTransformer,
specs: list[PassSpec],
remote: bool,
ndif_api_key: str | None = None,
) -> dict[str, torch.Tensor]:
"""
Run one forward pass per spec and return reduced per-token logprobs.
The log-softmax and target-pick happen *inside* the trace, so only the
reduced ``[n_resp]`` logprob vector per pass is shipped back β not the full
``[1, seq, vocab]`` logits (which would be hundreds of MB per pass on NDIF).
"""
def _score_pass(
input_ids: torch.Tensor,
n_ctx: int,
n_resp: int,
target_ids: torch.Tensor,
) -> torch.Tensor:
if remote:
from utils.runtime import remote_backend
backend = remote_backend(model, ndif_api_key)
else:
backend = None
with torch.no_grad(), model.trace(input_ids, remote=remote, backend=backend):
# logit at position i predicts token i+1, so response token j
# (at full-text position n_ctx+j) uses logit at n_ctx+j-1.
resp_logits = model.logits[0, n_ctx - 1 : n_ctx - 1 + n_resp].float()
log_probs = torch.log_softmax(resp_logits, dim=-1)
targets = target_ids.to(log_probs.device).view(-1, 1)
picked = log_probs.gather(1, targets).view(-1)
out = picked.detach().cpu().save()
return resolve_saved_tensor(out)
return {
key: _score_pass(input_ids, n_ctx, n_resp, target_ids)
for key, input_ids, n_ctx, n_resp, target_ids in specs
}
def _specs_for_response(
tokenizer: object,
response_ids: torch.Tensor,
context_a: list[dict[str, str]],
context_b: list[dict[str, str]],
prefix: str,
) -> list[PassSpec]:
"""Build the (under_a, under_b) pass specs for a single response."""
input_a, n_ctx_a, n_resp = _prepare_trace_input_ids(
tokenizer, context_a, response_ids
)
input_b, n_ctx_b, _ = _prepare_trace_input_ids(tokenizer, context_b, response_ids)
return [
(f"{prefix}_under_a", input_a, n_ctx_a, n_resp, response_ids),
(f"{prefix}_under_b", input_b, n_ctx_b, n_resp, response_ids),
]
def compute_contrast(
model: StandardizedTransformer,
context_a: list[dict[str, str]],
context_b: list[dict[str, str]],
response_ids: torch.Tensor,
label_a: str,
label_b: str,
remote: bool = False,
ndif_api_key: str | None = None,
) -> "TokenContrast | None":
"""Compute per-token contrast weights for a single response (2 forward passes)."""
tokenizer = model.tokenizer
if response_ids.numel() == 0:
return None
specs = _specs_for_response(tokenizer, response_ids, context_a, context_b, "r")
out = _score_passes(model, specs, remote, ndif_api_key)
return _build_contrast(
tokenizer, response_ids, out["r_under_a"], out["r_under_b"], label_a, label_b
)
def compute_contrast_pair(
model: StandardizedTransformer,
context_a: list[dict[str, str]],
context_b: list[dict[str, str]],
response_ids_a: torch.Tensor,
response_ids_b: torch.Tensor,
label_a: str,
label_b: str,
remote: bool = False,
ndif_api_key: str | None = None,
) -> tuple["TokenContrast | None", "TokenContrast | None"]:
"""
Compute contrast weights for both panel responses (up to 4 remote passes).
"""
tokenizer = model.tokenizer
if response_ids_a.numel() == 0 and response_ids_b.numel() == 0:
return None, None
specs: list[PassSpec] = []
if response_ids_a.numel() > 0:
specs += _specs_for_response(
tokenizer, response_ids_a, context_a, context_b, "a"
)
if response_ids_b.numel() > 0:
specs += _specs_for_response(
tokenizer, response_ids_b, context_a, context_b, "b"
)
out = _score_passes(model, specs, remote, ndif_api_key)
def _build(resp_ids: torch.Tensor, prefix: str) -> "TokenContrast | None":
k_a, k_b = f"{prefix}_under_a", f"{prefix}_under_b"
if resp_ids.numel() == 0 or k_a not in out or k_b not in out:
return None
return _build_contrast(
tokenizer, resp_ids, out[k_a], out[k_b], label_a, label_b
)
return _build(response_ids_a, "a"), _build(response_ids_b, "b")
# ββ HTML rendering ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _weight_to_bg(w: float) -> str:
"""Map a normalised weight in [-1, 1] to a CSS rgba background color."""
w = max(-1.0, min(1.0, w))
alpha = abs(w) * 0.5 # cap at 0.5 opacity so text stays readable
if w > 0.05:
return f"rgba(210,60,60,{alpha:.3f})"
if w < -0.05:
return f"rgba(50,110,210,{alpha:.3f})"
return "rgba(0,0,0,0)"
_CONTRAST_CSS = (
"<style>"
".contrast-tok{position:relative;border-radius:2px;padding:0 1px;"
"cursor:default;white-space:pre;}"
".contrast-tok>.contrast-tip{display:none;position:absolute;bottom:100%;"
"left:50%;transform:translateX(-50%);margin-bottom:4px;padding:2px 6px;"
"border-radius:3px;background:#222;color:#eee;font-size:0.72em;"
"font-family:ui-monospace,monospace;white-space:nowrap;pointer-events:none;"
"z-index:10;box-shadow:0 2px 6px rgba(0,0,0,0.3);}"
".contrast-tok:hover>.contrast-tip{display:block;}"
"</style>"
)
def render_contrast_html(result: TokenContrast) -> str:
"""
Render each token with a colored background reflecting how A- or B-specific
it is, with a hover tooltip showing the raw Ξlog P, plus a legend.
"""
# The model often opens a response with newline tokens; under pre-wrap
# those render as blank lines before the first word. Drop leading
# whitespace-only tokens (and left-trim the first visible one) so the
# contrast starts at real content. Display-only β weights stay aligned.
items = list(zip(result.tokens, result.weights, result.raw_diffs, strict=True))
start = 0
while start < len(items) and not items[start][0].strip():
start += 1
if start >= len(items):
start = 0 # all-whitespace response: render as-is, not blank
items = items[start:]
spans: list[str] = []
for idx, (token, weight, raw) in enumerate(items):
if idx == 0:
token = token.lstrip()
bg = _weight_to_bg(weight)
tip = escape(f"Ξlog P(AβB): {raw:+.3f}")
text = escape(token)
spans.append(
f'<span class="contrast-tok" style="background:{bg};">'
f'{text}<span class="contrast-tip">{tip}</span></span>'
)
la = escape(result.label_a)
lb = escape(result.label_b)
return (
_CONTRAST_CSS + '<div style="font-family:inherit;line-height:1.75;'
'white-space:pre-wrap;word-break:break-word;padding:2px 0 6px 0;">'
+ "".join(spans)
+ '<div style="margin-top:10px;font-size:0.72em;color:#888;'
+ 'display:flex;gap:12px;flex-wrap:wrap;">'
+ '<span><span style="background:rgba(210,60,60,0.45);'
+ f'padding:1px 6px;border-radius:2px;"> </span> {la}</span>'
+ '<span><span style="background:rgba(50,110,210,0.45);'
+ f'padding:1px 6px;border-radius:2px;"> </span> {lb}</span>'
+ '<span style="color:#aaa;">gray = shared by both</span>'
+ "</div>"
+ "</div>"
)
|