File size: 13,463 Bytes
77c2d62
 
 
 
 
 
 
 
d8ae160
77c2d62
c30bbc5
77c2d62
 
b279884
77c2d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8ae160
 
 
 
 
 
77c2d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae347c6
77c2d62
 
 
 
 
 
d8ae160
77c2d62
 
 
 
 
 
 
 
 
 
 
 
 
ae347c6
 
 
 
 
 
 
77c2d62
 
 
c30bbc5
 
77c2d62
 
 
 
 
 
 
 
 
 
 
 
 
 
d8ae160
 
b279884
d8ae160
 
 
 
 
 
 
 
 
77c2d62
 
 
 
 
 
 
 
 
 
 
 
 
 
d8ae160
 
77c2d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b279884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8ae160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b279884
d8ae160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b279884
d8ae160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
from __future__ import annotations

import hashlib
from dataclasses import dataclass

import streamlit as st
import torch
from nnterp import StandardizedTransformer
from persona_data.prompts import normalize_messages, supports_system_role

from utils.chat import decode_token, format_generation_prompt, resolve_saved_tensor

_TRACE_CACHE_KEY = "probe:trace_cache"
_DERIVED_CACHE_TRACKER_KEY = "probe:derived_cache_keys"
_MAX_CACHED_TRACES = 3


@dataclass(frozen=True)
class ConversationTrace:
    cache_key: str
    model_name: str
    remote: bool
    prompt_text: str
    prompt_hash: str
    layer: int
    location: str
    input_ids: torch.Tensor
    activations: torch.Tensor
    tokens: list[str]
    # One (start, end_exclusive) per assistant message in order. Empty list if
    # the tokenizer's chat template can't mark assistant tokens.
    assistant_spans: list[tuple[int, int]]
    # Per-position mask; True for tokenizer special ids that we don't want to
    # paint in the overlay (role markers, BOS/EOS, etc.).
    is_special: torch.Tensor

    @property
    def hidden_size(self) -> int:
        return int(self.activations.shape[-1])

    @property
    def n_tokens(self) -> int:
        return int(self.input_ids.shape[0])


def trace_conversation(
    *,
    model: StandardizedTransformer,
    model_name: str,
    messages: list[dict[str, str]],
    layer: int,
    location: str,
    remote: bool,
    ndif_api_key: str | None = None,
) -> ConversationTrace:
    prompt_text, _ = format_generation_prompt(
        messages,
        model.tokenizer,
        add_generation_prompt=False,
    )
    assistant_mask_seq = _compute_assistant_mask(model.tokenizer, messages)
    prompt_hash = hashlib.sha256(prompt_text.encode("utf-8")).hexdigest()
    cache_key = _trace_cache_key(
        model_name=model_name,
        remote=remote,
        prompt_hash=prompt_hash,
        layer=layer,
        location=location,
    )
    cached = _get_cached_trace(cache_key)
    if cached is not None:
        return cached

    accessor = _select_accessor(model, location)
    if remote:
        from utils.runtime import remote_backend

        backend = remote_backend(model, ndif_api_key)
    else:
        backend = None
    with torch.no_grad(), model.trace(prompt_text, remote=remote, backend=backend):
        saved_ids = model.input_ids[0].detach().cpu().save()
        saved_acts = accessor[layer][0].detach().float().cpu().save()

    input_ids = resolve_saved_tensor(saved_ids)
    activations = resolve_saved_tensor(saved_acts)
    if input_ids.ndim != 1:
        raise ValueError(
            f"Expected traced input ids to be [seq], got {tuple(input_ids.shape)}"
        )
    if activations.ndim != 2:
        raise ValueError(
            f"Expected traced activations to be [seq, hidden], got {tuple(activations.shape)}"
        )
    if int(input_ids.shape[0]) != int(activations.shape[0]):
        raise ValueError(
            "Trace produced a different number of token ids and activation rows: "
            f"{tuple(input_ids.shape)} vs {tuple(activations.shape)}"
        )

    n_tokens = int(input_ids.shape[0])
    assistant_spans = _clip_spans(
        _assistant_spans_from_offsets(model.tokenizer, prompt_text, messages, n_tokens),
        n_tokens,
    )
    if not assistant_spans and assistant_mask_seq is not None:
        assistant_spans = _assistant_spans(assistant_mask_seq, n_tokens)
    if not assistant_spans:
        prefix_spans = _assistant_spans_from_prefixes(model.tokenizer, messages)
        assistant_spans = _clip_spans(prefix_spans or [], n_tokens)
    is_special = _special_token_mask(model.tokenizer, input_ids)

    trace = ConversationTrace(
        cache_key=cache_key,
        model_name=model_name,
        remote=remote,
        prompt_text=prompt_text,
        prompt_hash=prompt_hash,
        layer=layer,
        location=location,
        input_ids=input_ids,
        activations=activations,
        tokens=[
            decode_token(model.tokenizer, int(token_id))
            for token_id in input_ids.tolist()
        ],
        assistant_spans=assistant_spans,
        is_special=is_special,
    )
    _store_cached_trace(cache_key, trace)
    return trace


