persona-ui / tests /test_probe_cache_bounds.py
Jac-Zac
Big refactoring
b279884
from __future__ import annotations
import torch
from tabs import probe_ui
from utils import probe_trace
def test_store_derived_cache_evicts_oldest(monkeypatch):
session_state: dict[str, object] = {}
monkeypatch.setattr(probe_ui.st, "session_state", session_state)
monkeypatch.setattr(probe_ui, "_DERIVED_CACHE_ENTRIES", 2)
probe_ui._store_derived_cache("k1", 1)
probe_ui._store_derived_cache("k2", 2)
probe_ui._store_derived_cache("k3", 3)
assert "k1" not in session_state
assert session_state["k2"] == 2
assert session_state["k3"] == 3
assert session_state[probe_ui._DERIVED_CACHE_TRACKER_KEY] == ["k2", "k3"]
def test_get_derived_cache_refreshes_recently_used_entry(monkeypatch):
session_state: dict[str, object] = {}
monkeypatch.setattr(probe_ui.st, "session_state", session_state)
monkeypatch.setattr(probe_ui, "_DERIVED_CACHE_ENTRIES", 2)
probe_ui._store_derived_cache("k1", 1)
probe_ui._store_derived_cache("k2", 2)
assert probe_ui._get_derived_cache("k1") == 1
probe_ui._store_derived_cache("k3", 3)
assert "k1" in session_state
assert "k2" not in session_state
assert session_state[probe_ui._DERIVED_CACHE_TRACKER_KEY] == ["k1", "k3"]
def test_trace_eviction_drops_derived_results(monkeypatch):
session_state: dict[str, object] = {}
monkeypatch.setattr(probe_trace.st, "session_state", session_state)
monkeypatch.setattr(probe_trace, "_MAX_CACHED_TRACES", 1)
trace = probe_trace.ConversationTrace(
cache_key="old",
model_name="m",
remote=False,
prompt_text="p",
prompt_hash="h",
layer=0,
location="post_reasoning",
input_ids=torch.tensor([1]),
activations=torch.zeros((1, 1)),
tokens=["x"],
assistant_spans=[],
is_special=torch.tensor([False]),
)
old_prediction_key = "probe_predictions::old::probe"
kept_prediction_key = "probe_predictions::new::probe"
session_state[probe_trace._DERIVED_CACHE_TRACKER_KEY] = [
old_prediction_key,
kept_prediction_key,
]
session_state[old_prediction_key] = object()
session_state[kept_prediction_key] = object()
probe_trace._store_cached_trace("old", trace)
probe_trace._store_cached_trace(
"new",
probe_trace.ConversationTrace(
**{**trace.__dict__, "cache_key": "new"},
),
)
assert old_prediction_key not in session_state
assert kept_prediction_key in session_state
assert session_state[probe_trace._DERIVED_CACHE_TRACKER_KEY] == [
kept_prediction_key
]