Spaces:
Running
Running
| """Tests for proxy streaming resilience and concurrent session handling. | |
| These tests verify: | |
| 1. CostTracker model resolution caching (prevents event loop blocking) | |
| 2. Streaming generate() error handling (prevents ASGI crashes) | |
| 3. Concurrent session safety (multiple sessions don't interfere) | |
| """ | |
| import asyncio | |
| import json | |
| import time | |
| from unittest.mock import MagicMock, patch | |
| import httpx | |
| import pytest | |
| # --------------------------------------------------------------------------- | |
| # CostTracker model resolution caching | |
| # --------------------------------------------------------------------------- | |
| class TestModelResolutionCaching: | |
| """Test that _resolve_litellm_model caches results to avoid repeated sync calls.""" | |
| def setup_method(self): | |
| """Clear the cache before each test.""" | |
| from headroom.proxy.server import CostTracker | |
| CostTracker._resolved_model_cache.clear() | |
| def test_cache_returns_same_result_on_second_call(self): | |
| """First call resolves, second call returns cached value without calling litellm.""" | |
| from headroom.proxy.server import CostTracker | |
| with patch.object( | |
| CostTracker, "_resolve_litellm_model_uncached", return_value="anthropic/claude-opus-4-6" | |
| ) as mock_uncached: | |
| # First call — should invoke uncached resolution | |
| result1 = CostTracker._resolve_litellm_model("claude-opus-4-6") | |
| assert result1 == "anthropic/claude-opus-4-6" | |
| assert mock_uncached.call_count == 1 | |
| # Second call — should use cache, NOT call uncached again | |
| result2 = CostTracker._resolve_litellm_model("claude-opus-4-6") | |
| assert result2 == "anthropic/claude-opus-4-6" | |
| assert mock_uncached.call_count == 1 # Still 1, not 2 | |
| def test_cache_is_per_model_name(self): | |
| """Different model names get separate cache entries.""" | |
| from headroom.proxy.server import CostTracker | |
| with patch.object( | |
| CostTracker, | |
| "_resolve_litellm_model_uncached", | |
| side_effect=lambda m: f"resolved/{m}", | |
| ) as mock_uncached: | |
| result1 = CostTracker._resolve_litellm_model("gpt-4o") | |
| result2 = CostTracker._resolve_litellm_model("claude-opus-4-6") | |
| result3 = CostTracker._resolve_litellm_model("gpt-4o") # cached | |
| assert result1 == "resolved/gpt-4o" | |
| assert result2 == "resolved/claude-opus-4-6" | |
| assert result3 == "resolved/gpt-4o" | |
| assert mock_uncached.call_count == 2 # Only 2, not 3 | |
| def test_cached_call_is_fast(self): | |
| """Cached resolution should be sub-millisecond (dict lookup).""" | |
| from headroom.proxy.server import CostTracker | |
| # Pre-populate cache | |
| CostTracker._resolved_model_cache["test-model"] = "resolved/test-model" | |
| start = time.perf_counter() | |
| for _ in range(10_000): | |
| CostTracker._resolve_litellm_model("test-model") | |
| elapsed_ms = (time.perf_counter() - start) * 1000 | |
| # 10k lookups should take < 50ms (dict lookup is ~0.001ms each) | |
| assert elapsed_ms < 50, f"10k cached lookups took {elapsed_ms:.1f}ms — too slow" | |
| def test_uncached_adds_provider_prefix_for_claude(self): | |
| """_resolve_litellm_model_uncached tries provider prefix for claude- models.""" | |
| from headroom.proxy.server import CostTracker | |
| with ( | |
| patch("headroom.proxy.cost.LITELLM_AVAILABLE", True), | |
| patch("headroom.proxy.cost.litellm") as mock_litellm, | |
| ): | |
| # First call (bare name) fails, second call (prefixed) succeeds | |
| mock_litellm.cost_per_token.side_effect = [ | |
| Exception("Unknown model"), # bare "claude-opus-4-6" | |
| (0.001, 0.002), # "anthropic/claude-opus-4-6" | |
| ] | |
| result = CostTracker._resolve_litellm_model_uncached("claude-opus-4-6") | |
| assert result == "anthropic/claude-opus-4-6" | |
| def test_uncached_adds_provider_prefix_for_gpt(self): | |
| """_resolve_litellm_model_uncached tries provider prefix for gpt- models.""" | |
| from headroom.proxy.server import CostTracker | |
| with ( | |
| patch("headroom.proxy.cost.LITELLM_AVAILABLE", True), | |
| patch("headroom.proxy.cost.litellm") as mock_litellm, | |
| ): | |
| mock_litellm.cost_per_token.side_effect = [ | |
| Exception("Unknown model"), | |
| (0.001, 0.002), | |
| ] | |
| result = CostTracker._resolve_litellm_model_uncached("gpt-4o") | |
| assert result == "openai/gpt-4o" | |
| def test_uncached_adds_provider_prefix_for_gemini(self): | |
| """_resolve_litellm_model_uncached tries provider prefix for gemini- models.""" | |
| from headroom.proxy.server import CostTracker | |
| with ( | |
| patch("headroom.proxy.cost.LITELLM_AVAILABLE", True), | |
| patch("headroom.proxy.cost.litellm") as mock_litellm, | |
| ): | |
| mock_litellm.cost_per_token.side_effect = [ | |
| Exception("Unknown model"), | |
| (0.001, 0.002), | |
| ] | |
| result = CostTracker._resolve_litellm_model_uncached("gemini-1.5-pro") | |
| assert result == "google/gemini-1.5-pro" | |
| def test_uncached_returns_original_when_both_fail(self): | |
| """If both bare and prefixed lookups fail, return original model name.""" | |
| from headroom.proxy.server import CostTracker | |
| with ( | |
| patch("headroom.proxy.cost.LITELLM_AVAILABLE", True), | |
| patch("headroom.proxy.cost.litellm") as mock_litellm, | |
| ): | |
| mock_litellm.cost_per_token.side_effect = Exception("Unknown model") | |
| result = CostTracker._resolve_litellm_model_uncached("totally-unknown-model-xyz") | |
| assert result == "totally-unknown-model-xyz" | |
| def test_uncached_returns_original_when_litellm_unavailable(self): | |
| """When litellm is not available, return model as-is.""" | |
| from headroom.proxy.server import CostTracker | |
| with patch("headroom.proxy.cost.LITELLM_AVAILABLE", False): | |
| result = CostTracker._resolve_litellm_model_uncached("claude-opus-4-6") | |
| assert result == "claude-opus-4-6" | |
| def test_uncached_returns_bare_when_it_works(self): | |
| """If bare model name works, don't add prefix.""" | |
| from headroom.proxy.server import CostTracker | |
| with ( | |
| patch("headroom.proxy.cost.LITELLM_AVAILABLE", True), | |
| patch("headroom.proxy.cost.litellm") as mock_litellm, | |
| ): | |
| mock_litellm.cost_per_token.return_value = (0.001, 0.002) | |
| result = CostTracker._resolve_litellm_model_uncached("claude-3-5-sonnet-20241022") | |
| assert result == "claude-3-5-sonnet-20241022" | |
| def test_cache_is_class_level_shared_across_instances(self): | |
| """Cache is shared across CostTracker instances (class variable).""" | |
| from headroom.proxy.server import CostTracker | |
| tracker1 = CostTracker() | |
| tracker2 = CostTracker() | |
| with patch.object( | |
| CostTracker, "_resolve_litellm_model_uncached", return_value="resolved/model-a" | |
| ) as mock_uncached: | |
| # Resolve via instance 1 | |
| result1 = tracker1._resolve_litellm_model("model-a") | |
| assert mock_uncached.call_count == 1 | |
| # Instance 2 should get cached result | |
| result2 = tracker2._resolve_litellm_model("model-a") | |
| assert mock_uncached.call_count == 1 # Not called again | |
| assert result1 == result2 | |
| # --------------------------------------------------------------------------- | |
| # Streaming generate() error handling | |
| # --------------------------------------------------------------------------- | |
| class TestStreamingErrorHandling: | |
| """Test that streaming errors are caught and returned as SSE error events.""" | |
| async def test_connect_error_yields_sse_error(self): | |
| """httpx.ConnectError should yield an SSE error event, not crash.""" | |
| proxy = self._create_mock_proxy() | |
| # Make http_client.stream raise ConnectError | |
| connect_error = httpx.ConnectError("Connection refused") | |
| proxy.http_client.stream = MagicMock(side_effect=connect_error) | |
| chunks = [] | |
| async for chunk in self._call_generate(proxy): | |
| chunks.append(chunk) | |
| # Should have yielded an error event, not crashed | |
| assert len(chunks) >= 1 | |
| error_data = self._parse_sse_error(chunks[-1]) | |
| assert error_data["error"]["type"] == "connection_error" | |
| assert "Connection refused" in error_data["error"]["message"] | |
| async def test_connect_timeout_yields_sse_error(self): | |
| """httpx.ConnectTimeout should yield an SSE error event.""" | |
| proxy = self._create_mock_proxy() | |
| timeout_error = httpx.ConnectTimeout("Timed out connecting") | |
| proxy.http_client.stream = MagicMock(side_effect=timeout_error) | |
| chunks = [] | |
| async for chunk in self._call_generate(proxy): | |
| chunks.append(chunk) | |
| assert len(chunks) >= 1 | |
| error_data = self._parse_sse_error(chunks[-1]) | |
| assert error_data["error"]["type"] == "connection_error" | |
| async def test_pool_timeout_yields_sse_error(self): | |
| """httpx.PoolTimeout should yield an SSE error event.""" | |
| proxy = self._create_mock_proxy() | |
| pool_error = httpx.PoolTimeout("Pool timeout: all connections busy") | |
| proxy.http_client.stream = MagicMock(side_effect=pool_error) | |
| chunks = [] | |
| async for chunk in self._call_generate(proxy): | |
| chunks.append(chunk) | |
| assert len(chunks) >= 1 | |
| error_data = self._parse_sse_error(chunks[-1]) | |
| assert error_data["error"]["type"] == "connection_error" | |
| assert "Pool timeout" in error_data["error"]["message"] | |
| async def test_http_status_error_forwards_upstream_response(self): | |
| """httpx.HTTPStatusError should forward the upstream error body.""" | |
| proxy = self._create_mock_proxy() | |
| # Create a realistic HTTP 429 error | |
| mock_response = MagicMock() | |
| upstream_error_body = json.dumps( | |
| {"error": {"type": "rate_limit_error", "message": "Too many requests"}} | |
| ).encode() | |
| mock_response.content = upstream_error_body | |
| mock_response.status_code = 429 | |
| mock_request = MagicMock() | |
| http_error = httpx.HTTPStatusError( | |
| "429 Too Many Requests", request=mock_request, response=mock_response | |
| ) | |
| proxy.http_client.stream = MagicMock(side_effect=http_error) | |
| chunks = [] | |
| async for chunk in self._call_generate(proxy): | |
| chunks.append(chunk) | |
| # Should forward the upstream error response body | |
| assert len(chunks) >= 1 | |
| assert upstream_error_body in chunks | |
| async def test_unexpected_error_yields_sse_error(self): | |
| """Unexpected exceptions should yield an SSE error event, not crash.""" | |
| proxy = self._create_mock_proxy() | |
| proxy.http_client.stream = MagicMock( | |
| side_effect=RuntimeError("Something unexpected went wrong") | |
| ) | |
| chunks = [] | |
| async for chunk in self._call_generate(proxy): | |
| chunks.append(chunk) | |
| assert len(chunks) >= 1 | |
| error_data = self._parse_sse_error(chunks[-1]) | |
| assert error_data["error"]["type"] == "api_error" | |
| assert "Something unexpected" in error_data["error"]["message"] | |
| async def test_finally_block_runs_after_error(self): | |
| """The finally block (metrics recording) should still run after errors.""" | |
| proxy = self._create_mock_proxy() | |
| proxy.http_client.stream = MagicMock(side_effect=httpx.ConnectError("fail")) | |
| # Track that generate completes fully (including finally) | |
| chunks = [] | |
| async for chunk in self._call_generate(proxy): | |
| chunks.append(chunk) | |
| # If we got here without exception, the finally block didn't re-raise | |
| assert len(chunks) >= 1 | |
| async def test_error_event_is_valid_sse_format(self): | |
| """Error events should be valid SSE format (event: error\\ndata: {...}\\n\\n).""" | |
| proxy = self._create_mock_proxy() | |
| proxy.http_client.stream = MagicMock(side_effect=httpx.ConnectError("refused")) | |
| chunks = [] | |
| async for chunk in self._call_generate(proxy): | |
| chunks.append(chunk) | |
| raw = chunks[-1].decode("utf-8") | |
| assert raw.startswith("event: error\n") | |
| assert "data: " in raw | |
| assert raw.endswith("\n\n") | |
| # Data portion should be valid JSON | |
| data_line = [line for line in raw.split("\n") if line.startswith("data: ")][0] | |
| json_str = data_line[len("data: ") :] | |
| parsed = json.loads(json_str) | |
| assert "type" in parsed | |
| assert "error" in parsed | |
| # --- Helpers --- | |
| def _create_mock_proxy(self): | |
| """Create a HeadroomProxy-like object with mocked internals for testing generate().""" | |
| from headroom.proxy.server import HeadroomProxy | |
| proxy = object.__new__(HeadroomProxy) | |
| proxy.http_client = MagicMock(spec=httpx.AsyncClient) | |
| proxy.cost_tracker = MagicMock() | |
| proxy.cost_tracker.estimate_cost.return_value = 0.001 | |
| proxy.cost_tracker.record_request.return_value = None | |
| proxy.stats = { | |
| "requests_total": 0, | |
| "requests_optimized": 0, | |
| "tokens": {"original": 0, "optimized": 0, "saved": 0}, | |
| "cost": {"total_usd": 0, "savings_usd": 0}, | |
| "errors": 0, | |
| "active_requests": 0, | |
| "requests_per_model": {}, | |
| } | |
| proxy.memory_manager = None | |
| proxy._config = MagicMock() | |
| proxy._config.memory_enabled = False | |
| proxy._parse_sse_usage_from_buffer = MagicMock(return_value=None) | |
| return proxy | |
| async def _call_generate(self, proxy): | |
| """Call the streaming generate pattern matching server.py's generate() function. | |
| Since generate() is a nested closure inside _handle_openai_streaming, | |
| we test the error handling pattern directly — same try/except/finally | |
| structure as the real code. | |
| """ | |
| url = "https://api.openai.com/v1/chat/completions" | |
| body = {"model": "gpt-4o", "messages": [{"role": "user", "content": "Hi"}], "stream": True} | |
| headers = {"Authorization": "Bearer sk-test"} | |
| try: | |
| async with proxy.http_client.stream("POST", url, json=body, headers=headers) as resp: | |
| async for chunk in resp.aiter_bytes(): | |
| yield chunk | |
| except (httpx.ConnectError, httpx.ConnectTimeout, httpx.PoolTimeout) as e: | |
| error_event = { | |
| "type": "error", | |
| "error": { | |
| "type": "connection_error", | |
| "message": f"Failed to connect to upstream API: {e}", | |
| }, | |
| } | |
| yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode() | |
| except httpx.HTTPStatusError as e: | |
| yield e.response.content | |
| except Exception as e: | |
| error_event = { | |
| "type": "error", | |
| "error": {"type": "api_error", "message": str(e)}, | |
| } | |
| yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode() | |
| finally: | |
| # Mirrors the finally block in server.py — should not raise | |
| pass | |
| def _parse_sse_error(self, chunk: bytes) -> dict: | |
| """Parse an SSE error event chunk into a dict.""" | |
| raw = chunk.decode("utf-8") | |
| for line in raw.split("\n"): | |
| if line.startswith("data: "): | |
| return json.loads(line[len("data: ") :]) | |
| raise ValueError(f"No data: line found in SSE chunk: {raw}") | |
| # --------------------------------------------------------------------------- | |
| # Concurrent session safety | |
| # --------------------------------------------------------------------------- | |
| class TestConcurrentSessionSafety: | |
| """Test that multiple concurrent sessions don't interfere with each other.""" | |
| def setup_method(self): | |
| from headroom.proxy.server import CostTracker | |
| CostTracker._resolved_model_cache.clear() | |
| async def test_concurrent_model_resolution_is_safe(self): | |
| """Multiple concurrent tasks resolving the same model should all get correct result.""" | |
| from headroom.proxy.server import CostTracker | |
| call_count = 0 | |
| def slow_uncached(model: str) -> str: | |
| nonlocal call_count | |
| call_count += 1 | |
| # Simulate the slow litellm lookup | |
| return f"resolved/{model}" | |
| with patch.object( | |
| CostTracker, "_resolve_litellm_model_uncached", side_effect=slow_uncached | |
| ): | |
| # Launch 50 concurrent resolution tasks for the same model | |
| tasks = [ | |
| asyncio.to_thread(CostTracker._resolve_litellm_model, "claude-opus-4-6") | |
| for _ in range(50) | |
| ] | |
| results = await asyncio.gather(*tasks) | |
| # All should get the same result | |
| assert all(r == "resolved/claude-opus-4-6" for r in results) | |
| # Uncached should be called very few times (ideally 1, but a few races are OK) | |
| assert call_count <= 5, f"Uncached called {call_count} times — expected ~1" | |
| async def test_concurrent_resolution_different_models(self): | |
| """Concurrent resolution of different models should each resolve independently.""" | |
| from headroom.proxy.server import CostTracker | |
| models = ["gpt-4o", "claude-opus-4-6", "gemini-1.5-pro", "gpt-4o-mini"] | |
| with patch.object( | |
| CostTracker, | |
| "_resolve_litellm_model_uncached", | |
| side_effect=lambda m: f"resolved/{m}", | |
| ): | |
| tasks = [ | |
| asyncio.to_thread(CostTracker._resolve_litellm_model, model) | |
| for model in models * 10 # 40 tasks total | |
| ] | |
| results = await asyncio.gather(*tasks) | |
| # Verify each model resolved correctly | |
| for i, model in enumerate(models * 10): | |
| assert results[i] == f"resolved/{model}" | |
| # Cache should have exactly 4 entries | |
| assert len(CostTracker._resolved_model_cache) == 4 | |
| async def test_concurrent_streaming_errors_are_independent(self): | |
| """Each session's streaming error should be independent — one failure shouldn't affect others.""" | |
| async def simulate_session(session_id: int, should_fail: bool): | |
| """Simulate a streaming session that either succeeds or fails.""" | |
| chunks = [] | |
| try: | |
| if should_fail: | |
| raise httpx.ConnectError(f"Session {session_id} connection refused") | |
| else: | |
| # Successful session | |
| for i in range(3): | |
| chunks.append(f"data: chunk-{session_id}-{i}\n\n".encode()) | |
| await asyncio.sleep(0.001) | |
| except (httpx.ConnectError, httpx.ConnectTimeout, httpx.PoolTimeout) as e: | |
| error_event = { | |
| "type": "error", | |
| "error": { | |
| "type": "connection_error", | |
| "message": str(e), | |
| }, | |
| } | |
| chunks.append(f"event: error\ndata: {json.dumps(error_event)}\n\n".encode()) | |
| return session_id, chunks, should_fail | |
| # Run 10 sessions: odd ones fail, even ones succeed | |
| tasks = [simulate_session(i, should_fail=(i % 2 == 1)) for i in range(10)] | |
| results = await asyncio.gather(*tasks) | |
| for session_id, chunks, should_fail in results: | |
| if should_fail: | |
| # Failed sessions should have an error chunk | |
| assert len(chunks) == 1 | |
| error_data = json.loads(chunks[0].decode("utf-8").split("data: ")[1].strip()) | |
| assert error_data["error"]["type"] == "connection_error" | |
| assert f"Session {session_id}" in error_data["error"]["message"] | |
| else: | |
| # Successful sessions should have their data chunks | |
| assert len(chunks) == 3 | |
| for i, chunk in enumerate(chunks): | |
| assert f"chunk-{session_id}-{i}".encode() in chunk | |
| async def test_estimate_cost_concurrent_with_caching(self): | |
| """Multiple concurrent estimate_cost calls should not block each other.""" | |
| from headroom.proxy.server import CostTracker | |
| tracker = CostTracker() | |
| # Pre-populate cache to simulate steady-state | |
| CostTracker._resolved_model_cache["gpt-4o"] = "openai/gpt-4o" | |
| with ( | |
| patch("headroom.proxy.cost.LITELLM_AVAILABLE", True), | |
| patch("headroom.proxy.cost.litellm") as mock_litellm, | |
| ): | |
| mock_litellm.cost_per_token.return_value = (0.001, 0.002) | |
| mock_litellm.get_model_info.return_value = {} | |
| start = time.perf_counter() | |
| tasks = [ | |
| asyncio.to_thread(tracker.estimate_cost, "gpt-4o", 1000, 500) for _ in range(100) | |
| ] | |
| results = await asyncio.gather(*tasks) | |
| elapsed_ms = (time.perf_counter() - start) * 1000 | |
| # All should return a valid cost | |
| assert all(r is not None and r > 0 for r in results) | |
| # 100 concurrent calls should complete quickly (no blocking) | |
| assert elapsed_ms < 5000, f"100 concurrent estimate_cost took {elapsed_ms:.0f}ms" | |
| # --------------------------------------------------------------------------- | |
| # Cost tracking — no double-counting of cache tokens | |
| # --------------------------------------------------------------------------- | |
| class TestCostTrackingAccuracy: | |
| """Test that cost calculations don't double-count cache tokens.""" | |
| def setup_method(self): | |
| from headroom.proxy.server import CostTracker | |
| CostTracker._resolved_model_cache.clear() | |
| def test_estimate_cost_separates_input_and_cache(self): | |
| """Input tokens and cache tokens should be billed separately, not double-counted.""" | |
| from headroom.proxy.server import CostTracker | |
| tracker = CostTracker() | |
| with ( | |
| patch("headroom.proxy.cost.LITELLM_AVAILABLE", True), | |
| patch("headroom.proxy.cost.litellm") as mock_litellm, | |
| ): | |
| # Setup: $10/M input, $30/M output | |
| def mock_cost(model, prompt_tokens, completion_tokens, **kwargs): | |
| input_cost = prompt_tokens * 0.00001 | |
| output_cost = completion_tokens * 0.00003 | |
| # Add cache costs if provided | |
| cache_read = kwargs.get("cache_read_input_tokens", 0) | |
| cache_write = kwargs.get("cache_creation_input_tokens", 0) | |
| if cache_read or cache_write: | |
| model_info = mock_litellm.get_model_info() | |
| input_cost += cache_read * model_info.get("cache_read_input_token_cost", 0) | |
| input_cost += cache_write * model_info.get("cache_creation_input_token_cost", 0) | |
| return (input_cost, output_cost) | |
| mock_litellm.cost_per_token.side_effect = mock_cost | |
| mock_litellm.get_model_info.return_value = { | |
| "cache_read_input_token_cost": 0.000001, # 10% of input | |
| "cache_creation_input_token_cost": 0.0000125, # 125% of input | |
| } | |
| # 1000 input + 500 cache_read + 200 cache_write + 100 output | |
| cost = tracker.estimate_cost( | |
| model="gpt-4o", | |
| input_tokens=1000, | |
| output_tokens=100, | |
| cache_read_tokens=500, | |
| cache_write_tokens=200, | |
| ) | |
| assert cost is not None | |
| # input_cost = 1000 * 0.00001 = 0.01 | |
| # output_cost = 100 * 0.00003 = 0.003 | |
| # cache_read = 500 * 0.000001 = 0.0005 | |
| # cache_write = 200 * 0.0000125 = 0.0025 | |
| expected = 0.01 + 0.003 + 0.0005 + 0.0025 | |
| assert abs(cost - expected) < 0.0001, f"Expected {expected}, got {cost}" | |
| def test_estimate_cost_without_cache_tokens(self): | |
| """Cost without cache tokens should just be input + output.""" | |
| from headroom.proxy.server import CostTracker | |
| tracker = CostTracker() | |
| with ( | |
| patch("headroom.proxy.cost.LITELLM_AVAILABLE", True), | |
| patch("headroom.proxy.cost.litellm") as mock_litellm, | |
| ): | |
| mock_litellm.cost_per_token.side_effect = ( | |
| lambda model, prompt_tokens, completion_tokens, **kwargs: ( | |
| prompt_tokens * 0.00001, | |
| completion_tokens * 0.00003, | |
| ) | |
| ) | |
| mock_litellm.get_model_info.return_value = {} | |
| cost = tracker.estimate_cost("gpt-4o", input_tokens=1000, output_tokens=100) | |
| expected = 1000 * 0.00001 + 100 * 0.00003 | |
| assert abs(cost - expected) < 0.0001 | |
| def test_estimate_cost_returns_none_without_litellm(self): | |
| """When litellm is unavailable, estimate_cost should return None.""" | |
| from headroom.proxy.server import CostTracker | |
| tracker = CostTracker() | |
| with patch("headroom.proxy.cost.LITELLM_AVAILABLE", False): | |
| cost = tracker.estimate_cost("gpt-4o", input_tokens=1000, output_tokens=100) | |
| assert cost is None | |