def _select_accessor(model: StandardizedTransformer, location: str):
    normalized = location.lower()
    if normalized in {"pre_reasoning", "pre", "input", "layers_input"}:
        return model.layers_input
    if normalized in {"post_reasoning", "post", "output", "layers_output"}:
        return model.layers_output
    raise ValueError(f"Unsupported trace location: {location!r}")


def _trace_cache_key(
    *,
    model_name: str,
    remote: bool,
    prompt_hash: str,
    layer: int,
    location: str,
) -> str:
    return "::".join(
        (
            "probe-trace",
            model_name,
            str(remote),
            prompt_hash,
            str(layer),
            location,
        )
    )


def _get_cached_trace(cache_key: str) -> ConversationTrace | None:
    cache = st.session_state.get(_TRACE_CACHE_KEY)
    if not isinstance(cache, dict):
        return None
    trace = cache.get(cache_key)
    if not isinstance(trace, ConversationTrace):
        return None
    cache.pop(cache_key, None)
    cache[cache_key] = trace
    return trace


def _trace_cache() -> dict[str, ConversationTrace]:
    cache = st.session_state.get(_TRACE_CACHE_KEY)
    if isinstance(cache, dict):
        return cache
    cache = {}
    st.session_state[_TRACE_CACHE_KEY] = cache
    return cache


def _store_cached_trace(cache_key: str, trace: ConversationTrace) -> None:
    cache = _trace_cache()
    cache.pop(cache_key, None)
    cache[cache_key] = trace
    while len(cache) > _MAX_CACHED_TRACES:
        oldest_key = next(iter(cache))
        cache.pop(oldest_key, None)
        _drop_derived_results_for_trace(oldest_key)


def _drop_derived_results_for_trace(cache_key: str) -> None:
    """Remove probe predictions tied to a trace that just aged out."""

    prefixes = (
        f"probe_predictions::{cache_key}::",
        f"probe_values::{cache_key}::",
    )
    tracked = st.session_state.get(_DERIVED_CACHE_TRACKER_KEY)
    if isinstance(tracked, list):
        kept: list[str] = []
        for key in tracked:
            if isinstance(key, str) and key.startswith(prefixes):
                st.session_state.pop(key, None)
            else:
                kept.append(key)
        st.session_state[_DERIVED_CACHE_TRACKER_KEY] = kept
        return

    for key in list(st.session_state):
        if isinstance(key, str) and key.startswith(prefixes):
            st.session_state.pop(key, None)


def _compute_assistant_mask(
    tokenizer: object, messages: list[dict[str, str]]
) -> list[int] | None:
    """Return a per-token 0/1 mask marking assistant content, or None if unknown.

    Uses ``apply_chat_template(return_assistant_tokens_mask=True)`` when the
    tokenizer supports it (modern chat templates with ``{% generation %}``
    blocks). Returns ``None`` when the template doesn't mark assistant spans.
    """
    apply = getattr(tokenizer, "apply_chat_template", None)
    if apply is None or not messages:
        return None
    try:
        encoded = apply(
            messages,
            tokenize=True,
            add_generation_prompt=False,
            return_assistant_tokens_mask=True,
            return_dict=True,
        )
    except Exception:
        return None
    mask = encoded.get("assistant_masks") if isinstance(encoded, dict) else None
    if not mask:
        return None
    if isinstance(mask, list) and mask and isinstance(mask[0], list):
        mask = mask[0]
    values = [int(value) for value in mask]
    if not any(values):
        return None
    return values


