File size: 16,301 Bytes
d8ae160
 
 
12cdb17
77c2d62
 
 
 
b279884
 
 
 
 
 
 
 
 
d8ae160
 
 
 
 
 
 
77c2d62
 
d8ae160
77c2d62
 
 
ae347c6
b279884
77c2d62
d8ae160
 
 
77c2d62
d8ae160
b279884
 
 
 
 
 
77c2d62
 
d8ae160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77c2d62
 
b279884
d8ae160
b279884
77c2d62
d8ae160
b279884
 
 
d8ae160
 
 
 
 
 
 
 
 
 
b279884
d8ae160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77c2d62
d8ae160
 
 
 
 
 
77c2d62
 
b279884
77c2d62
 
d8ae160
 
 
77c2d62
 
 
 
d8ae160
 
 
77c2d62
d8ae160
 
 
 
 
 
 
 
77c2d62
d8ae160
 
 
 
 
 
77c2d62
 
d8ae160
 
 
 
 
77c2d62
d8ae160
 
 
 
 
 
 
 
 
 
 
 
77c2d62
d8ae160
 
9ba2da4
d8ae160
 
 
 
 
 
 
 
 
 
 
77c2d62
d8ae160
 
77c2d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8ae160
 
9ba2da4
 
 
 
 
d8ae160
9ba2da4
 
 
 
 
 
 
 
 
 
d8ae160
9ba2da4
 
 
 
 
 
d8ae160
 
9ba2da4
 
 
 
 
 
d8ae160
 
9ba2da4
 
 
 
 
d8ae160
 
 
9ba2da4
 
b279884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8ae160
 
 
 
b279884
d8ae160
 
 
b279884
d8ae160
9ba2da4
 
d8ae160
 
 
 
b279884
d8ae160
 
 
b279884
d8ae160
9ba2da4
 
d8ae160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b279884
d8ae160
 
 
 
 
 
 
 
 
 
9ba2da4
 
77c2d62
 
 
 
 
 
 
 
 
d8ae160
77c2d62
d8ae160
77c2d62
 
d8ae160
 
77c2d62
d8ae160
 
 
b279884
d8ae160
 
 
 
 
 
 
 
 
 
 
 
77c2d62
d8ae160
b279884
77c2d62
d8ae160
77c2d62
d8ae160
77c2d62
d8ae160
 
 
 
 
 
77c2d62
 
d8ae160
 
77c2d62
d8ae160
 
77c2d62
d8ae160
77c2d62
 
d8ae160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77c2d62
 
d8ae160
 
 
 
 
 
 
 
 
 
ae347c6
d8ae160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ba2da4
d8ae160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
from __future__ import annotations

from pathlib import Path

import streamlit as st
import torch

from utils.chat import build_chat_messages
from utils.helpers import env_int, session_key, widget_key
from utils.probe_files import (
    DEFAULT_LOCAL_PROBE_DIR,
    DEFAULT_PROBE_REPO,
    list_local_probe_files,
    list_probe_files,
    model_probe_dir_name,
    parse_probe_filename,
)
from utils.probe_overlay import (
    attach_overlays,
    build_classification_overlays,
    build_regression_overlays,
    clear_overlays,
)
from utils.probe_trace import ConversationTrace, trace_conversation
from utils.probes import (
    LoadedProbe,
    load_local_probe,
    load_probe,
    load_probe_from_bytes,
)
from utils.runtime import cached_model, session_ndif_api_key
from utils.selection_controls import remembered_segmented_control

_LAST_SOURCE_KEY = session_key("probe", "last_source")
_LAST_LOCAL_FILE_KEY = session_key("probe", "last_local_file")
_LAST_HUB_FILE_KEY = session_key("probe", "last_hub_file")

_PROBE_SOURCES = ("Local artifact", "Hugging Face repo", "Upload .pt")
_DERIVED_CACHE_TRACKER_KEY = session_key("probe", "derived_cache_keys")
# Keep enough room for the three retained traces plus a few recently explored
# probes per trace. Derived outputs are much smaller than the trace activations
# themselves, so this avoids needless recomputation without reopening
# unbounded growth.
_DERIVED_CACHE_ENTRIES = env_int("PERSONA_UI_PROBE_DERIVED_CACHE_ENTRIES", 12)


# ---------------------------------------------------------------------------
# Probe selection
# ---------------------------------------------------------------------------


