File size: 10,754 Bytes
a9950fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c30bbc5
93d5dc5
a9950fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ba2da4
a9950fb
 
 
9ba2da4
 
77c2d62
a9950fb
9ba2da4
a9950fb
 
9ba2da4
a9950fb
 
 
 
 
 
 
 
 
 
 
 
 
 
c30bbc5
a9950fb
 
 
 
 
 
 
9ba2da4
 
a9950fb
 
 
 
 
 
ae347c6
a9950fb
 
 
 
 
 
 
 
 
 
9ba2da4
a9950fb
 
 
 
ae347c6
 
 
 
 
 
 
a9950fb
 
 
 
 
 
 
c30bbc5
a9950fb
220e208
9ba2da4
 
220e208
a9950fb
 
 
 
 
 
 
 
 
 
9ba2da4
 
 
 
a9950fb
9ba2da4
 
a9950fb
 
 
 
 
 
 
 
 
 
 
ae347c6
a9950fb
 
 
 
 
 
 
ae347c6
a9950fb
 
 
 
 
 
 
 
 
 
 
 
 
 
ae347c6
a9950fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae347c6
a9950fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ac8f1c
 
 
 
b279884
9ac8f1c
 
 
 
 
 
 
a9950fb
9ac8f1c
 
 
a9950fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220e208
a9950fb
220e208
a9950fb
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
"""
Contrastive token-level log-probability comparison for compare mode.

For a pair of responses generated under different persona contexts, each token
gets a weight:

    w(token) = log P(token | context_A) βˆ’ log P(token | context_B)

Positive (red)  β†’ token is more characteristic of persona A.
Negative (blue) β†’ token is more characteristic of persona B.
Near-zero (gray) β†’ both personas would emit this token with similar likelihood.
"""

from dataclasses import dataclass
from html import escape

import torch
from nnterp import StandardizedTransformer

from utils.chat import decode_token, format_generation_prompt, resolve_saved_tensor


@dataclass
class TokenContrast:
    tokens: list[str]
    weights: list[float]  # normalised to [-1, 1], used for coloring
    raw_diffs: list[float]  # unclipped log P(A) - log P(B) per token
    label_a: str
    label_b: str


# ── Weight computation ────────────────────────────────────────────────────────


def _normalise_diffs(diffs: torch.Tensor) -> list[float]:
    """
    Clip at the 95th percentile of |diff| and scale to [-1, 1] so a few
    high-magnitude tokens don't wash out everything else.
    """
    if len(diffs) < 2:
        return diffs.tolist()
    clip_val = max(torch.quantile(diffs.abs(), 0.95).item(), 0.3)
    return (diffs.float().clamp(-clip_val, clip_val) / clip_val).tolist()


