File size: 13,352 Bytes
35765b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
"""Test cases for Model Router - multi-model rotation with rate limiting and caching."""

import asyncio
import time
from unittest.mock import patch, MagicMock, AsyncMock
from datetime import datetime, timedelta
import sys
import os

# Add parent to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from dotenv import load_dotenv
load_dotenv()

# Test configuration
TESTS_PASSED = 0
TESTS_FAILED = 0


def test(name):
    """Decorator for test functions."""
    def decorator(func):
        async def wrapper():
            global TESTS_PASSED, TESTS_FAILED
            try:
                if asyncio.iscoroutinefunction(func):
                    await func()
                else:
                    func()
                print(f"[PASS] {name}")
                TESTS_PASSED += 1
            except AssertionError as e:
                print(f"[FAIL] {name}: {e}")
                TESTS_FAILED += 1
            except Exception as e:
                print(f"[ERROR] {name}: {e}")
                TESTS_FAILED += 1
        return wrapper
    return decorator


# ========== Model Selection Tests ==========

@test("Model selection returns best model for chat task")
def test_model_selection_chat():
    from app.model_router import ModelRouter, TASK_PRIORITIES
    router = ModelRouter()
    model = router.get_model_for_task("chat")
    assert model == "gemini-2.0-flash", f"Expected gemini-2.0-flash, got {model}"


@test("Model selection returns best model for documentation task")
def test_model_selection_documentation():
    from app.model_router import ModelRouter, TASK_PRIORITIES
    router = ModelRouter()
    model = router.get_model_for_task("documentation")
    assert model == "gemini-2.0-flash-lite", f"Expected gemini-2.0-flash-lite, got {model}"


@test("Model selection returns best model for synthesis task")
def test_model_selection_synthesis():
    from app.model_router import ModelRouter, TASK_PRIORITIES
    router = ModelRouter()
    model = router.get_model_for_task("synthesis")
    assert model == "gemma-3-27b-it", f"Expected gemma-3-27b-it, got {model}"


@test("Model selection falls back to default for unknown task")
def test_model_selection_unknown():
    from app.model_router import ModelRouter
    router = ModelRouter()
    model = router.get_model_for_task("unknown_task_type")
    assert model == "gemini-2.0-flash", f"Expected gemini-2.0-flash (default), got {model}"


# ========== Rate Limiting Tests ==========

@test("Rate limit tracking works correctly")
def test_rate_limit_tracking():
    from app.model_router import ModelRouter, MODEL_CONFIGS
    router = ModelRouter()

    # Initially all models should be available (key 0)
    assert router._check_rate_limit("gemini-2.0-flash", 0) == True

    # Record usage up to limit
    rpm_limit = MODEL_CONFIGS["gemini-2.0-flash"]["rpm"]
    for _ in range(rpm_limit):
        router._record_usage("gemini-2.0-flash", 0)

    # Should now be rate limited for key 0
    assert router._check_rate_limit("gemini-2.0-flash", 0) == False


@test("Model falls back when primary is rate limited")
def test_model_fallback():
    from app.model_router import ModelRouter, MODEL_CONFIGS
    router = ModelRouter()

    # Exhaust gemini-2.0-flash rate limit on all keys
    rpm_limit = MODEL_CONFIGS["gemini-2.0-flash"]["rpm"]
    for key_idx in range(len(router.api_keys)):
        for _ in range(rpm_limit):
            router._record_usage("gemini-2.0-flash", key_idx)

    # Should fall back to next model in chat priority
    model = router.get_model_for_task("chat")
    assert model == "gemini-2.0-flash-lite", f"Expected fallback to gemini-2.0-flash-lite, got {model}"


@test("Returns None when all models exhausted on all keys")
def test_all_models_exhausted():
    from app.model_router import ModelRouter, MODEL_CONFIGS
    router = ModelRouter()

    # Exhaust all models on all keys
    for key_idx in range(len(router.api_keys)):
        for model_name, config in MODEL_CONFIGS.items():
            for _ in range(config["rpm"]):
                router._record_usage(model_name, key_idx)

    # Should return None
    model = router.get_model_for_task("chat")
    assert model is None, f"Expected None when all exhausted, got {model}"


# ========== Cache Tests ==========

