"""Test cases for Model Router - multi-model rotation with rate limiting and caching.""" import asyncio import time from unittest.mock import patch, MagicMock, AsyncMock from datetime import datetime, timedelta import sys import os # Add parent to path for imports sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from dotenv import load_dotenv load_dotenv() # Test configuration TESTS_PASSED = 0 TESTS_FAILED = 0 def test(name): """Decorator for test functions.""" def decorator(func): async def wrapper(): global TESTS_PASSED, TESTS_FAILED try: if asyncio.iscoroutinefunction(func): await func() else: func() print(f"[PASS] {name}") TESTS_PASSED += 1 except AssertionError as e: print(f"[FAIL] {name}: {e}") TESTS_FAILED += 1 except Exception as e: print(f"[ERROR] {name}: {e}") TESTS_FAILED += 1 return wrapper return decorator # ========== Model Selection Tests ========== @test("Model selection returns best model for chat task") def test_model_selection_chat(): from app.model_router import ModelRouter, TASK_PRIORITIES router = ModelRouter() model = router.get_model_for_task("chat") assert model == "gemini-2.0-flash", f"Expected gemini-2.0-flash, got {model}" @test("Model selection returns best model for documentation task") def test_model_selection_documentation(): from app.model_router import ModelRouter, TASK_PRIORITIES router = ModelRouter() model = router.get_model_for_task("documentation") assert model == "gemini-2.0-flash-lite", f"Expected gemini-2.0-flash-lite, got {model}" @test("Model selection returns best model for synthesis task") def test_model_selection_synthesis(): from app.model_router import ModelRouter, TASK_PRIORITIES router = ModelRouter() model = router.get_model_for_task("synthesis") assert model == "gemma-3-27b-it", f"Expected gemma-3-27b-it, got {model}" @test("Model selection falls back to default for unknown task") def test_model_selection_unknown(): from app.model_router import ModelRouter router = ModelRouter() model = router.get_model_for_task("unknown_task_type") assert model == "gemini-2.0-flash", f"Expected gemini-2.0-flash (default), got {model}" # ========== Rate Limiting Tests ========== @test("Rate limit tracking works correctly") def test_rate_limit_tracking(): from app.model_router import ModelRouter, MODEL_CONFIGS router = ModelRouter() # Initially all models should be available (key 0) assert router._check_rate_limit("gemini-2.0-flash", 0) == True # Record usage up to limit rpm_limit = MODEL_CONFIGS["gemini-2.0-flash"]["rpm"] for _ in range(rpm_limit): router._record_usage("gemini-2.0-flash", 0) # Should now be rate limited for key 0 assert router._check_rate_limit("gemini-2.0-flash", 0) == False @test("Model falls back when primary is rate limited") def test_model_fallback(): from app.model_router import ModelRouter, MODEL_CONFIGS router = ModelRouter() # Exhaust gemini-2.0-flash rate limit on all keys rpm_limit = MODEL_CONFIGS["gemini-2.0-flash"]["rpm"] for key_idx in range(len(router.api_keys)): for _ in range(rpm_limit): router._record_usage("gemini-2.0-flash", key_idx) # Should fall back to next model in chat priority model = router.get_model_for_task("chat") assert model == "gemini-2.0-flash-lite", f"Expected fallback to gemini-2.0-flash-lite, got {model}" @test("Returns None when all models exhausted on all keys") def test_all_models_exhausted(): from app.model_router import ModelRouter, MODEL_CONFIGS router = ModelRouter() # Exhaust all models on all keys for key_idx in range(len(router.api_keys)): for model_name, config in MODEL_CONFIGS.items(): for _ in range(config["rpm"]): router._record_usage(model_name, key_idx) # Should return None model = router.get_model_for_task("chat") assert model is None, f"Expected None when all exhausted, got {model}" # ========== Cache Tests ========== @test("Cache stores and retrieves responses") def test_cache_store_retrieve(): from app.model_router import ModelRouter router = ModelRouter() cache_key = router._get_cache_key("chat", "user1", "test prompt") # Initially empty assert router._check_cache(cache_key) is None # Store response router._store_cache(cache_key, "cached response", "gemini-2.0-flash") # Should retrieve cached = router._check_cache(cache_key) assert cached == "cached response", f"Expected 'cached response', got {cached}" @test("Cache key includes user_id") def test_cache_key_user_differentiation(): from app.model_router import ModelRouter router = ModelRouter() key1 = router._get_cache_key("chat", "user1", "same prompt") key2 = router._get_cache_key("chat", "user2", "same prompt") assert key1 != key2, "Cache keys should differ for different users" @test("Cache key includes task_type") def test_cache_key_task_differentiation(): from app.model_router import ModelRouter router = ModelRouter() key1 = router._get_cache_key("chat", "user1", "same prompt") key2 = router._get_cache_key("documentation", "user1", "same prompt") assert key1 != key2, "Cache keys should differ for different task types" @test("Cache expires after TTL") def test_cache_expiry(): from app.model_router import ModelRouter, CACHE_TTL router = ModelRouter() cache_key = router._get_cache_key("chat", "user1", "test prompt") router._store_cache(cache_key, "cached response", "gemini-2.0-flash") # Manually expire the cache entry router.cache[cache_key]["timestamp"] = datetime.now() - timedelta(seconds=CACHE_TTL + 1) # Should not retrieve expired entry cached = router._check_cache(cache_key) assert cached is None, "Expired cache entry should return None" @test("Cache cleaning removes expired entries") def test_cache_cleaning(): from app.model_router import ModelRouter, CACHE_TTL router = ModelRouter() # Add expired entries for i in range(5): key = f"expired_{i}" router.cache[key] = { "response": f"response_{i}", "timestamp": datetime.now() - timedelta(seconds=CACHE_TTL + 1), "model": "test" } # Add valid entry router.cache["valid"] = { "response": "valid_response", "timestamp": datetime.now(), "model": "test" } # Clean cache router._clean_cache() # Only valid entry should remain assert len(router.cache) == 1, f"Expected 1 entry after cleaning, got {len(router.cache)}" assert "valid" in router.cache, "Valid entry should remain after cleaning" # ========== Stats Tests ========== @test("Stats returns correct usage info") def test_stats(): from app.model_router import ModelRouter, MODEL_CONFIGS router = ModelRouter() # Record some usage on key 0 router._record_usage("gemini-2.0-flash", 0) router._record_usage("gemini-2.0-flash", 0) router._record_usage("gemma-3-27b-it", 0) stats = router.get_stats() assert stats["models"]["gemini-2.0-flash"]["used"] == 2, "Should show 2 uses for gemini-2.0-flash" assert stats["models"]["gemma-3-27b-it"]["used"] == 1, "Should show 1 use for gemma-3-27b-it" # Limit is per-key * num_keys expected_limit = MODEL_CONFIGS["gemini-2.0-flash"]["rpm"] * len(router.api_keys) assert stats["models"]["gemini-2.0-flash"]["limit"] == expected_limit # ========== Multi-Key Tests ========== @test("Multiple keys are loaded from environment") def test_multi_key_loading(): from app.model_router import ModelRouter router = ModelRouter() assert len(router.api_keys) >= 1, "Should have at least one API key" @test("Key health tracking works") def test_key_health_tracking(): from app.model_router import ModelRouter, KEY_COOLDOWN_RATE_LIMIT router = ModelRouter() # Initially all keys should be healthy for i in range(len(router.api_keys)): assert router._is_key_healthy(i) == True, f"Key {i} should be healthy initially" # Mark first key as unhealthy router._mark_key_unhealthy(0, Exception("Test error"), KEY_COOLDOWN_RATE_LIMIT) assert router._is_key_healthy(0) == False, "Key 0 should be unhealthy after marking" assert router.key_health[0]["last_error"] == "Test error" @test("Key rotation skips unhealthy keys") def test_key_rotation_skips_unhealthy(): from app.model_router import ModelRouter router = ModelRouter() if len(router.api_keys) < 2: return # Skip if only one key # Mark key 0 as unhealthy router._mark_key_unhealthy(0, Exception("Test"), 60) # Get next key should skip key 0 key_idx, _ = router._get_next_key() assert key_idx != 0 or len(router.api_keys) == 1, "Should skip unhealthy key 0" @test("Key auto-recovers after cooldown") def test_key_auto_recovery(): from app.model_router import ModelRouter from datetime import datetime, timedelta router = ModelRouter() # Mark key as unhealthy with expired cooldown router.key_health[0] = { "healthy": False, "last_error": "Test", "retry_after": datetime.now() - timedelta(seconds=1) # Already expired } # Should recover when checked assert router._is_key_healthy(0) == True, "Key should auto-recover after cooldown" assert router.key_health[0]["healthy"] == True assert router.key_health[0]["last_error"] is None @test("Stats includes key information") def test_stats_includes_keys(): from app.model_router import ModelRouter router = ModelRouter() stats = router.get_stats() assert "keys" in stats, "Stats should include keys info" assert stats["keys"]["total"] >= 1, "Should have at least one key" assert stats["keys"]["healthy"] >= 1, "Should have at least one healthy key" assert "details" in stats["keys"], "Stats should include key details" # ========== Integration Tests (requires API key) ========== @test("Generate returns response and model info") async def test_generate_integration(): from app.model_router import generate_with_info response, model = await generate_with_info( "Say 'test' in one word.", task_type="default", use_cache=False ) assert response is not None, "Response should not be None" assert len(response) > 0, "Response should not be empty" assert model in ["gemini-2.0-flash", "gemini-2.0-flash-lite", "gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it", "gemma-3-1b-it", "cache"] @test("Generate uses cache on repeated calls") async def test_generate_uses_cache(): from app.model_router import generate_with_info, router # Clear cache first router.cache.clear() prompt = "Say 'cached test' in two words." # First call - should hit model response1, model1 = await generate_with_info(prompt, task_type="default", use_cache=True) assert model1 != "cache", f"First call should not be from cache, got {model1}" # Second call - should hit cache response2, model2 = await generate_with_info(prompt, task_type="default", use_cache=True) assert model2 == "cache", f"Second call should be from cache, got {model2}" assert response1 == response2, "Cached response should match original" # ========== Run Tests ========== async def run_tests(): """Run all tests.""" print("=" * 60) print("Model Router Tests") print("=" * 60) print() # Unit tests (no API needed) print("--- Model Selection Tests ---") await test_model_selection_chat() await test_model_selection_documentation() await test_model_selection_synthesis() await test_model_selection_unknown() print() print("--- Rate Limiting Tests ---") await test_rate_limit_tracking() await test_model_fallback() await test_all_models_exhausted() print() print("--- Cache Tests ---") await test_cache_store_retrieve() await test_cache_key_user_differentiation() await test_cache_key_task_differentiation() await test_cache_expiry() await test_cache_cleaning() print() print("--- Stats Tests ---") await test_stats() print() print("--- Multi-Key Tests ---") await test_multi_key_loading() await test_key_health_tracking() await test_key_rotation_skips_unhealthy() await test_key_auto_recovery() await test_stats_includes_keys() print() print("--- Integration Tests (requires API key) ---") # Check if API key is available if not os.getenv("GEMINI_API_KEY") and not os.getenv("GEMINI_API_KEYS"): print("[SKIP] Integration tests skipped - no API keys") else: await test_generate_integration() await test_generate_uses_cache() print() print("=" * 60) print(f"Results: {TESTS_PASSED} passed, {TESTS_FAILED} failed") print("=" * 60) return TESTS_FAILED == 0 if __name__ == "__main__": success = asyncio.run(run_tests()) exit(0 if success else 1)