|
|
"""Test cases for Model Router - multi-model rotation with rate limiting and caching.""" |
|
|
|
|
|
import asyncio |
|
|
import time |
|
|
from unittest.mock import patch, MagicMock, AsyncMock |
|
|
from datetime import datetime, timedelta |
|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
TESTS_PASSED = 0 |
|
|
TESTS_FAILED = 0 |
|
|
|
|
|
|
|
|
def test(name): |
|
|
"""Decorator for test functions.""" |
|
|
def decorator(func): |
|
|
async def wrapper(): |
|
|
global TESTS_PASSED, TESTS_FAILED |
|
|
try: |
|
|
if asyncio.iscoroutinefunction(func): |
|
|
await func() |
|
|
else: |
|
|
func() |
|
|
print(f"[PASS] {name}") |
|
|
TESTS_PASSED += 1 |
|
|
except AssertionError as e: |
|
|
print(f"[FAIL] {name}: {e}") |
|
|
TESTS_FAILED += 1 |
|
|
except Exception as e: |
|
|
print(f"[ERROR] {name}: {e}") |
|
|
TESTS_FAILED += 1 |
|
|
return wrapper |
|
|
return decorator |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@test("Model selection returns best model for chat task") |
|
|
def test_model_selection_chat(): |
|
|
from app.model_router import ModelRouter, TASK_PRIORITIES |
|
|
router = ModelRouter() |
|
|
model = router.get_model_for_task("chat") |
|
|
assert model == "gemini-2.0-flash", f"Expected gemini-2.0-flash, got {model}" |
|
|
|
|
|
|
|
|
@test("Model selection returns best model for documentation task") |
|
|
def test_model_selection_documentation(): |
|
|
from app.model_router import ModelRouter, TASK_PRIORITIES |
|
|
router = ModelRouter() |
|
|
model = router.get_model_for_task("documentation") |
|
|
assert model == "gemini-2.0-flash-lite", f"Expected gemini-2.0-flash-lite, got {model}" |
|
|
|
|
|
|
|
|
@test("Model selection returns best model for synthesis task") |
|
|
def test_model_selection_synthesis(): |
|
|
from app.model_router import ModelRouter, TASK_PRIORITIES |
|
|
router = ModelRouter() |
|
|
model = router.get_model_for_task("synthesis") |
|
|
assert model == "gemma-3-27b-it", f"Expected gemma-3-27b-it, got {model}" |
|
|
|
|
|
|
|
|
@test("Model selection falls back to default for unknown task") |
|
|
def test_model_selection_unknown(): |
|
|
from app.model_router import ModelRouter |
|
|
router = ModelRouter() |
|
|
model = router.get_model_for_task("unknown_task_type") |
|
|
assert model == "gemini-2.0-flash", f"Expected gemini-2.0-flash (default), got {model}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@test("Rate limit tracking works correctly") |
|
|
def test_rate_limit_tracking(): |
|
|
from app.model_router import ModelRouter, MODEL_CONFIGS |
|
|
router = ModelRouter() |
|
|
|
|
|
|
|
|
assert router._check_rate_limit("gemini-2.0-flash", 0) == True |
|
|
|
|
|
|
|
|
rpm_limit = MODEL_CONFIGS["gemini-2.0-flash"]["rpm"] |
|
|
for _ in range(rpm_limit): |
|
|
router._record_usage("gemini-2.0-flash", 0) |
|
|
|
|
|
|
|
|
assert router._check_rate_limit("gemini-2.0-flash", 0) == False |
|
|
|
|
|
|
|
|
@test("Model falls back when primary is rate limited") |
|
|
def test_model_fallback(): |
|
|
from app.model_router import ModelRouter, MODEL_CONFIGS |
|
|
router = ModelRouter() |
|
|
|
|
|
|
|
|
rpm_limit = MODEL_CONFIGS["gemini-2.0-flash"]["rpm"] |
|
|
for key_idx in range(len(router.api_keys)): |
|
|
for _ in range(rpm_limit): |
|
|
router._record_usage("gemini-2.0-flash", key_idx) |
|
|
|
|
|
|
|
|
model = router.get_model_for_task("chat") |
|
|
assert model == "gemini-2.0-flash-lite", f"Expected fallback to gemini-2.0-flash-lite, got {model}" |
|
|
|
|
|
|
|
|
@test("Returns None when all models exhausted on all keys") |
|
|
def test_all_models_exhausted(): |
|
|
from app.model_router import ModelRouter, MODEL_CONFIGS |
|
|
router = ModelRouter() |
|
|
|
|
|
|
|
|
for key_idx in range(len(router.api_keys)): |
|
|
for model_name, config in MODEL_CONFIGS.items(): |
|
|
for _ in range(config["rpm"]): |
|
|
router._record_usage(model_name, key_idx) |
|
|
|
|
|
|
|
|
model = router.get_model_for_task("chat") |
|
|
assert model is None, f"Expected None when all exhausted, got {model}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@test("Cache stores and retrieves responses") |
|
|
def test_cache_store_retrieve(): |
|
|
from app.model_router import ModelRouter |
|
|
router = ModelRouter() |
|
|
|
|
|
cache_key = router._get_cache_key("chat", "user1", "test prompt") |
|
|
|
|
|
|
|
|
assert router._check_cache(cache_key) is None |
|
|
|
|
|
|
|
|
router._store_cache(cache_key, "cached response", "gemini-2.0-flash") |
|
|
|
|
|
|
|
|
cached = router._check_cache(cache_key) |
|
|
assert cached == "cached response", f"Expected 'cached response', got {cached}" |
|
|
|
|
|
|
|
|
@test("Cache key includes user_id") |
|
|
def test_cache_key_user_differentiation(): |
|
|
from app.model_router import ModelRouter |
|
|
router = ModelRouter() |
|
|
|
|
|
key1 = router._get_cache_key("chat", "user1", "same prompt") |
|
|
key2 = router._get_cache_key("chat", "user2", "same prompt") |
|
|
|
|
|
assert key1 != key2, "Cache keys should differ for different users" |
|
|
|
|
|
|
|
|
@test("Cache key includes task_type") |
|
|
def test_cache_key_task_differentiation(): |
|
|
from app.model_router import ModelRouter |
|
|
router = ModelRouter() |
|
|
|
|
|
key1 = router._get_cache_key("chat", "user1", "same prompt") |
|
|
key2 = router._get_cache_key("documentation", "user1", "same prompt") |
|
|
|
|
|
assert key1 != key2, "Cache keys should differ for different task types" |
|
|
|
|
|
|
|
|
@test("Cache expires after TTL") |
|
|
def test_cache_expiry(): |
|
|
from app.model_router import ModelRouter, CACHE_TTL |
|
|
router = ModelRouter() |
|
|
|
|
|
cache_key = router._get_cache_key("chat", "user1", "test prompt") |
|
|
router._store_cache(cache_key, "cached response", "gemini-2.0-flash") |
|
|
|
|
|
|
|
|
router.cache[cache_key]["timestamp"] = datetime.now() - timedelta(seconds=CACHE_TTL + 1) |
|
|
|
|
|
|
|
|
cached = router._check_cache(cache_key) |
|
|
assert cached is None, "Expired cache entry should return None" |
|
|
|
|
|
|
|
|
@test("Cache cleaning removes expired entries") |
|
|
def test_cache_cleaning(): |
|
|
from app.model_router import ModelRouter, CACHE_TTL |
|
|
router = ModelRouter() |
|
|
|
|
|
|
|
|
for i in range(5): |
|
|
key = f"expired_{i}" |
|
|
router.cache[key] = { |
|
|
"response": f"response_{i}", |
|
|
"timestamp": datetime.now() - timedelta(seconds=CACHE_TTL + 1), |
|
|
"model": "test" |
|
|
} |
|
|
|
|
|
|
|
|
router.cache["valid"] = { |
|
|
"response": "valid_response", |
|
|
"timestamp": datetime.now(), |
|
|
"model": "test" |
|
|
} |
|
|
|
|
|
|
|
|
router._clean_cache() |
|
|
|
|
|
|
|
|
assert len(router.cache) == 1, f"Expected 1 entry after cleaning, got {len(router.cache)}" |
|
|
assert "valid" in router.cache, "Valid entry should remain after cleaning" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@test("Stats returns correct usage info") |
|
|
def test_stats(): |
|
|
from app.model_router import ModelRouter, MODEL_CONFIGS |
|
|
router = ModelRouter() |
|
|
|
|
|
|
|
|
router._record_usage("gemini-2.0-flash", 0) |
|
|
router._record_usage("gemini-2.0-flash", 0) |
|
|
router._record_usage("gemma-3-27b-it", 0) |
|
|
|
|
|
stats = router.get_stats() |
|
|
|
|
|
assert stats["models"]["gemini-2.0-flash"]["used"] == 2, "Should show 2 uses for gemini-2.0-flash" |
|
|
assert stats["models"]["gemma-3-27b-it"]["used"] == 1, "Should show 1 use for gemma-3-27b-it" |
|
|
|
|
|
expected_limit = MODEL_CONFIGS["gemini-2.0-flash"]["rpm"] * len(router.api_keys) |
|
|
assert stats["models"]["gemini-2.0-flash"]["limit"] == expected_limit |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@test("Multiple keys are loaded from environment") |
|
|
def test_multi_key_loading(): |
|
|
from app.model_router import ModelRouter |
|
|
router = ModelRouter() |
|
|
assert len(router.api_keys) >= 1, "Should have at least one API key" |
|
|
|
|
|
|
|
|
@test("Key health tracking works") |
|
|
def test_key_health_tracking(): |
|
|
from app.model_router import ModelRouter, KEY_COOLDOWN_RATE_LIMIT |
|
|
router = ModelRouter() |
|
|
|
|
|
|
|
|
for i in range(len(router.api_keys)): |
|
|
assert router._is_key_healthy(i) == True, f"Key {i} should be healthy initially" |
|
|
|
|
|
|
|
|
router._mark_key_unhealthy(0, Exception("Test error"), KEY_COOLDOWN_RATE_LIMIT) |
|
|
|
|
|
assert router._is_key_healthy(0) == False, "Key 0 should be unhealthy after marking" |
|
|
assert router.key_health[0]["last_error"] == "Test error" |
|
|
|
|
|
|
|
|
@test("Key rotation skips unhealthy keys") |
|
|
def test_key_rotation_skips_unhealthy(): |
|
|
from app.model_router import ModelRouter |
|
|
router = ModelRouter() |
|
|
|
|
|
if len(router.api_keys) < 2: |
|
|
return |
|
|
|
|
|
|
|
|
router._mark_key_unhealthy(0, Exception("Test"), 60) |
|
|
|
|
|
|
|
|
key_idx, _ = router._get_next_key() |
|
|
assert key_idx != 0 or len(router.api_keys) == 1, "Should skip unhealthy key 0" |
|
|
|
|
|
|
|
|
@test("Key auto-recovers after cooldown") |
|
|
def test_key_auto_recovery(): |
|
|
from app.model_router import ModelRouter |
|
|
from datetime import datetime, timedelta |
|
|
router = ModelRouter() |
|
|
|
|
|
|
|
|
router.key_health[0] = { |
|
|
"healthy": False, |
|
|
"last_error": "Test", |
|
|
"retry_after": datetime.now() - timedelta(seconds=1) |
|
|
} |
|
|
|
|
|
|
|
|
assert router._is_key_healthy(0) == True, "Key should auto-recover after cooldown" |
|
|
assert router.key_health[0]["healthy"] == True |
|
|
assert router.key_health[0]["last_error"] is None |
|
|
|
|
|
|
|
|
@test("Stats includes key information") |
|
|
def test_stats_includes_keys(): |
|
|
from app.model_router import ModelRouter |
|
|
router = ModelRouter() |
|
|
|
|
|
stats = router.get_stats() |
|
|
|
|
|
assert "keys" in stats, "Stats should include keys info" |
|
|
assert stats["keys"]["total"] >= 1, "Should have at least one key" |
|
|
assert stats["keys"]["healthy"] >= 1, "Should have at least one healthy key" |
|
|
assert "details" in stats["keys"], "Stats should include key details" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@test("Generate returns response and model info") |
|
|
async def test_generate_integration(): |
|
|
from app.model_router import generate_with_info |
|
|
|
|
|
response, model = await generate_with_info( |
|
|
"Say 'test' in one word.", |
|
|
task_type="default", |
|
|
use_cache=False |
|
|
) |
|
|
|
|
|
assert response is not None, "Response should not be None" |
|
|
assert len(response) > 0, "Response should not be empty" |
|
|
assert model in ["gemini-2.0-flash", "gemini-2.0-flash-lite", "gemma-3-27b-it", |
|
|
"gemma-3-12b-it", "gemma-3-4b-it", "gemma-3-1b-it", "cache"] |
|
|
|
|
|
|
|
|
@test("Generate uses cache on repeated calls") |
|
|
async def test_generate_uses_cache(): |
|
|
from app.model_router import generate_with_info, router |
|
|
|
|
|
|
|
|
router.cache.clear() |
|
|
|
|
|
prompt = "Say 'cached test' in two words." |
|
|
|
|
|
|
|
|
response1, model1 = await generate_with_info(prompt, task_type="default", use_cache=True) |
|
|
assert model1 != "cache", f"First call should not be from cache, got {model1}" |
|
|
|
|
|
|
|
|
response2, model2 = await generate_with_info(prompt, task_type="default", use_cache=True) |
|
|
assert model2 == "cache", f"Second call should be from cache, got {model2}" |
|
|
assert response1 == response2, "Cached response should match original" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def run_tests(): |
|
|
"""Run all tests.""" |
|
|
print("=" * 60) |
|
|
print("Model Router Tests") |
|
|
print("=" * 60) |
|
|
print() |
|
|
|
|
|
|
|
|
print("--- Model Selection Tests ---") |
|
|
await test_model_selection_chat() |
|
|
await test_model_selection_documentation() |
|
|
await test_model_selection_synthesis() |
|
|
await test_model_selection_unknown() |
|
|
|
|
|
print() |
|
|
print("--- Rate Limiting Tests ---") |
|
|
await test_rate_limit_tracking() |
|
|
await test_model_fallback() |
|
|
await test_all_models_exhausted() |
|
|
|
|
|
print() |
|
|
print("--- Cache Tests ---") |
|
|
await test_cache_store_retrieve() |
|
|
await test_cache_key_user_differentiation() |
|
|
await test_cache_key_task_differentiation() |
|
|
await test_cache_expiry() |
|
|
await test_cache_cleaning() |
|
|
|
|
|
print() |
|
|
print("--- Stats Tests ---") |
|
|
await test_stats() |
|
|
|
|
|
print() |
|
|
print("--- Multi-Key Tests ---") |
|
|
await test_multi_key_loading() |
|
|
await test_key_health_tracking() |
|
|
await test_key_rotation_skips_unhealthy() |
|
|
await test_key_auto_recovery() |
|
|
await test_stats_includes_keys() |
|
|
|
|
|
print() |
|
|
print("--- Integration Tests (requires API key) ---") |
|
|
|
|
|
|
|
|
if not os.getenv("GEMINI_API_KEY") and not os.getenv("GEMINI_API_KEYS"): |
|
|
print("[SKIP] Integration tests skipped - no API keys") |
|
|
else: |
|
|
await test_generate_integration() |
|
|
await test_generate_uses_cache() |
|
|
|
|
|
print() |
|
|
print("=" * 60) |
|
|
print(f"Results: {TESTS_PASSED} passed, {TESTS_FAILED} failed") |
|
|
print("=" * 60) |
|
|
|
|
|
return TESTS_FAILED == 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
success = asyncio.run(run_tests()) |
|
|
exit(0 if success else 1) |
|
|
|