@test("Cache stores and retrieves responses")
def test_cache_store_retrieve():
    from app.model_router import ModelRouter
    router = ModelRouter()

    cache_key = router._get_cache_key("chat", "user1", "test prompt")

    # Initially empty
    assert router._check_cache(cache_key) is None

    # Store response
    router._store_cache(cache_key, "cached response", "gemini-2.0-flash")

    # Should retrieve
    cached = router._check_cache(cache_key)
    assert cached == "cached response", f"Expected 'cached response', got {cached}"


@test("Cache key includes user_id")
def test_cache_key_user_differentiation():
    from app.model_router import ModelRouter
    router = ModelRouter()

    key1 = router._get_cache_key("chat", "user1", "same prompt")
    key2 = router._get_cache_key("chat", "user2", "same prompt")

    assert key1 != key2, "Cache keys should differ for different users"


@test("Cache key includes task_type")
def test_cache_key_task_differentiation():
    from app.model_router import ModelRouter
    router = ModelRouter()

    key1 = router._get_cache_key("chat", "user1", "same prompt")
    key2 = router._get_cache_key("documentation", "user1", "same prompt")

    assert key1 != key2, "Cache keys should differ for different task types"


@test("Cache expires after TTL")
def test_cache_expiry():
    from app.model_router import ModelRouter, CACHE_TTL
    router = ModelRouter()

    cache_key = router._get_cache_key("chat", "user1", "test prompt")
    router._store_cache(cache_key, "cached response", "gemini-2.0-flash")

    # Manually expire the cache entry
    router.cache[cache_key]["timestamp"] = datetime.now() - timedelta(seconds=CACHE_TTL + 1)

    # Should not retrieve expired entry
    cached = router._check_cache(cache_key)
    assert cached is None, "Expired cache entry should return None"


@test("Cache cleaning removes expired entries")
def test_cache_cleaning():
    from app.model_router import ModelRouter, CACHE_TTL
    router = ModelRouter()

    # Add expired entries
    for i in range(5):
        key = f"expired_{i}"
        router.cache[key] = {
            "response": f"response_{i}",
            "timestamp": datetime.now() - timedelta(seconds=CACHE_TTL + 1),
            "model": "test"
        }

    # Add valid entry
    router.cache["valid"] = {
        "response": "valid_response",
        "timestamp": datetime.now(),
        "model": "test"
    }

    # Clean cache
    router._clean_cache()

    # Only valid entry should remain
    assert len(router.cache) == 1, f"Expected 1 entry after cleaning, got {len(router.cache)}"
    assert "valid" in router.cache, "Valid entry should remain after cleaning"


# ========== Stats Tests ==========

@test("Stats returns correct usage info")
def test_stats():
    from app.model_router import ModelRouter, MODEL_CONFIGS
    router = ModelRouter()

    # Record some usage on key 0
    router._record_usage("gemini-2.0-flash", 0)
    router._record_usage("gemini-2.0-flash", 0)
    router._record_usage("gemma-3-27b-it", 0)

    stats = router.get_stats()

    assert stats["models"]["gemini-2.0-flash"]["used"] == 2, "Should show 2 uses for gemini-2.0-flash"
    assert stats["models"]["gemma-3-27b-it"]["used"] == 1, "Should show 1 use for gemma-3-27b-it"
    # Limit is per-key * num_keys
    expected_limit = MODEL_CONFIGS["gemini-2.0-flash"]["rpm"] * len(router.api_keys)
    assert stats["models"]["gemini-2.0-flash"]["limit"] == expected_limit


# ========== Multi-Key Tests ==========

@test("Multiple keys are loaded from environment")
def test_multi_key_loading():
    from app.model_router import ModelRouter
    router = ModelRouter()
    assert len(router.api_keys) >= 1, "Should have at least one API key"


@test("Key health tracking works")
def test_key_health_tracking():
    from app.model_router import ModelRouter, KEY_COOLDOWN_RATE_LIMIT
    router = ModelRouter()

    # Initially all keys should be healthy
    for i in range(len(router.api_keys)):
        assert router._is_key_healthy(i) == True, f"Key {i} should be healthy initially"

    # Mark first key as unhealthy
    router._mark_key_unhealthy(0, Exception("Test error"), KEY_COOLDOWN_RATE_LIMIT)

    assert router._is_key_healthy(0) == False, "Key 0 should be unhealthy after marking"
    assert router.key_health[0]["last_error"] == "Test error"


