File size: 5,728 Bytes
bc2b724
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Integration tests for LLM provider fallback through the API."""

import os
from unittest.mock import MagicMock, patch

from slowapi import Limiter
from slowapi.util import get_remote_address
from starlette.testclient import TestClient

from tests.conftest import MOCK_LLM_OUTPUT, MOCK_RETRIEVE_RESULT


def _setup(env_overrides: dict | None = None):
    """Common setup: mock retriever, fresh limiter, env patch context.

    Returns (app, mock_retriever, env, fresh_limiter) for callers to
    compose their own ``with`` blocks.
    """
    mock_retriever = MagicMock()
    mock_retriever.retrieve.return_value = MOCK_RETRIEVE_RESULT

    env = {
        "ANTHROPIC_API_KEY": "fake-key-for-testing",
        "LLM_PROVIDER": "anthropic",
    }
    if env_overrides:
        env.update(env_overrides)

    from api.main import app
    from api.cache import generation_cache, retrieval_cache

    generation_cache.clear()
    retrieval_cache.clear()

    fresh_limiter = Limiter(key_func=get_remote_address)
    app.state.limiter = fresh_limiter

    return app, mock_retriever, env, fresh_limiter


class TestProviderFallback:
    def test_generate_uses_primary_provider(self):
        app, mock_retriever, env, fresh_limiter = _setup()
        mock_anthropic = MagicMock(return_value=MOCK_LLM_OUTPUT)

        with (
            patch("retrieve._retriever", mock_retriever),
            patch.dict(os.environ, env, clear=True),
            patch("api.dependencies.limiter", fresh_limiter),
            patch("api.routes.limiter", fresh_limiter),
            patch("providers._call_anthropic", mock_anthropic),
            TestClient(app) as client,
        ):
            resp = client.post("/api/generate", json={"message": "add numbers"})
            assert resp.status_code == 200
            assert resp.json()["generated_code"] == MOCK_LLM_OUTPUT
            mock_anthropic.assert_called()

    def test_generate_falls_back_to_vllm(self):
        app, mock_retriever, env, fresh_limiter = _setup(
            env_overrides={"LLM_FALLBACK_PROVIDER": "vllm"},
        )
        fallback_code = "fun fallback():\n    return 0\nend fun"
        mock_anthropic = MagicMock(side_effect=RuntimeError("API key invalid"))
        mock_vllm = MagicMock(return_value=fallback_code)

        with (
            patch("retrieve._retriever", mock_retriever),
            patch.dict(os.environ, env, clear=True),
            patch("api.dependencies.limiter", fresh_limiter),
            patch("api.routes.limiter", fresh_limiter),
            patch("providers._call_anthropic", mock_anthropic),
            patch("providers._call_vllm", mock_vllm),
            TestClient(app) as client,
        ):
            resp = client.post("/api/generate", json={"message": "add numbers"})
            assert resp.status_code == 200
            assert resp.json()["generated_code"] == fallback_code
            mock_vllm.assert_called()

    def test_generate_503_when_not_configured(self):
        """No API key -> is_provider_configured() returns False -> 503."""
        app, mock_retriever, _, fresh_limiter = _setup()
        env_no_key = {"LLM_PROVIDER": "anthropic"}  # No ANTHROPIC_API_KEY

        with (
            patch("retrieve._retriever", mock_retriever),
            patch.dict(os.environ, env_no_key, clear=True),
            patch("api.dependencies.limiter", fresh_limiter),
            patch("api.routes.limiter", fresh_limiter),
            TestClient(app) as client,
        ):
            resp = client.post("/api/generate", json={"message": "add numbers"})
            assert resp.status_code == 503
            assert "not configured" in resp.json()["detail"]

    def test_health_reflects_provider_status(self):
        app, mock_retriever, _, fresh_limiter = _setup()

        # With API key -> configured
        with (
            patch("retrieve._retriever", mock_retriever),
            patch.dict(os.environ, {
                "LLM_PROVIDER": "anthropic",
                "ANTHROPIC_API_KEY": "fake",
            }, clear=True),
            patch("api.dependencies.limiter", fresh_limiter),
            patch("api.routes.limiter", fresh_limiter),
            TestClient(app) as client,
        ):
            resp = client.get("/api/health")
            assert resp.json()["provider_configured"] is True

        # Without API key -> not configured
        with (
            patch("retrieve._retriever", mock_retriever),
            patch.dict(os.environ, {
                "LLM_PROVIDER": "anthropic",
            }, clear=True),
            patch("api.dependencies.limiter", fresh_limiter),
            patch("api.routes.limiter", fresh_limiter),
            TestClient(app) as client,
        ):
            resp = client.get("/api/health")
            assert resp.json()["provider_configured"] is False

    def test_generate_500_when_all_providers_fail(self):
        app, mock_retriever, env, fresh_limiter = _setup(
            env_overrides={"LLM_FALLBACK_PROVIDER": "vllm"},
        )
        mock_anthropic = MagicMock(side_effect=RuntimeError("primary down"))
        mock_vllm = MagicMock(side_effect=RuntimeError("fallback down"))

        with (
            patch("retrieve._retriever", mock_retriever),
            patch.dict(os.environ, env, clear=True),
            patch("api.dependencies.limiter", fresh_limiter),
            patch("api.routes.limiter", fresh_limiter),
            patch("providers._call_anthropic", mock_anthropic),
            patch("providers._call_vllm", mock_vllm),
            TestClient(app, raise_server_exceptions=False) as client,
        ):
            resp = client.post("/api/generate", json={"message": "add numbers"})
            assert resp.status_code == 500