avp-rag-system / tests /test_fallback_integration.py
BeefStewBibi's picture
Add e2e frontend tests to the CI/CD pipeline
bc2b724
"""Integration tests for LLM provider fallback through the API."""
import os
from unittest.mock import MagicMock, patch
from slowapi import Limiter
from slowapi.util import get_remote_address
from starlette.testclient import TestClient
from tests.conftest import MOCK_LLM_OUTPUT, MOCK_RETRIEVE_RESULT
def _setup(env_overrides: dict | None = None):
"""Common setup: mock retriever, fresh limiter, env patch context.
Returns (app, mock_retriever, env, fresh_limiter) for callers to
compose their own ``with`` blocks.
"""
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = MOCK_RETRIEVE_RESULT
env = {
"ANTHROPIC_API_KEY": "fake-key-for-testing",
"LLM_PROVIDER": "anthropic",
}
if env_overrides:
env.update(env_overrides)
from api.main import app
from api.cache import generation_cache, retrieval_cache
generation_cache.clear()
retrieval_cache.clear()
fresh_limiter = Limiter(key_func=get_remote_address)
app.state.limiter = fresh_limiter
return app, mock_retriever, env, fresh_limiter
class TestProviderFallback:
def test_generate_uses_primary_provider(self):
app, mock_retriever, env, fresh_limiter = _setup()
mock_anthropic = MagicMock(return_value=MOCK_LLM_OUTPUT)
with (
patch("retrieve._retriever", mock_retriever),
patch.dict(os.environ, env, clear=True),
patch("api.dependencies.limiter", fresh_limiter),
patch("api.routes.limiter", fresh_limiter),
patch("providers._call_anthropic", mock_anthropic),
TestClient(app) as client,
):
resp = client.post("/api/generate", json={"message": "add numbers"})
assert resp.status_code == 200
assert resp.json()["generated_code"] == MOCK_LLM_OUTPUT
mock_anthropic.assert_called()
def test_generate_falls_back_to_vllm(self):
app, mock_retriever, env, fresh_limiter = _setup(
env_overrides={"LLM_FALLBACK_PROVIDER": "vllm"},
)
fallback_code = "fun fallback():\n return 0\nend fun"
mock_anthropic = MagicMock(side_effect=RuntimeError("API key invalid"))
mock_vllm = MagicMock(return_value=fallback_code)
with (
patch("retrieve._retriever", mock_retriever),
patch.dict(os.environ, env, clear=True),
patch("api.dependencies.limiter", fresh_limiter),
patch("api.routes.limiter", fresh_limiter),
patch("providers._call_anthropic", mock_anthropic),
patch("providers._call_vllm", mock_vllm),
TestClient(app) as client,
):
resp = client.post("/api/generate", json={"message": "add numbers"})
assert resp.status_code == 200
assert resp.json()["generated_code"] == fallback_code
mock_vllm.assert_called()
def test_generate_503_when_not_configured(self):
"""No API key -> is_provider_configured() returns False -> 503."""
app, mock_retriever, _, fresh_limiter = _setup()
env_no_key = {"LLM_PROVIDER": "anthropic"} # No ANTHROPIC_API_KEY
with (
patch("retrieve._retriever", mock_retriever),
patch.dict(os.environ, env_no_key, clear=True),
patch("api.dependencies.limiter", fresh_limiter),
patch("api.routes.limiter", fresh_limiter),
TestClient(app) as client,
):
resp = client.post("/api/generate", json={"message": "add numbers"})
assert resp.status_code == 503
assert "not configured" in resp.json()["detail"]
def test_health_reflects_provider_status(self):
app, mock_retriever, _, fresh_limiter = _setup()
# With API key -> configured
with (
patch("retrieve._retriever", mock_retriever),
patch.dict(os.environ, {
"LLM_PROVIDER": "anthropic",
"ANTHROPIC_API_KEY": "fake",
}, clear=True),
patch("api.dependencies.limiter", fresh_limiter),
patch("api.routes.limiter", fresh_limiter),
TestClient(app) as client,
):
resp = client.get("/api/health")
assert resp.json()["provider_configured"] is True
# Without API key -> not configured
with (
patch("retrieve._retriever", mock_retriever),
patch.dict(os.environ, {
"LLM_PROVIDER": "anthropic",
}, clear=True),
patch("api.dependencies.limiter", fresh_limiter),
patch("api.routes.limiter", fresh_limiter),
TestClient(app) as client,
):
resp = client.get("/api/health")
assert resp.json()["provider_configured"] is False
def test_generate_500_when_all_providers_fail(self):
app, mock_retriever, env, fresh_limiter = _setup(
env_overrides={"LLM_FALLBACK_PROVIDER": "vllm"},
)
mock_anthropic = MagicMock(side_effect=RuntimeError("primary down"))
mock_vllm = MagicMock(side_effect=RuntimeError("fallback down"))
with (
patch("retrieve._retriever", mock_retriever),
patch.dict(os.environ, env, clear=True),
patch("api.dependencies.limiter", fresh_limiter),
patch("api.routes.limiter", fresh_limiter),
patch("providers._call_anthropic", mock_anthropic),
patch("providers._call_vllm", mock_vllm),
TestClient(app, raise_server_exceptions=False) as client,
):
resp = client.post("/api/generate", json={"message": "add numbers"})
assert resp.status_code == 500