@test("Key rotation skips unhealthy keys")
def test_key_rotation_skips_unhealthy():
    from app.model_router import ModelRouter
    router = ModelRouter()

    if len(router.api_keys) < 2:
        return  # Skip if only one key

    # Mark key 0 as unhealthy
    router._mark_key_unhealthy(0, Exception("Test"), 60)

    # Get next key should skip key 0
    key_idx, _ = router._get_next_key()
    assert key_idx != 0 or len(router.api_keys) == 1, "Should skip unhealthy key 0"


@test("Key auto-recovers after cooldown")
def test_key_auto_recovery():
    from app.model_router import ModelRouter
    from datetime import datetime, timedelta
    router = ModelRouter()

    # Mark key as unhealthy with expired cooldown
    router.key_health[0] = {
        "healthy": False,
        "last_error": "Test",
        "retry_after": datetime.now() - timedelta(seconds=1)  # Already expired
    }

    # Should recover when checked
    assert router._is_key_healthy(0) == True, "Key should auto-recover after cooldown"
    assert router.key_health[0]["healthy"] == True
    assert router.key_health[0]["last_error"] is None


@test("Stats includes key information")
def test_stats_includes_keys():
    from app.model_router import ModelRouter
    router = ModelRouter()

    stats = router.get_stats()

    assert "keys" in stats, "Stats should include keys info"
    assert stats["keys"]["total"] >= 1, "Should have at least one key"
    assert stats["keys"]["healthy"] >= 1, "Should have at least one healthy key"
    assert "details" in stats["keys"], "Stats should include key details"


# ========== Integration Tests (requires API key) ==========

@test("Generate returns response and model info")
async def test_generate_integration():
    from app.model_router import generate_with_info

    response, model = await generate_with_info(
        "Say 'test' in one word.",
        task_type="default",
        use_cache=False
    )

    assert response is not None, "Response should not be None"
    assert len(response) > 0, "Response should not be empty"
    assert model in ["gemini-2.0-flash", "gemini-2.0-flash-lite", "gemma-3-27b-it",
                     "gemma-3-12b-it", "gemma-3-4b-it", "gemma-3-1b-it", "cache"]


@test("Generate uses cache on repeated calls")
async def test_generate_uses_cache():
    from app.model_router import generate_with_info, router

    # Clear cache first
    router.cache.clear()

    prompt = "Say 'cached test' in two words."

    # First call - should hit model
    response1, model1 = await generate_with_info(prompt, task_type="default", use_cache=True)
    assert model1 != "cache", f"First call should not be from cache, got {model1}"

    # Second call - should hit cache
    response2, model2 = await generate_with_info(prompt, task_type="default", use_cache=True)
    assert model2 == "cache", f"Second call should be from cache, got {model2}"
    assert response1 == response2, "Cached response should match original"


# ========== Run Tests ==========

async def run_tests():
    """Run all tests."""
    print("=" * 60)
    print("Model Router Tests")
    print("=" * 60)
    print()

    # Unit tests (no API needed)
    print("--- Model Selection Tests ---")
    await test_model_selection_chat()
    await test_model_selection_documentation()
    await test_model_selection_synthesis()
    await test_model_selection_unknown()

    print()
    print("--- Rate Limiting Tests ---")
    await test_rate_limit_tracking()
    await test_model_fallback()
    await test_all_models_exhausted()

    print()
    print("--- Cache Tests ---")
    await test_cache_store_retrieve()
    await test_cache_key_user_differentiation()
    await test_cache_key_task_differentiation()
    await test_cache_expiry()
    await test_cache_cleaning()

    print()
    print("--- Stats Tests ---")
    await test_stats()

    print()
    print("--- Multi-Key Tests ---")
    await test_multi_key_loading()
    await test_key_health_tracking()
    await test_key_rotation_skips_unhealthy()
    await test_key_auto_recovery()
    await test_stats_includes_keys()

    print()
    print("--- Integration Tests (requires API key) ---")

    # Check if API key is available
    if not os.getenv("GEMINI_API_KEY") and not os.getenv("GEMINI_API_KEYS"):
        print("[SKIP] Integration tests skipped - no API keys")
    else:
        await test_generate_integration()
        await test_generate_uses_cache()

    print()
    print("=" * 60)
    print(f"Results: {TESTS_PASSED} passed, {TESTS_FAILED} failed")
    print("=" * 60)

    return TESTS_FAILED == 0


if __name__ == "__main__":
    success = asyncio.run(run_tests())
    exit(0 if success else 1)