def _probe_label(filename: str) -> str:
    metadata = parse_probe_filename(filename)
    prefix = f"{metadata.model_name} / " if metadata.model_name else ""
    return f"{prefix}{metadata.label}"


def _model_compatible_files(files: list[str], model_name: str) -> list[str]:
    model_dir = model_probe_dir_name(model_name)
    compatible = [
        filename
        for filename in files
        if Path(filename).parts and Path(filename).parts[0] == model_dir
    ]
    return compatible or files


def _default_file(files: list[str], remembered: str | None) -> str:
    if remembered and remembered in files:
        return remembered
    return files[0]


def _render_probe_selector(*, context_key: str, model_name: str) -> LoadedProbe | None:
    """Inline source + file selector. Returns the loaded probe or None."""
    source = remembered_segmented_control(
        "Probe source",
        options=_PROBE_SOURCES,
        key=widget_key(context_key, "probe_source"),
        remember_key=_LAST_SOURCE_KEY,
        default=_PROBE_SOURCES[0],
        label_visibility="collapsed",
    )

    if source == "Local artifact":
        return _render_local_probe(context_key=context_key, model_name=model_name)
    if source == "Hugging Face repo":
        return _render_hub_probe(context_key=context_key, model_name=model_name)
    return _render_upload_probe(context_key=context_key)


def _render_local_probe(*, context_key: str, model_name: str) -> LoadedProbe | None:
    root_dir = st.text_input(
        "Probe directory",
        value=st.session_state.get(
            widget_key(context_key, "probe_local_dir"), DEFAULT_LOCAL_PROBE_DIR
        ),
        key=widget_key(context_key, "probe_local_dir"),
    )
    files = list_local_probe_files(root_dir.strip())
    if not files:
        st.warning("No probe files found in that directory.")
        return None
    files = _model_compatible_files(files, model_name)
    default = _default_file(files, st.session_state.get(_LAST_LOCAL_FILE_KEY))
    selected = st.selectbox(
        "Probe",
        options=files,
        index=files.index(default),
        format_func=_probe_label,
        key=widget_key(context_key, "probe_local_file"),
    )
    st.session_state[_LAST_LOCAL_FILE_KEY] = selected
    try:
        return load_local_probe(root_dir.strip(), selected)
    except Exception as exc:
        st.error(f"Could not load probe: {exc}")
        return None


def _render_hub_probe(*, context_key: str, model_name: str) -> LoadedProbe | None:
    repo_id = st.text_input(
        "Probe repo",
        value=st.session_state.get(
            widget_key(context_key, "probe_repo"), DEFAULT_PROBE_REPO
        ),
        key=widget_key(context_key, "probe_repo"),
    )
    if not repo_id.strip():
        return None
    files = list_probe_files(repo_id.strip())
    if not files:
        st.warning("No probe files found in that repo.")
        return None
    files = _model_compatible_files(files, model_name)
    default = _default_file(files, st.session_state.get(_LAST_HUB_FILE_KEY))
    selected = st.selectbox(
        "Probe",
        options=files,
        index=files.index(default),
        format_func=_probe_label,
        key=widget_key(context_key, "probe_hub_file"),
    )
    st.session_state[_LAST_HUB_FILE_KEY] = selected
    try:
        return load_probe(repo_id.strip(), selected)
    except Exception as exc:
        st.error(f"Could not load probe: {exc}")
        return None


def _render_upload_probe(*, context_key: str) -> LoadedProbe | None:
    uploaded = st.file_uploader(
        "Upload probe (.pt)",
        type=["pt"],
        key=widget_key(context_key, "probe_upload"),
    )
    if uploaded is None:
        return None
    try:
        return load_probe_from_bytes(uploaded.name, uploaded.getvalue())
    except Exception as exc:
        st.error(f"Could not load probe: {exc}")
        return None


# ---------------------------------------------------------------------------
# Probe card + target validation
# ---------------------------------------------------------------------------


def _render_probe_card(probe: LoadedProbe) -> None:
    parts: list[str] = []
    if probe.attribute_name:
        parts.append(f"**{probe.attribute_name}**")
    parts.append(f"layer `{probe.layer if probe.layer is not None else '?'}`")
    parts.append(f"kind `{probe.model_type}`")
    if probe.feature_space:
        parts.append(f"`{probe.feature_space}`")
    if probe.location:
        parts.append(f"`{probe.location}`")
    classes = (
        ", ".join(label for label in probe.labels if label)
        or f"{len(probe.labels)} classes"
    )
    parts.append(f"classes: {classes}")
    st.markdown("  ·  ".join(parts))


