Spaces:
Running
Running
| """Tests for the one-function compress() API and integrations.""" | |
| import json | |
| import pytest | |
| from headroom.compress import CompressResult, compress | |
| from headroom.hooks import CompressionHooks | |
| try: | |
| from starlette.applications import Starlette | |
| from starlette.requests import Request | |
| from starlette.responses import JSONResponse | |
| from starlette.routing import Route | |
| from starlette.testclient import TestClient | |
| from headroom.integrations.asgi import CompressionMiddleware | |
| HAS_STARLETTE = True | |
| except ImportError: | |
| HAS_STARLETTE = False | |
| # ============================================================================= | |
| # Tests: compress() function | |
| # ============================================================================= | |
| class TestCompressFunction: | |
| def test_empty_messages(self): | |
| result = compress([], model="test") | |
| assert result.messages == [] | |
| assert result.tokens_saved == 0 | |
| def test_small_messages_passthrough(self): | |
| """Small messages below compression threshold pass through unchanged.""" | |
| messages = [{"role": "user", "content": "hello"}] | |
| result = compress(messages, model="gpt-4o") | |
| assert result.messages[0]["content"] == "hello" | |
| assert result.tokens_saved == 0 | |
| def test_returns_compress_result(self): | |
| result = compress([{"role": "user", "content": "hi"}]) | |
| assert isinstance(result, CompressResult) | |
| assert hasattr(result, "messages") | |
| assert hasattr(result, "tokens_saved") | |
| assert hasattr(result, "compression_ratio") | |
| assert hasattr(result, "transforms_applied") | |
| def test_large_tool_output_compressed(self): | |
| """Large JSON tool output should be compressed.""" | |
| big_data = json.dumps( | |
| [ | |
| {"id": i, "status": "active", "name": f"item_{i}", "value": i * 17} | |
| for i in range(200) | |
| ] | |
| ) | |
| messages = [ | |
| {"role": "user", "content": "What are the top items?"}, | |
| {"role": "tool", "content": big_data, "tool_call_id": "call_1"}, | |
| ] | |
| result = compress(messages, model="gpt-4o") | |
| assert result.tokens_after <= result.tokens_before | |
| assert len(result.messages) == 2 | |
| def test_compact_json_counts_tokens_not_whitespace(self): | |
| """Compact JSON arrays should still compress under token thresholds.""" | |
| numbers = [42.0 + i * 0.1 for i in range(200)] | |
| messages = [ | |
| {"role": "system", "content": "You are helpful."}, | |
| {"role": "user", "content": "Show metrics"}, | |
| { | |
| "role": "assistant", | |
| "content": None, | |
| "tool_calls": [ | |
| { | |
| "id": "call_1", | |
| "type": "function", | |
| "function": {"name": "get_metrics", "arguments": "{}"}, | |
| } | |
| ], | |
| }, | |
| {"role": "tool", "tool_call_id": "call_1", "content": json.dumps(numbers)}, | |
| ] | |
| result = compress(messages, min_tokens_to_compress=250) | |
| assert result.tokens_saved > 0 | |
| assert any( | |
| transform.startswith("router:smart_crusher") for transform in result.transforms_applied | |
| ) | |
| def test_optimize_false_passthrough(self): | |
| """optimize=False returns messages unchanged.""" | |
| messages = [{"role": "user", "content": "hello world " * 100}] | |
| result = compress(messages, optimize=False) | |
| assert result.messages is messages | |
| assert result.tokens_saved == 0 | |
| def test_with_custom_hooks(self): | |
| """Hooks are called when provided.""" | |
| calls = [] | |
| class TrackingHooks(CompressionHooks): | |
| def pre_compress(self, messages, ctx): | |
| calls.append(("pre", len(messages))) | |
| return messages | |
| def compute_biases(self, messages, ctx): | |
| calls.append(("biases", len(messages))) | |
| return {} | |
| def post_compress(self, event): | |
| calls.append(("post", event.tokens_saved)) | |
| big_data = json.dumps([{"id": i, "status": "active"} for i in range(100)]) | |
| messages = [ | |
| {"role": "user", "content": "analyze"}, | |
| {"role": "tool", "content": big_data, "tool_call_id": "c1"}, | |
| ] | |
| compress(messages, hooks=TrackingHooks()) | |
| assert any(c[0] == "pre" for c in calls) | |
| assert any(c[0] == "biases" for c in calls) | |
| class TestCompressResultFields: | |
| def test_fields_populated(self): | |
| big_data = json.dumps([{"id": i, "type": "log"} for i in range(100)]) | |
| messages = [ | |
| {"role": "user", "content": "summarize"}, | |
| {"role": "tool", "content": big_data, "tool_call_id": "c1"}, | |
| ] | |
| result = compress(messages, model="claude-sonnet-4-5-20250929") | |
| assert result.tokens_before > 0 | |
| assert result.tokens_after >= 0 | |
| assert result.tokens_saved >= 0 | |
| assert 0.0 <= result.compression_ratio <= 1.0 | |
| # ============================================================================= | |
| # Tests: ASGI CompressionMiddleware (requires starlette) | |
| # ============================================================================= | |
| def _make_asgi_app(middleware_kwargs=None): | |
| """Create a test ASGI app with CompressionMiddleware.""" | |
| async def chat_endpoint(request: Request) -> JSONResponse: | |
| body = await request.json() | |
| return JSONResponse( | |
| { | |
| "model": "gpt-4o", | |
| "choices": [{"message": {"content": "response"}}], | |
| "usage": {"prompt_tokens": 10, "completion_tokens": 5}, | |
| "_message_count": len(body.get("messages", [])), | |
| } | |
| ) | |
| async def health(request: Request) -> JSONResponse: | |
| return JSONResponse({"status": "ok"}) | |
| app = Starlette( | |
| routes=[ | |
| Route("/health", health), | |
| Route("/v1/chat/completions", chat_endpoint, methods=["POST"]), | |
| Route("/v1/messages", chat_endpoint, methods=["POST"]), | |
| ] | |
| ) | |
| app.add_middleware(CompressionMiddleware, **(middleware_kwargs or {})) | |
| return app | |
| class TestASGIMiddleware: | |
| def test_non_llm_paths_passthrough(self): | |
| app = _make_asgi_app() | |
| client = TestClient(app) | |
| resp = client.get("/health") | |
| assert resp.status_code == 200 | |
| assert resp.json()["status"] == "ok" | |
| def test_small_messages_passthrough(self): | |
| app = _make_asgi_app() | |
| client = TestClient(app) | |
| resp = client.post( | |
| "/v1/chat/completions", | |
| json={"model": "gpt-4o", "messages": [{"role": "user", "content": "hi"}]}, | |
| ) | |
| assert resp.status_code == 200 | |
| def test_large_messages_compressed(self): | |
| """Large tool output should be compressed by middleware.""" | |
| app = _make_asgi_app() | |
| client = TestClient(app) | |
| big_data = json.dumps([{"id": i, "status": "active"} for i in range(200)]) | |
| resp = client.post( | |
| "/v1/chat/completions", | |
| json={ | |
| "model": "gpt-4o", | |
| "messages": [ | |
| {"role": "user", "content": "analyze"}, | |
| {"role": "tool", "content": big_data, "tool_call_id": "c1"}, | |
| ], | |
| }, | |
| ) | |
| assert resp.status_code == 200 | |
| def test_anthropic_path(self): | |
| """Works with Anthropic /v1/messages path.""" | |
| app = _make_asgi_app() | |
| client = TestClient(app) | |
| resp = client.post( | |
| "/v1/messages", | |
| json={ | |
| "model": "claude-sonnet-4-5-20250929", | |
| "messages": [{"role": "user", "content": "hello"}], | |
| }, | |
| ) | |
| assert resp.status_code == 200 | |
| def test_get_requests_passthrough(self): | |
| """GET requests to LLM paths pass through.""" | |
| app = _make_asgi_app() | |
| client = TestClient(app) | |
| resp = client.get("/v1/chat/completions") | |
| assert resp.status_code in (200, 405) | |
| # ============================================================================= | |
| # Tests: LiteLLM Callback | |
| # ============================================================================= | |
| class TestLiteLLMCallback: | |
| def test_callback_imports(self): | |
| """Verify the callback can be imported.""" | |
| from headroom.integrations.litellm_callback import HeadroomCallback | |
| callback = HeadroomCallback() | |
| assert callback.total_tokens_saved == 0 | |
| def test_callback_compresses_messages(self): | |
| """Callback compresses messages in pre_call_hook.""" | |
| import asyncio | |
| from headroom.integrations.litellm_callback import HeadroomCallback | |
| callback = HeadroomCallback() | |
| big_data = json.dumps([{"id": i, "status": "active"} for i in range(200)]) | |
| data = { | |
| "model": "gpt-4o", | |
| "messages": [ | |
| {"role": "user", "content": "analyze"}, | |
| {"role": "tool", "content": big_data, "tool_call_id": "c1"}, | |
| ], | |
| } | |
| result = asyncio.run(callback.async_pre_call_hook("key", data, "completion")) | |
| assert result is data | |
| def test_callback_ignores_non_completion(self): | |
| """Non-completion calls are passed through.""" | |
| import asyncio | |
| from headroom.integrations.litellm_callback import HeadroomCallback | |
| callback = HeadroomCallback() | |
| data = {"messages": [{"role": "user", "content": "hi"}]} | |
| result = asyncio.run(callback.async_pre_call_hook("key", data, "embedding")) | |
| assert result is data | |