def _assistant_spans_from_offsets(
    tokenizer: object,
    prompt_text: str,
    messages: list[dict[str, str]],
    n_tokens: int,
) -> list[tuple[int, int]]:
    """Locate assistant bodies by char-offset, aligned to the traced sequence.

    The chat-template token arithmetic in ``_assistant_spans_from_prefixes``
    drifts whenever the template tokenizes differently than how ``model.trace``
    tokenizes the rendered prompt string (extra/missing BOS, trailing
    whitespace, etc.), which leaves the overlay unalignable. This instead finds
    each assistant message's text inside ``prompt_text`` and maps those char
    ranges to token indices via the fast tokenizer's offset mapping, retokenizing
    the exact string the trace ran on so the indices line up by construction.
    """
    if not getattr(tokenizer, "is_fast", False):
        return []
    contents = [
        message["content"]
        for message in messages
        if message.get("role") == "assistant" and message.get("content")
    ]
    if not contents:
        return []

    offsets = None
    for add_special_tokens in (True, False):
        try:
            encoded = tokenizer(
                prompt_text,
                return_offsets_mapping=True,
                add_special_tokens=add_special_tokens,
            )
        except Exception:
            return []
        mapping = encoded.get("offset_mapping")
        if mapping is not None and len(mapping) == n_tokens:
            offsets = mapping
            break
    if offsets is None:
        return []

    spans: list[tuple[int, int]] = []
    search_from = 0
    for content in contents:
        char_start = prompt_text.find(content, search_from)
        if char_start < 0:
            return []
        char_end = char_start + len(content)
        search_from = char_end
        tok_start: int | None = None
        tok_end: int | None = None
        for i, (start, end) in enumerate(offsets):
            if start == end:  # special tokens map to an empty (0, 0) range
                continue
            if tok_start is None and end > char_start:
                tok_start = i
            if start < char_end:
                tok_end = i + 1
        if tok_start is not None and tok_end is not None and tok_start < tok_end:
            spans.append((tok_start, tok_end))
    return spans


def _assistant_spans_from_prefixes(
    tokenizer: object, messages: list[dict[str, str]]
) -> list[tuple[int, int]] | None:
    """Fallback span detection when the chat template doesn't mark assistant tokens.

    For each assistant message at index ``i``, tokenize ``messages[:i]`` with
    ``add_generation_prompt=True`` to find where the body starts, and
    ``messages[:i+1]`` with ``add_generation_prompt=False`` to find where it
    ends. Mirrors the prefix arithmetic used by ``utils.contrast``.
    """
    apply = getattr(tokenizer, "apply_chat_template", None)
    if apply is None or not messages:
        return None
    if not supports_system_role(tokenizer):
        messages = normalize_messages(messages)
    spans: list[tuple[int, int]] = []
    try:
        for i, message in enumerate(messages):
            if message.get("role") != "assistant":
                continue
            prefix_ids = apply(messages[:i], tokenize=True, add_generation_prompt=True)
            through_ids = apply(
                messages[: i + 1], tokenize=True, add_generation_prompt=False
            )
            prefix_ids = _flatten_ids(prefix_ids)
            through_ids = _flatten_ids(through_ids)
            if prefix_ids is None or through_ids is None:
                return None
            start = len(prefix_ids)
            end = len(through_ids)
            if 0 <= start < end:
                spans.append((start, end))
    except Exception:
        return None
    return spans


def _flatten_ids(value: object) -> list[int] | None:
    if not isinstance(value, list):
        return None
    if value and isinstance(value[0], list):
        value = value[0]
    try:
        return [int(v) for v in value]
    except (TypeError, ValueError):
        return None


def _clip_spans(spans: list[tuple[int, int]], n_tokens: int) -> list[tuple[int, int]]:
    clipped: list[tuple[int, int]] = []
    for start, end in spans:
        s = max(0, min(start, n_tokens))
        e = max(0, min(end, n_tokens))
        if s < e:
            clipped.append((s, e))
    return clipped


def _assistant_spans(
    assistant_mask_seq: list[int] | None, n_tokens: int
) -> list[tuple[int, int]]:
    """Convert a per-token mask into ``[(start, end_exclusive), ...]`` runs.

    Returns an empty list when the mask is missing or doesn't line up with the
    traced sequence, so the caller can show a clear "no overlay" state instead
    of painting the entire conversation.
    """
    if assistant_mask_seq is None or len(assistant_mask_seq) != n_tokens:
        return []
    spans: list[tuple[int, int]] = []
    start: int | None = None
    for i, value in enumerate(assistant_mask_seq):
        if value and start is None:
            start = i
        elif not value and start is not None:
            spans.append((start, i))
            start = None
    if start is not None:
        spans.append((start, n_tokens))
    return spans


def _special_token_mask(tokenizer: object, input_ids: torch.Tensor) -> torch.Tensor:
    special_ids = set(getattr(tokenizer, "all_special_ids", []) or [])
    if not special_ids:
        return torch.zeros(int(input_ids.shape[0]), dtype=torch.bool)
    return torch.tensor(
        [int(token_id) in special_ids for token_id in input_ids.tolist()],
        dtype=torch.bool,
    )