def _model_dimensions(model: object) -> tuple[int, int]:
    config = getattr(model, "config", None)
    hidden_size = getattr(model, "hidden_size", None) or getattr(
        config, "hidden_size", None
    )
    num_layers = (
        getattr(model, "num_layers", None)
        or getattr(config, "num_hidden_layers", None)
        or getattr(config, "n_layer", None)
    )
    if hidden_size is None or num_layers is None:
        raise ValueError("Could not read hidden_size and num_layers from the model.")
    return int(hidden_size), int(num_layers)


def _resolve_target(
    *, probe: LoadedProbe, context_key: str, num_layers: int
) -> tuple[int, str]:
    layer = probe.layer
    if layer is None:
        layer = int(
            st.number_input(
                "Layer (probe did not specify one)",
                min_value=0,
                max_value=max(0, num_layers - 1),
                value=min(15, max(0, num_layers - 1)),
                step=1,
                key=widget_key(context_key, "probe_layer"),
            )
        )
    location = probe.location
    if location is None:
        location = st.selectbox(
            "Activation location (probe did not specify one)",
            options=("post_reasoning", "pre_reasoning"),
            key=widget_key(context_key, "probe_location"),
        )
    return layer, location


def _validate(
    *, probe: LoadedProbe, layer: int, num_layers: int, hidden_size: int
) -> bool:
    if not 0 <= layer < num_layers:
        st.error(f"Probe layer {layer} is outside the model's {num_layers} layers.")
        return False
    if probe.input_dim != hidden_size:
        st.warning(
            f"Probe input dim ({probe.input_dim}) does not match the model's hidden "
            f"size ({hidden_size}). Predictions will not be meaningful."
        )
        return False
    return True


# ---------------------------------------------------------------------------
# Cached batched probe forward
# ---------------------------------------------------------------------------


def _store_derived_cache(key: str, value: object) -> None:
    """Store one derived probe result while keeping a small MRU window."""

    tracked = st.session_state.setdefault(_DERIVED_CACHE_TRACKER_KEY, [])
    if not isinstance(tracked, list):
        tracked = []
    tracked = [existing for existing in tracked if existing != key]
    tracked.append(key)
    while len(tracked) > _DERIVED_CACHE_ENTRIES:
        st.session_state.pop(tracked.pop(0), None)
    st.session_state[_DERIVED_CACHE_TRACKER_KEY] = tracked
    st.session_state[key] = value


def _get_derived_cache(key: str) -> object | None:
    """Return a derived probe result and refresh its MRU position."""

    cached = st.session_state.get(key)
    if cached is None:
        return None
    tracked = st.session_state.get(_DERIVED_CACHE_TRACKER_KEY)
    if isinstance(tracked, list) and key in tracked:
        tracked = [existing for existing in tracked if existing != key]
        tracked.append(key)
        st.session_state[_DERIVED_CACHE_TRACKER_KEY] = tracked
    return cached


def _classification_predictions(
    probe: LoadedProbe, activations: torch.Tensor, cache_key: str
) -> tuple[torch.Tensor, torch.Tensor]:
    full_key = widget_key("probe_predictions", cache_key, str(id(probe)))
    cached = _get_derived_cache(full_key)
    if cached is not None:
        return cached
    _, probs, predicted = probe.run_batch(activations)
    _store_derived_cache(full_key, (probs, predicted))
    return probs, predicted


def _regression_values(
    probe: LoadedProbe, activations: torch.Tensor, cache_key: str
) -> torch.Tensor:
    full_key = widget_key("probe_values", cache_key, str(id(probe)))
    cached = _get_derived_cache(full_key)
    if cached is not None:
        return cached
    values = probe.predict_batch(activations)
    _store_derived_cache(full_key, values)
    return values


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------


def _has_assistant_message(messages: list[dict]) -> bool:
    return any(m.get("role") == "assistant" and m.get("content") for m in messages)


