"""Tests for PrefixCacheTracker — cache-aware compression.""" import time import pytest from headroom.cache.prefix_tracker import ( FreezeStats, PrefixCacheTracker, PrefixFreezeConfig, SessionTrackerStore, ) class TestPrefixCacheTracker: """Test PrefixCacheTracker core functionality.""" @pytest.fixture def tracker(self): return PrefixCacheTracker("anthropic") @pytest.fixture def openai_tracker(self): return PrefixCacheTracker("openai") def test_turn_0_no_freeze(self, tracker): """First turn should never freeze — no cache state yet.""" assert tracker.get_frozen_message_count() == 0 def test_turn_1_with_cache_hit_freezes(self, tracker): """After turn 1 with cache hits, turn 2 should freeze.""" messages = [ {"role": "system", "content": "You are a helpful assistant." * 100}, {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, ] # Simulate: provider cached 2000 tokens (system + user) token_counts = [1500, 50, 500] tracker.update_from_response( cache_read_tokens=0, cache_write_tokens=2050, messages=messages, message_token_counts=token_counts, ) # On turn 1, no prior cache , so nothing frozen assert tracker.get_frozen_message_count() == 0 def test_partial_freeze(self, tracker): """Only messages that fit within cached tokens are frozen.""" messages = [ {"role": "system", "content": "System prompt" * 50}, {"role": "user", "content": "First question" * 50}, {"role": "assistant", "content": "First answer" * 50}, {"role": "user", "content": "Second question"}, ] token_counts = [2000, 500, 500, 50] tracker.update_from_response( cache_read_tokens=2500, cache_write_tokens=0, messages=messages, message_token_counts=token_counts, ) # 2000 + 500 = 2500 <= 2500, but 2000 + 500 + 500 = 3000 > 2500 assert tracker.get_frozen_message_count() == 2 def test_cold_start_no_freeze(self, tracker): """If cache_read=0 and cache_write=0, don't freeze.""" messages = [{"role": "user", "content": "Hello"}] tracker.update_from_response( cache_read_tokens=0, cache_write_tokens=0, messages=messages, ) assert tracker.get_frozen_message_count() == 0 def test_cache_write_freezes_next_turn(self, tracker): """Cache writes (new cache entries) should be frozen on the next turn.""" messages = [ {"role": "system", "content": "System" * 200}, {"role": "user", "content": "Hello"}, ] token_counts = [1500, 50] # Turn 1: provider writes to cache (above min threshold) tracker.update_from_response( cache_read_tokens=0, cache_write_tokens=1550, messages=messages, message_token_counts=token_counts, ) # Turn 1: writes don't freeze current request assert tracker.get_frozen_message_count() == 0 # Simulate second turn where those writes become reads tracker.update_from_response( cache_read_tokens=1550, cache_write_tokens=0, messages=messages, message_token_counts=token_counts, ) # Turn 2: now freeze what was written previously assert tracker.get_frozen_message_count() == 2 def test_min_cached_tokens_threshold(self): """Below min_cached_tokens, no freeze.""" config = PrefixFreezeConfig(min_cached_tokens=2000) tracker = PrefixCacheTracker("anthropic", config) messages = [{"role": "user", "content": "Hello"}] # Turn 1: only 500 tokens cached — below threshold tracker.update_from_response( cache_read_tokens=0, cache_write_tokens=500, messages=messages, message_token_counts=[500], ) assert tracker.get_frozen_message_count() == 0 def test_disabled_config(self): """Disabled config always returns 0.""" config = PrefixFreezeConfig(enabled=False) tracker = PrefixCacheTracker("anthropic", config) messages = [{"role": "system", "content": "System" * 500}] tracker.update_from_response( cache_read_tokens=5000, cache_write_tokens=0, messages=messages, message_token_counts=[5000], ) assert tracker.get_frozen_message_count() == 0 def test_turn_number_increments(self, tracker): """Turn number should increment on each update.""" messages = [{"role": "user", "content": "Hello"}] assert tracker._turn_number == 0 tracker.update_from_response(0, 0, messages) assert tracker._turn_number == 1 tracker.update_from_response(0, 0, messages) assert tracker._turn_number == 2 def test_stats_tracking(self, tracker): """Stats should reflect tracker state.""" stats = tracker.stats assert isinstance(stats, FreezeStats) assert stats.busts_avoided == 0 assert stats.tokens_preserved == 0 assert stats.turn_number == 0 def test_record_bust_avoided(self, tracker): """Recording bust avoided should update stats.""" tracker.record_bust_avoided(tokens_preserved=5000, compression_foregone=500) tracker.record_bust_avoided(tokens_preserved=3000, compression_foregone=200) stats = tracker.stats assert stats.busts_avoided == 2 assert stats.tokens_preserved == 8000 assert stats.compression_foregone_tokens == 700 assert stats.net_benefit_tokens == 7300 def test_should_force_compress_outside_frozen(self, tracker): """Messages outside frozen prefix should always be compressed.""" tracker._cached_message_count = 3 assert tracker.should_force_compress(5, 1000, 200) is True def test_should_force_compress_when_savings_exceed_discount(self, tracker): """For Anthropic (90% discount), compression must save >90% to be worth it.""" tracker._cached_message_count = 5 # 95% savings > 90% discount — should force compress assert tracker.should_force_compress(2, 1000, 50) is True # 50% savings < 90% discount — should NOT force compress assert tracker.should_force_compress(2, 1000, 500) is False def test_should_force_compress_openai(self, openai_tracker): """For OpenAI (50% discount), compression must save >50% to be worth it.""" openai_tracker._cached_message_count = 5 # 60% savings > 50% discount — should force compress assert openai_tracker.should_force_compress(2, 1000, 400) is True # 40% savings < 50% discount — should NOT force compress assert openai_tracker.should_force_compress(2, 1000, 600) is False def test_estimate_message_tokens(self): """Token estimation should roughly match character / 3.5.""" messages = [ {"role": "system", "content": "A" * 350}, # ~100 tokens {"role": "user", "content": "B" * 70}, # ~20 tokens ] counts = PrefixCacheTracker._estimate_message_tokens(messages) assert len(counts) == 2 assert counts[0] > counts[1] # System should have more tokens def test_estimate_content_blocks(self): """Token estimation should handle Anthropic content blocks.""" messages = [ { "role": "user", "content": [ {"type": "text", "text": "A" * 350}, {"type": "text", "text": "B" * 350}, ], }, ] counts = PrefixCacheTracker._estimate_message_tokens(messages) assert len(counts) == 1 assert counts[0] > 100 def test_estimate_tool_result_content(self): """Token estimation should count tool_result content field.""" tool_content = "x" * 3500 # ~1000 tokens messages = [ { "role": "user", "content": [ { "type": "tool_result", "tool_use_id": "t1", "content": tool_content, } ], }, ] counts = PrefixCacheTracker._estimate_message_tokens(messages) assert len(counts) == 1 # Should be ~1000 tokens, definitely > 100 assert counts[0] > 100 def test_estimate_tool_use_input(self): """Token estimation should count tool_use input field.""" messages = [ { "role": "assistant", "content": [ { "type": "tool_use", "id": "t1", "name": "Read", "input": {"file_path": "/very/long/path/" + "x" * 700}, } ], }, ] counts = PrefixCacheTracker._estimate_message_tokens(messages) assert len(counts) == 1 # Should count the serialized input dict assert counts[0] > 50 def test_estimate_tool_result_nested_blocks(self): """Token estimation should handle nested content blocks in tool_result.""" messages = [ { "role": "user", "content": [ { "type": "tool_result", "tool_use_id": "t1", "content": [ {"type": "text", "text": "A" * 3500}, ], } ], }, ] counts = PrefixCacheTracker._estimate_message_tokens(messages) assert len(counts) == 1 assert counts[0] > 100 def test_session_ttl_expiry(self): """Tracker should report as expired after TTL.""" config = PrefixFreezeConfig(session_ttl_seconds=1) tracker = PrefixCacheTracker("anthropic", config) assert tracker.is_expired is False # Simulate time passing tracker._last_activity = time.time() - 2 assert tracker.is_expired is True class TestSessionTrackerStore: """Test SessionTrackerStore management.""" @pytest.fixture def store(self): return SessionTrackerStore() def test_get_or_create_new(self, store): """Should create a new tracker for unknown session.""" tracker = store.get_or_create("session-1", "anthropic") assert isinstance(tracker, PrefixCacheTracker) assert tracker.provider == "anthropic" def test_get_or_create_existing(self, store): """Should return the same tracker for the same session.""" tracker1 = store.get_or_create("session-1", "anthropic") tracker2 = store.get_or_create("session-1", "anthropic") assert tracker1 is tracker2 def test_different_sessions(self, store): """Different sessions should get different trackers.""" tracker1 = store.get_or_create("session-1", "anthropic") tracker2 = store.get_or_create("session-2", "openai") assert tracker1 is not tracker2 assert tracker1.provider == "anthropic" assert tracker2.provider == "openai" def test_active_sessions_count(self, store): """Should track the number of active sessions.""" assert store.active_sessions == 0 store.get_or_create("s1", "anthropic") assert store.active_sessions == 1 store.get_or_create("s2", "openai") assert store.active_sessions == 2 def test_cleanup_expired(self, store): """Should remove expired sessions on cleanup.""" config = PrefixFreezeConfig(session_ttl_seconds=1) store = SessionTrackerStore(default_config=config) tracker = store.get_or_create("expired-session", "anthropic") tracker._last_activity = time.time() - 2 # Force cleanup store._last_cleanup = 0 store._maybe_cleanup() assert store.active_sessions == 0 def test_compute_session_id_from_header(self, store): """Should use x-headroom-session-id header if present.""" class MockRequest: headers = {"x-headroom-session-id": "explicit-id-123"} session_id = store.compute_session_id( MockRequest(), "claude-3", [{"role": "user", "content": "Hi"}] ) assert session_id == "explicit-id-123" def test_compute_session_id_from_hash(self, store): """Should hash model + system prompt as fallback.""" class MockRequest: headers = {} messages = [ {"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}, ] id1 = store.compute_session_id(MockRequest(), "claude-3", messages) id2 = store.compute_session_id(MockRequest(), "claude-3", messages) assert id1 == id2 # Stable hash assert len(id1) == 16 # Different model = different session id3 = store.compute_session_id(MockRequest(), "gpt-4", messages) assert id3 != id1 def test_compute_session_id_no_system(self, store): """Should work without system messages.""" class MockRequest: headers = {} messages = [{"role": "user", "content": "Hi"}] session_id = store.compute_session_id(MockRequest(), "claude-3", messages) assert isinstance(session_id, str) assert len(session_id) == 16 class TestMultiTurnScenario: """Integration-style tests simulating multi-turn conversations.""" def test_five_turn_conversation(self): """Simulate a 5-turn conversation with growing prefix.""" tracker = PrefixCacheTracker("anthropic") # Turn 1: System + User (cold start, no cache) messages_t1 = [ {"role": "system", "content": "System prompt" * 200}, {"role": "user", "content": "Question 1"}, ] token_counts_t1 = [2000, 50] assert tracker.get_frozen_message_count() == 0 # No freeze on turn 1 tracker.update_from_response( cache_read_tokens=0, cache_write_tokens=2050, messages=messages_t1, message_token_counts=token_counts_t1, ) # Turn 1 had no reads, so nothing frozen yet (writes become reads next turn) frozen = tracker.get_frozen_message_count() assert frozen == 0 # Turn 2: Previous messages cached, new user message added messages_t2 = messages_t1 + [ {"role": "assistant", "content": "Answer 1"}, {"role": "user", "content": "Question 2"}, ] token_counts_t2 = [2000, 50, 200, 50] tracker.update_from_response( cache_read_tokens=2050, cache_write_tokens=250, messages=messages_t2, message_token_counts=token_counts_t2, ) frozen = tracker.get_frozen_message_count() assert frozen == 2 # System + User1 frozen (from turn 1 writes → now reads) # Turn 3: Even more cached messages_t3 = messages_t2 + [ {"role": "assistant", "content": "Answer 2"}, {"role": "user", "content": "Question 3"}, ] token_counts_t3 = [2000, 50, 200, 50, 200, 50] frozen = tracker.get_frozen_message_count() assert frozen == 2 # Still 2 — no update_from_response for turn 3 yet tracker.update_from_response( cache_read_tokens=2300, cache_write_tokens=250, messages=messages_t3, message_token_counts=token_counts_t3, ) # Turn 4: turn 3's reads now freeze more messages frozen = tracker.get_frozen_message_count() assert frozen == 4 # System + User1 + Asst1 + User2 frozen # Verify turn count assert tracker._turn_number == 3 def test_cache_bust_resets_freeze(self): """If cache is busted (0 read, 0 write), freeze should reset.""" tracker = PrefixCacheTracker("anthropic") messages = [ {"role": "system", "content": "System" * 200}, {"role": "user", "content": "Hello"}, ] # Turn 1: Cache established (writes only, reads=0 → nothing frozen yet) tracker.update_from_response( cache_read_tokens=0, cache_write_tokens=2000, messages=messages, message_token_counts=[1500, 500], ) assert tracker.get_frozen_message_count() == 0 # Writes don't freeze current turn # Turn 2: Cache bust (0 reads, system prompt changed) tracker.update_from_response( cache_read_tokens=0, cache_write_tokens=0, messages=messages, message_token_counts=[1500, 500], ) # After a bust with 0 total, freeze should reset assert tracker.get_frozen_message_count() == 0