def _strip_special_ids(
    ids: torch.Tensor,
    tokenizer: object,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Return display ids and a mask that excludes special tokens."""
    ids = ids.cpu()
    special_ids = set(getattr(tokenizer, "all_special_ids", []) or [])
    if not special_ids or ids.numel() == 0:
        return ids, torch.ones(ids.shape[0], dtype=torch.bool)
    keep = torch.tensor(
        [tid.item() not in special_ids for tid in ids], dtype=torch.bool
    )
    return ids[keep], keep


def _prepare_trace_input_ids(
    tokenizer: object,
    context_messages: list[dict[str, str]],
    response_ids: torch.Tensor,
) -> tuple[torch.Tensor, int, int]:
    """Build exact trace input ids and return ``(input_ids, n_ctx, n_resp)``."""
    context_prompt, _ = format_generation_prompt(context_messages, tokenizer)
    context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0]
    input_ids = torch.cat([context_ids.cpu(), response_ids.detach().cpu()])
    n_ctx = len(context_ids)
    n_resp = len(response_ids)
    return input_ids, n_ctx, n_resp


def _build_contrast(
    tokenizer: object,
    response_ids: torch.Tensor,
    lp_a: torch.Tensor,
    lp_b: torch.Tensor,
    label_a: str,
    label_b: str,
) -> TokenContrast:
    diffs = (lp_a - lp_b).cpu()
    display_ids, keep_mask = _strip_special_ids(response_ids, tokenizer)
    display_diffs = diffs[keep_mask]
    return TokenContrast(
        tokens=[decode_token(tokenizer, tid.item()) for tid in display_ids],
        weights=_normalise_diffs(display_diffs),
        raw_diffs=display_diffs.float().tolist(),
        label_a=label_a,
        label_b=label_b,
    )


# Each spec: (key, input_ids, n_ctx, n_resp, target_ids).
PassSpec = tuple[str, torch.Tensor, int, int, torch.Tensor]


def _score_passes(
    model: StandardizedTransformer,
    specs: list[PassSpec],
    remote: bool,
    ndif_api_key: str | None = None,
) -> dict[str, torch.Tensor]:
    """
    Run one forward pass per spec and return reduced per-token logprobs.

    The log-softmax and target-pick happen *inside* the trace, so only the
    reduced ``[n_resp]`` logprob vector per pass is shipped back β€” not the full
    ``[1, seq, vocab]`` logits (which would be hundreds of MB per pass on NDIF).
    """

    def _score_pass(
        input_ids: torch.Tensor,
        n_ctx: int,
        n_resp: int,
        target_ids: torch.Tensor,
    ) -> torch.Tensor:
        if remote:
            from utils.runtime import remote_backend

            backend = remote_backend(model, ndif_api_key)
        else:
            backend = None
        with torch.no_grad(), model.trace(input_ids, remote=remote, backend=backend):
            # logit at position i predicts token i+1, so response token j
            # (at full-text position n_ctx+j) uses logit at n_ctx+j-1.
            resp_logits = model.logits[0, n_ctx - 1 : n_ctx - 1 + n_resp].float()
            log_probs = torch.log_softmax(resp_logits, dim=-1)
            targets = target_ids.to(log_probs.device).view(-1, 1)
            picked = log_probs.gather(1, targets).view(-1)
            out = picked.detach().cpu().save()
        return resolve_saved_tensor(out)

    return {
        key: _score_pass(input_ids, n_ctx, n_resp, target_ids)
        for key, input_ids, n_ctx, n_resp, target_ids in specs
    }


def _specs_for_response(
    tokenizer: object,
    response_ids: torch.Tensor,
    context_a: list[dict[str, str]],
    context_b: list[dict[str, str]],
    prefix: str,
) -> list[PassSpec]:
    """Build the (under_a, under_b) pass specs for a single response."""
    input_a, n_ctx_a, n_resp = _prepare_trace_input_ids(
        tokenizer, context_a, response_ids
    )
    input_b, n_ctx_b, _ = _prepare_trace_input_ids(tokenizer, context_b, response_ids)
    return [
        (f"{prefix}_under_a", input_a, n_ctx_a, n_resp, response_ids),
        (f"{prefix}_under_b", input_b, n_ctx_b, n_resp, response_ids),
    ]


def compute_contrast(
    model: StandardizedTransformer,
    context_a: list[dict[str, str]],
    context_b: list[dict[str, str]],
    response_ids: torch.Tensor,
    label_a: str,
    label_b: str,
    remote: bool = False,
    ndif_api_key: str | None = None,
) -> "TokenContrast | None":
    """Compute per-token contrast weights for a single response (2 forward passes)."""
    tokenizer = model.tokenizer
    if response_ids.numel() == 0:
        return None

    specs = _specs_for_response(tokenizer, response_ids, context_a, context_b, "r")
    out = _score_passes(model, specs, remote, ndif_api_key)
    return _build_contrast(
        tokenizer, response_ids, out["r_under_a"], out["r_under_b"], label_a, label_b
    )


def compute_contrast_pair(
    model: StandardizedTransformer,
    context_a: list[dict[str, str]],
    context_b: list[dict[str, str]],
    response_ids_a: torch.Tensor,
    response_ids_b: torch.Tensor,
    label_a: str,
    label_b: str,
    remote: bool = False,
    ndif_api_key: str | None = None,
) -> tuple["TokenContrast | None", "TokenContrast | None"]:
    """
    Compute contrast weights for both panel responses (up to 4 remote passes).
    """
    tokenizer = model.tokenizer
    if response_ids_a.numel() == 0 and response_ids_b.numel() == 0:
        return None, None

    specs: list[PassSpec] = []
    if response_ids_a.numel() > 0:
        specs += _specs_for_response(
            tokenizer, response_ids_a, context_a, context_b, "a"
        )
    if response_ids_b.numel() > 0:
        specs += _specs_for_response(
            tokenizer, response_ids_b, context_a, context_b, "b"
        )

    out = _score_passes(model, specs, remote, ndif_api_key)

    def _build(resp_ids: torch.Tensor, prefix: str) -> "TokenContrast | None":
        k_a, k_b = f"{prefix}_under_a", f"{prefix}_under_b"
        if resp_ids.numel() == 0 or k_a not in out or k_b not in out:
            return None
        return _build_contrast(
            tokenizer, resp_ids, out[k_a], out[k_b], label_a, label_b
        )

    return _build(response_ids_a, "a"), _build(response_ids_b, "b")


# ── HTML rendering ────────────────────────────────────────────────────────────


def _weight_to_bg(w: float) -> str:
    """Map a normalised weight in [-1, 1] to a CSS rgba background color."""
    w = max(-1.0, min(1.0, w))
    alpha = abs(w) * 0.5  # cap at 0.5 opacity so text stays readable
    if w > 0.05:
        return f"rgba(210,60,60,{alpha:.3f})"
    if w < -0.05:
        return f"rgba(50,110,210,{alpha:.3f})"
    return "rgba(0,0,0,0)"


_CONTRAST_CSS = (
    "<style>"
    ".contrast-tok{position:relative;border-radius:2px;padding:0 1px;"
    "cursor:default;white-space:pre;}"
    ".contrast-tok>.contrast-tip{display:none;position:absolute;bottom:100%;"
    "left:50%;transform:translateX(-50%);margin-bottom:4px;padding:2px 6px;"
    "border-radius:3px;background:#222;color:#eee;font-size:0.72em;"
    "font-family:ui-monospace,monospace;white-space:nowrap;pointer-events:none;"
    "z-index:10;box-shadow:0 2px 6px rgba(0,0,0,0.3);}"
    ".contrast-tok:hover>.contrast-tip{display:block;}"
    "</style>"
)


def render_contrast_html(result: TokenContrast) -> str:
    """
    Render each token with a colored background reflecting how A- or B-specific
    it is, with a hover tooltip showing the raw Ξ”log P, plus a legend.
    """
    # The model often opens a response with newline tokens; under pre-wrap
    # those render as blank lines before the first word. Drop leading
    # whitespace-only tokens (and left-trim the first visible one) so the
    # contrast starts at real content. Display-only β€” weights stay aligned.
    items = list(zip(result.tokens, result.weights, result.raw_diffs, strict=True))
    start = 0
    while start < len(items) and not items[start][0].strip():
        start += 1
    if start >= len(items):
        start = 0  # all-whitespace response: render as-is, not blank
    items = items[start:]

    spans: list[str] = []
    for idx, (token, weight, raw) in enumerate(items):
        if idx == 0:
            token = token.lstrip()
        bg = _weight_to_bg(weight)
        tip = escape(f"Ξ”log P(Aβˆ’B): {raw:+.3f}")
        text = escape(token)
        spans.append(
            f'<span class="contrast-tok" style="background:{bg};">'
            f'{text}<span class="contrast-tip">{tip}</span></span>'
        )

    la = escape(result.label_a)
    lb = escape(result.label_b)

    return (
        _CONTRAST_CSS + '<div style="font-family:inherit;line-height:1.75;'
        'white-space:pre-wrap;word-break:break-word;padding:2px 0 6px 0;">'
        + "".join(spans)
        + '<div style="margin-top:10px;font-size:0.72em;color:#888;'
        + 'display:flex;gap:12px;flex-wrap:wrap;">'
        + '<span><span style="background:rgba(210,60,60,0.45);'
        + f'padding:1px 6px;border-radius:2px;">&thinsp;</span>&nbsp;{la}</span>'
        + '<span><span style="background:rgba(50,110,210,0.45);'
        + f'padding:1px 6px;border-radius:2px;">&thinsp;</span>&nbsp;{lb}</span>'
        + '<span style="color:#aaa;">gray = shared by both</span>'
        + "</div>"
        + "</div>"
    )