def _apply_overlays(
    *, probe: LoadedProbe, trace: ConversationTrace, messages: list[dict]
) -> bool:
    if probe.is_regression:
        values = _regression_values(probe, trace.activations, trace.cache_key)
        overlays = build_regression_overlays(
            trace=trace,
            values=values,
            labels=probe.labels,
            attribute_name=probe.attribute_name,
        )
    else:
        probs, predicted = _classification_predictions(
            probe, trace.activations, trace.cache_key
        )
        binary = probs.shape[1] == 1 or (probs.shape[1] == 2 and len(probe.labels) == 2)
        overlays = build_classification_overlays(
            trace=trace,
            probs=probs,
            predicted=predicted,
            labels=probe.labels,
            binary=binary,
            attribute_name=probe.attribute_name,
        )
    attach_overlays(messages, overlays)
    return bool(overlays)


def render_probe_inspector(
    *,
    context_key: str,
    model_name: str,
    remote: bool,
    active_system_prompt: str | None,
    chat_state: dict[str, object],
    enabled: bool,
) -> None:
    messages: list[dict] = chat_state["messages"]  # type: ignore[assignment]
    if not enabled:
        clear_overlays(messages)
        return

    status_key = widget_key(context_key, "probe_status")
    sig_key = widget_key(context_key, "probe_scored_sig")

    def _conversation_sig() -> int:
        return hash(
            tuple(
                (m.get("role"), m.get("content")) for m in messages if m.get("content")
            )
        )

    def _reset() -> None:
        clear_overlays(messages)
        st.session_state.pop(status_key, None)
        st.session_state.pop(sig_key, None)

    with st.expander("Probe", expanded=True):
        if not _has_assistant_message(messages):
            _reset()
            st.caption("Probe overlay shows up after the first assistant reply.")
            return

        probe = _render_probe_selector(context_key=context_key, model_name=model_name)
        if probe is None:
            _reset()
            return
        _render_probe_card(probe)

        model = cached_model(model_name=model_name)
        try:
            hidden_size, num_layers = _model_dimensions(model)
        except Exception as exc:
            _reset()
            st.error(str(exc))
            return

        layer, location = _resolve_target(
            probe=probe, context_key=context_key, num_layers=num_layers
        )
        if not _validate(
            probe=probe, layer=layer, num_layers=num_layers, hidden_size=hidden_size
        ):
            _reset()
            return

        # The probe scores via a separate forward pass over the whole
        # conversation, so it's fully decoupled from generation: pick or switch
        # probes any time and score on demand. Gate that pass behind a button
        # instead of re-running it on every Streamlit rerun. Overlays live on
        # the message dicts, so they persist across reruns until refreshed.
        run = st.button(
            "Run probe",
            type="primary",
            key=widget_key(context_key, "probe_run"),
            help="Score the current conversation with the selected probe.",
        )
        if not run:
            status = st.session_state.get(status_key)
            if not status:
                st.caption("Press **Run probe** to score the conversation.")
            elif st.session_state.get(sig_key) != _conversation_sig():
                # Conversation changed since it was scored: drop the now-stale
                # overlay so it can't paint over edited/new text.
                clear_overlays(messages)
                st.caption("Conversation changed — press **Run probe** to refresh.")
            else:
                st.caption(f"{status} · press **Run probe** to refresh.")
            return

        chat_messages = build_chat_messages(active_system_prompt, messages)
        with st.spinner("Tracing conversation..."):
            try:
                trace = trace_conversation(
                    model=model,
                    model_name=model_name,
                    messages=chat_messages,
                    layer=layer,
                    location=location,
                    remote=remote,
                    ndif_api_key=session_ndif_api_key(),
                )
            except Exception as exc:
                _reset()
                st.error(f"Trace failed: {exc}")
                return

        if not trace.assistant_spans:
            _reset()
            st.warning(
                "Could not locate assistant tokens in the traced sequence, so "
                "the overlay can't be aligned to message bodies."
            )
            return

        try:
            applied = _apply_overlays(probe=probe, trace=trace, messages=messages)
        except Exception as exc:
            _reset()
            st.error(f"Probe execution failed: {exc}")
            return

        if not applied:
            _reset()
            return

        n_body = sum(
            sum(1 for i in range(s, e) if not bool(trace.is_special[i].item()))
            for s, e in trace.assistant_spans
        )
        kind = "regression" if probe.is_regression else "classification"
        status = (
            f"{kind} · {len(trace.assistant_spans)} assistant message(s) · "
            f"{n_body} body tokens · layer {trace.layer} · {trace.location}"
        )
        st.session_state[status_key] = status
        st.session_state[sig_key] = _conversation_sig()
        st.caption(status)