File size: 2,621 Bytes
b279884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from __future__ import annotations

import torch

from tabs import probe_ui
from utils import probe_trace


def test_store_derived_cache_evicts_oldest(monkeypatch):
    session_state: dict[str, object] = {}
    monkeypatch.setattr(probe_ui.st, "session_state", session_state)
    monkeypatch.setattr(probe_ui, "_DERIVED_CACHE_ENTRIES", 2)

    probe_ui._store_derived_cache("k1", 1)
    probe_ui._store_derived_cache("k2", 2)
    probe_ui._store_derived_cache("k3", 3)

    assert "k1" not in session_state
    assert session_state["k2"] == 2
    assert session_state["k3"] == 3
    assert session_state[probe_ui._DERIVED_CACHE_TRACKER_KEY] == ["k2", "k3"]


def test_get_derived_cache_refreshes_recently_used_entry(monkeypatch):
    session_state: dict[str, object] = {}
    monkeypatch.setattr(probe_ui.st, "session_state", session_state)
    monkeypatch.setattr(probe_ui, "_DERIVED_CACHE_ENTRIES", 2)

    probe_ui._store_derived_cache("k1", 1)
    probe_ui._store_derived_cache("k2", 2)

    assert probe_ui._get_derived_cache("k1") == 1
    probe_ui._store_derived_cache("k3", 3)

    assert "k1" in session_state
    assert "k2" not in session_state
    assert session_state[probe_ui._DERIVED_CACHE_TRACKER_KEY] == ["k1", "k3"]


def test_trace_eviction_drops_derived_results(monkeypatch):
    session_state: dict[str, object] = {}
    monkeypatch.setattr(probe_trace.st, "session_state", session_state)
    monkeypatch.setattr(probe_trace, "_MAX_CACHED_TRACES", 1)

    trace = probe_trace.ConversationTrace(
        cache_key="old",
        model_name="m",
        remote=False,
        prompt_text="p",
        prompt_hash="h",
        layer=0,
        location="post_reasoning",
        input_ids=torch.tensor([1]),
        activations=torch.zeros((1, 1)),
        tokens=["x"],
        assistant_spans=[],
        is_special=torch.tensor([False]),
    )
    old_prediction_key = "probe_predictions::old::probe"
    kept_prediction_key = "probe_predictions::new::probe"
    session_state[probe_trace._DERIVED_CACHE_TRACKER_KEY] = [
        old_prediction_key,
        kept_prediction_key,
    ]
    session_state[old_prediction_key] = object()
    session_state[kept_prediction_key] = object()

    probe_trace._store_cached_trace("old", trace)
    probe_trace._store_cached_trace(
        "new",
        probe_trace.ConversationTrace(
            **{**trace.__dict__, "cache_key": "new"},
        ),
    )

    assert old_prediction_key not in session_state
    assert kept_prediction_key in session_state
    assert session_state[probe_trace._DERIVED_CACHE_TRACKER_KEY] == [
        kept_prediction_key
    ]