File size: 6,046 Bytes
d04007f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import pytest
from unittest.mock import AsyncMock, MagicMock
from ankigen.llm_interface import OpenAIClientManager, OpenAIRateLimiter
from openai import OpenAIError

# --- OpenAIClientManager Tests ---


@pytest.mark.anyio
async def test_initialize_client_valid_key(mocker):
    manager = OpenAIClientManager()
    # Mock AsyncOpenAI class
    mock_openai_class = mocker.patch("ankigen.llm_interface.AsyncOpenAI", autospec=True)

    await manager.initialize_client("sk-valid-key")

    assert manager._api_key == "sk-valid-key"
    assert manager._client is not None
    mock_openai_class.assert_called_once_with(api_key="sk-valid-key")


@pytest.mark.anyio
async def test_initialize_client_invalid_key():
    manager = OpenAIClientManager()
    # Key not starting with sk-
    with pytest.raises(ValueError, match="Invalid OpenAI API key format."):
        await manager.initialize_client("invalid-key")
    assert manager._client is None


@pytest.mark.anyio
async def test_initialize_client_empty_key():
    manager = OpenAIClientManager()
    with pytest.raises(ValueError, match="Invalid OpenAI API key format."):
        await manager.initialize_client("")
    assert manager._client is None


@pytest.mark.anyio
async def test_initialize_client_openai_error(mocker):
    manager = OpenAIClientManager()
    mocker.patch(
        "ankigen.llm_interface.AsyncOpenAI", side_effect=OpenAIError("API Error")
    )

    with pytest.raises(OpenAIError):
        await manager.initialize_client("sk-error")

    assert manager._client is None


@pytest.mark.anyio
async def test_initialize_client_unexpected_error(mocker):
    manager = OpenAIClientManager()
    mocker.patch(
        "ankigen.llm_interface.AsyncOpenAI", side_effect=Exception("Unexpected")
    )

    with pytest.raises(
        RuntimeError, match="Unexpected error initializing AsyncOpenAI client."
    ):
        await manager.initialize_client("sk-unexpected")

    assert manager._client is None


def test_get_client_not_initialized():
    manager = OpenAIClientManager()
    with pytest.raises(RuntimeError, match="AsyncOpenAI client is not initialized"):
        manager.get_client()


@pytest.mark.anyio
async def test_get_client_initialized(mocker):
    manager = OpenAIClientManager()
    mocker.patch("ankigen.llm_interface.AsyncOpenAI")
    await manager.initialize_client("sk-valid")

    client = manager.get_client()
    assert client is not None


def test_context_manager_sync(mocker):
    manager = OpenAIClientManager()
    mock_close = mocker.patch.object(manager, "close")

    with manager as m:
        assert m is manager

    mock_close.assert_called_once()


@pytest.mark.anyio
async def test_context_manager_async(mocker):
    manager = OpenAIClientManager()
    mock_aclose = mocker.patch.object(manager, "aclose", new_callable=AsyncMock)

    async with manager as m:
        assert m is manager

    mock_aclose.assert_called_once()


def test_close_method(mocker):
    manager = OpenAIClientManager()
    mock_client = MagicMock()
    # Ensure it has 'close' method
    mock_client.close = MagicMock()
    manager._client = mock_client

    manager.close()

    mock_client.close.assert_called_once()
    assert manager._client is None


@pytest.mark.anyio
async def test_aclose_method(mocker):
    manager = OpenAIClientManager()
    mock_client = AsyncMock()
    # Ensure it has 'aclose' method
    mock_client.aclose = AsyncMock()
    manager._client = mock_client

    await manager.aclose()

    mock_client.aclose.assert_called_once()
    assert manager._client is None


@pytest.mark.anyio
async def test_aclose_method_fallback_to_sync(mocker):
    manager = OpenAIClientManager()
    # Create a client mock that only exposes 'close', so 'aclose' is truly absent
    mock_client = MagicMock(spec_set=["close"])
    mock_client.close = MagicMock()
    manager._client = mock_client

    await manager.aclose()

    mock_client.close.assert_called_once()
    assert manager._client is None


# --- OpenAIRateLimiter Tests ---


@pytest.mark.anyio
async def test_rate_limiter_within_limits():
    limiter = OpenAIRateLimiter(tokens_per_minute=100)
    # Should not wait
    await limiter.wait_if_needed(50)
    assert limiter.tokens_used_current_window == 50

    await limiter.wait_if_needed(40)
    assert limiter.tokens_used_current_window == 90


@pytest.mark.anyio
async def test_rate_limiter_window_reset(mocker):
    # Mock monotonic to simulate time passing
    mock_monotonic = mocker.patch("time.monotonic")
    mock_monotonic.return_value = 0.0

    limiter = OpenAIRateLimiter(tokens_per_minute=100)
    await limiter.wait_if_needed(80)
    assert limiter.tokens_used_current_window == 80

    # Advance time by 61 seconds
    mock_monotonic.return_value = 61.0

    await limiter.wait_if_needed(10)

    # Should have reset
    assert limiter.tokens_used_current_window == 10
    assert limiter.current_window_start_time == 61.0


@pytest.mark.anyio
async def test_rate_limiter_wait_needed(mocker):
    mock_sleep = mocker.patch("asyncio.sleep", new_callable=AsyncMock)
    mock_monotonic = mocker.patch("time.monotonic")

    # Start at time 0
    mock_monotonic.return_value = 0.0
    limiter = OpenAIRateLimiter(tokens_per_minute=100)

    # Use 90 tokens
    await limiter.wait_if_needed(90)
    assert limiter.tokens_used_current_window == 90

    # Next token request of 20 would exceed 100 (90 + 20 = 110)
    # Current time is 0.0. window_start is 0.0.
    # time_to_wait = (0.0 + 60.0) - 0.0 = 60.0

    # Instead of side_effect list, we use a more robust approach
    # asyncio loop also calls time.monotonic(), so return_value is safer
    # we update it when sleep is called.
    async def side_effect_sleep(delay):
        mock_monotonic.return_value = 60.0  # Fast forward time

    mock_sleep.side_effect = side_effect_sleep

    await limiter.wait_if_needed(20)

    mock_sleep.assert_called_once_with(60.0)
    assert limiter.tokens_used_current_window == 20
    assert limiter.current_window_start_time == 60.0