| from __future__ import annotations |
|
|
| import torch |
|
|
| from tabs import probe_ui |
| from utils import probe_trace |
|
|
|
|
| def test_store_derived_cache_evicts_oldest(monkeypatch): |
| session_state: dict[str, object] = {} |
| monkeypatch.setattr(probe_ui.st, "session_state", session_state) |
| monkeypatch.setattr(probe_ui, "_DERIVED_CACHE_ENTRIES", 2) |
|
|
| probe_ui._store_derived_cache("k1", 1) |
| probe_ui._store_derived_cache("k2", 2) |
| probe_ui._store_derived_cache("k3", 3) |
|
|
| assert "k1" not in session_state |
| assert session_state["k2"] == 2 |
| assert session_state["k3"] == 3 |
| assert session_state[probe_ui._DERIVED_CACHE_TRACKER_KEY] == ["k2", "k3"] |
|
|
|
|
| def test_get_derived_cache_refreshes_recently_used_entry(monkeypatch): |
| session_state: dict[str, object] = {} |
| monkeypatch.setattr(probe_ui.st, "session_state", session_state) |
| monkeypatch.setattr(probe_ui, "_DERIVED_CACHE_ENTRIES", 2) |
|
|
| probe_ui._store_derived_cache("k1", 1) |
| probe_ui._store_derived_cache("k2", 2) |
|
|
| assert probe_ui._get_derived_cache("k1") == 1 |
| probe_ui._store_derived_cache("k3", 3) |
|
|
| assert "k1" in session_state |
| assert "k2" not in session_state |
| assert session_state[probe_ui._DERIVED_CACHE_TRACKER_KEY] == ["k1", "k3"] |
|
|
|
|
| def test_trace_eviction_drops_derived_results(monkeypatch): |
| session_state: dict[str, object] = {} |
| monkeypatch.setattr(probe_trace.st, "session_state", session_state) |
| monkeypatch.setattr(probe_trace, "_MAX_CACHED_TRACES", 1) |
|
|
| trace = probe_trace.ConversationTrace( |
| cache_key="old", |
| model_name="m", |
| remote=False, |
| prompt_text="p", |
| prompt_hash="h", |
| layer=0, |
| location="post_reasoning", |
| input_ids=torch.tensor([1]), |
| activations=torch.zeros((1, 1)), |
| tokens=["x"], |
| assistant_spans=[], |
| is_special=torch.tensor([False]), |
| ) |
| old_prediction_key = "probe_predictions::old::probe" |
| kept_prediction_key = "probe_predictions::new::probe" |
| session_state[probe_trace._DERIVED_CACHE_TRACKER_KEY] = [ |
| old_prediction_key, |
| kept_prediction_key, |
| ] |
| session_state[old_prediction_key] = object() |
| session_state[kept_prediction_key] = object() |
|
|
| probe_trace._store_cached_trace("old", trace) |
| probe_trace._store_cached_trace( |
| "new", |
| probe_trace.ConversationTrace( |
| **{**trace.__dict__, "cache_key": "new"}, |
| ), |
| ) |
|
|
| assert old_prediction_key not in session_state |
| assert kept_prediction_key in session_state |
| assert session_state[probe_trace._DERIVED_CACHE_TRACKER_KEY] == [ |
| kept_prediction_key |
| ] |
|
|