Spaces:
Sleeping
Sleeping
File size: 5,728 Bytes
bc2b724 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """Integration tests for LLM provider fallback through the API."""
import os
from unittest.mock import MagicMock, patch
from slowapi import Limiter
from slowapi.util import get_remote_address
from starlette.testclient import TestClient
from tests.conftest import MOCK_LLM_OUTPUT, MOCK_RETRIEVE_RESULT
def _setup(env_overrides: dict | None = None):
"""Common setup: mock retriever, fresh limiter, env patch context.
Returns (app, mock_retriever, env, fresh_limiter) for callers to
compose their own ``with`` blocks.
"""
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = MOCK_RETRIEVE_RESULT
env = {
"ANTHROPIC_API_KEY": "fake-key-for-testing",
"LLM_PROVIDER": "anthropic",
}
if env_overrides:
env.update(env_overrides)
from api.main import app
from api.cache import generation_cache, retrieval_cache
generation_cache.clear()
retrieval_cache.clear()
fresh_limiter = Limiter(key_func=get_remote_address)
app.state.limiter = fresh_limiter
return app, mock_retriever, env, fresh_limiter
class TestProviderFallback:
def test_generate_uses_primary_provider(self):
app, mock_retriever, env, fresh_limiter = _setup()
mock_anthropic = MagicMock(return_value=MOCK_LLM_OUTPUT)
with (
patch("retrieve._retriever", mock_retriever),
patch.dict(os.environ, env, clear=True),
patch("api.dependencies.limiter", fresh_limiter),
patch("api.routes.limiter", fresh_limiter),
patch("providers._call_anthropic", mock_anthropic),
TestClient(app) as client,
):
resp = client.post("/api/generate", json={"message": "add numbers"})
assert resp.status_code == 200
assert resp.json()["generated_code"] == MOCK_LLM_OUTPUT
mock_anthropic.assert_called()
def test_generate_falls_back_to_vllm(self):
app, mock_retriever, env, fresh_limiter = _setup(
env_overrides={"LLM_FALLBACK_PROVIDER": "vllm"},
)
fallback_code = "fun fallback():\n return 0\nend fun"
mock_anthropic = MagicMock(side_effect=RuntimeError("API key invalid"))
mock_vllm = MagicMock(return_value=fallback_code)
with (
patch("retrieve._retriever", mock_retriever),
patch.dict(os.environ, env, clear=True),
patch("api.dependencies.limiter", fresh_limiter),
patch("api.routes.limiter", fresh_limiter),
patch("providers._call_anthropic", mock_anthropic),
patch("providers._call_vllm", mock_vllm),
TestClient(app) as client,
):
resp = client.post("/api/generate", json={"message": "add numbers"})
assert resp.status_code == 200
assert resp.json()["generated_code"] == fallback_code
mock_vllm.assert_called()
def test_generate_503_when_not_configured(self):
"""No API key -> is_provider_configured() returns False -> 503."""
app, mock_retriever, _, fresh_limiter = _setup()
env_no_key = {"LLM_PROVIDER": "anthropic"} # No ANTHROPIC_API_KEY
with (
patch("retrieve._retriever", mock_retriever),
patch.dict(os.environ, env_no_key, clear=True),
patch("api.dependencies.limiter", fresh_limiter),
patch("api.routes.limiter", fresh_limiter),
TestClient(app) as client,
):
resp = client.post("/api/generate", json={"message": "add numbers"})
assert resp.status_code == 503
assert "not configured" in resp.json()["detail"]
def test_health_reflects_provider_status(self):
app, mock_retriever, _, fresh_limiter = _setup()
# With API key -> configured
with (
patch("retrieve._retriever", mock_retriever),
patch.dict(os.environ, {
"LLM_PROVIDER": "anthropic",
"ANTHROPIC_API_KEY": "fake",
}, clear=True),
patch("api.dependencies.limiter", fresh_limiter),
patch("api.routes.limiter", fresh_limiter),
TestClient(app) as client,
):
resp = client.get("/api/health")
assert resp.json()["provider_configured"] is True
# Without API key -> not configured
with (
patch("retrieve._retriever", mock_retriever),
patch.dict(os.environ, {
"LLM_PROVIDER": "anthropic",
}, clear=True),
patch("api.dependencies.limiter", fresh_limiter),
patch("api.routes.limiter", fresh_limiter),
TestClient(app) as client,
):
resp = client.get("/api/health")
assert resp.json()["provider_configured"] is False
def test_generate_500_when_all_providers_fail(self):
app, mock_retriever, env, fresh_limiter = _setup(
env_overrides={"LLM_FALLBACK_PROVIDER": "vllm"},
)
mock_anthropic = MagicMock(side_effect=RuntimeError("primary down"))
mock_vllm = MagicMock(side_effect=RuntimeError("fallback down"))
with (
patch("retrieve._retriever", mock_retriever),
patch.dict(os.environ, env, clear=True),
patch("api.dependencies.limiter", fresh_limiter),
patch("api.routes.limiter", fresh_limiter),
patch("providers._call_anthropic", mock_anthropic),
patch("providers._call_vllm", mock_vllm),
TestClient(app, raise_server_exceptions=False) as client,
):
resp = client.post("/api/generate", json={"message": "add numbers"})
assert resp.status_code == 500
|