ProjectMemory / backend /tests /test_model_router.py
Amal Nimmy Lal
feat : Project Memory
35765b5
"""Test cases for Model Router - multi-model rotation with rate limiting and caching."""
import asyncio
import time
from unittest.mock import patch, MagicMock, AsyncMock
from datetime import datetime, timedelta
import sys
import os
# Add parent to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from dotenv import load_dotenv
load_dotenv()
# Test configuration
TESTS_PASSED = 0
TESTS_FAILED = 0
def test(name):
"""Decorator for test functions."""
def decorator(func):
async def wrapper():
global TESTS_PASSED, TESTS_FAILED
try:
if asyncio.iscoroutinefunction(func):
await func()
else:
func()
print(f"[PASS] {name}")
TESTS_PASSED += 1
except AssertionError as e:
print(f"[FAIL] {name}: {e}")
TESTS_FAILED += 1
except Exception as e:
print(f"[ERROR] {name}: {e}")
TESTS_FAILED += 1
return wrapper
return decorator
# ========== Model Selection Tests ==========
@test("Model selection returns best model for chat task")
def test_model_selection_chat():
from app.model_router import ModelRouter, TASK_PRIORITIES
router = ModelRouter()
model = router.get_model_for_task("chat")
assert model == "gemini-2.0-flash", f"Expected gemini-2.0-flash, got {model}"
@test("Model selection returns best model for documentation task")
def test_model_selection_documentation():
from app.model_router import ModelRouter, TASK_PRIORITIES
router = ModelRouter()
model = router.get_model_for_task("documentation")
assert model == "gemini-2.0-flash-lite", f"Expected gemini-2.0-flash-lite, got {model}"
@test("Model selection returns best model for synthesis task")
def test_model_selection_synthesis():
from app.model_router import ModelRouter, TASK_PRIORITIES
router = ModelRouter()
model = router.get_model_for_task("synthesis")
assert model == "gemma-3-27b-it", f"Expected gemma-3-27b-it, got {model}"
@test("Model selection falls back to default for unknown task")
def test_model_selection_unknown():
from app.model_router import ModelRouter
router = ModelRouter()
model = router.get_model_for_task("unknown_task_type")
assert model == "gemini-2.0-flash", f"Expected gemini-2.0-flash (default), got {model}"
# ========== Rate Limiting Tests ==========
@test("Rate limit tracking works correctly")
def test_rate_limit_tracking():
from app.model_router import ModelRouter, MODEL_CONFIGS
router = ModelRouter()
# Initially all models should be available (key 0)
assert router._check_rate_limit("gemini-2.0-flash", 0) == True
# Record usage up to limit
rpm_limit = MODEL_CONFIGS["gemini-2.0-flash"]["rpm"]
for _ in range(rpm_limit):
router._record_usage("gemini-2.0-flash", 0)
# Should now be rate limited for key 0
assert router._check_rate_limit("gemini-2.0-flash", 0) == False
@test("Model falls back when primary is rate limited")
def test_model_fallback():
from app.model_router import ModelRouter, MODEL_CONFIGS
router = ModelRouter()
# Exhaust gemini-2.0-flash rate limit on all keys
rpm_limit = MODEL_CONFIGS["gemini-2.0-flash"]["rpm"]
for key_idx in range(len(router.api_keys)):
for _ in range(rpm_limit):
router._record_usage("gemini-2.0-flash", key_idx)
# Should fall back to next model in chat priority
model = router.get_model_for_task("chat")
assert model == "gemini-2.0-flash-lite", f"Expected fallback to gemini-2.0-flash-lite, got {model}"
@test("Returns None when all models exhausted on all keys")
def test_all_models_exhausted():
from app.model_router import ModelRouter, MODEL_CONFIGS
router = ModelRouter()
# Exhaust all models on all keys
for key_idx in range(len(router.api_keys)):
for model_name, config in MODEL_CONFIGS.items():
for _ in range(config["rpm"]):
router._record_usage(model_name, key_idx)
# Should return None
model = router.get_model_for_task("chat")
assert model is None, f"Expected None when all exhausted, got {model}"
# ========== Cache Tests ==========
@test("Cache stores and retrieves responses")
def test_cache_store_retrieve():
from app.model_router import ModelRouter
router = ModelRouter()
cache_key = router._get_cache_key("chat", "user1", "test prompt")
# Initially empty
assert router._check_cache(cache_key) is None
# Store response
router._store_cache(cache_key, "cached response", "gemini-2.0-flash")
# Should retrieve
cached = router._check_cache(cache_key)
assert cached == "cached response", f"Expected 'cached response', got {cached}"
@test("Cache key includes user_id")
def test_cache_key_user_differentiation():
from app.model_router import ModelRouter
router = ModelRouter()
key1 = router._get_cache_key("chat", "user1", "same prompt")
key2 = router._get_cache_key("chat", "user2", "same prompt")
assert key1 != key2, "Cache keys should differ for different users"
@test("Cache key includes task_type")
def test_cache_key_task_differentiation():
from app.model_router import ModelRouter
router = ModelRouter()
key1 = router._get_cache_key("chat", "user1", "same prompt")
key2 = router._get_cache_key("documentation", "user1", "same prompt")
assert key1 != key2, "Cache keys should differ for different task types"
@test("Cache expires after TTL")
def test_cache_expiry():
from app.model_router import ModelRouter, CACHE_TTL
router = ModelRouter()
cache_key = router._get_cache_key("chat", "user1", "test prompt")
router._store_cache(cache_key, "cached response", "gemini-2.0-flash")
# Manually expire the cache entry
router.cache[cache_key]["timestamp"] = datetime.now() - timedelta(seconds=CACHE_TTL + 1)
# Should not retrieve expired entry
cached = router._check_cache(cache_key)
assert cached is None, "Expired cache entry should return None"
@test("Cache cleaning removes expired entries")
def test_cache_cleaning():
from app.model_router import ModelRouter, CACHE_TTL
router = ModelRouter()
# Add expired entries
for i in range(5):
key = f"expired_{i}"
router.cache[key] = {
"response": f"response_{i}",
"timestamp": datetime.now() - timedelta(seconds=CACHE_TTL + 1),
"model": "test"
}
# Add valid entry
router.cache["valid"] = {
"response": "valid_response",
"timestamp": datetime.now(),
"model": "test"
}
# Clean cache
router._clean_cache()
# Only valid entry should remain
assert len(router.cache) == 1, f"Expected 1 entry after cleaning, got {len(router.cache)}"
assert "valid" in router.cache, "Valid entry should remain after cleaning"
# ========== Stats Tests ==========
@test("Stats returns correct usage info")
def test_stats():
from app.model_router import ModelRouter, MODEL_CONFIGS
router = ModelRouter()
# Record some usage on key 0
router._record_usage("gemini-2.0-flash", 0)
router._record_usage("gemini-2.0-flash", 0)
router._record_usage("gemma-3-27b-it", 0)
stats = router.get_stats()
assert stats["models"]["gemini-2.0-flash"]["used"] == 2, "Should show 2 uses for gemini-2.0-flash"
assert stats["models"]["gemma-3-27b-it"]["used"] == 1, "Should show 1 use for gemma-3-27b-it"
# Limit is per-key * num_keys
expected_limit = MODEL_CONFIGS["gemini-2.0-flash"]["rpm"] * len(router.api_keys)
assert stats["models"]["gemini-2.0-flash"]["limit"] == expected_limit
# ========== Multi-Key Tests ==========
@test("Multiple keys are loaded from environment")
def test_multi_key_loading():
from app.model_router import ModelRouter
router = ModelRouter()
assert len(router.api_keys) >= 1, "Should have at least one API key"
@test("Key health tracking works")
def test_key_health_tracking():
from app.model_router import ModelRouter, KEY_COOLDOWN_RATE_LIMIT
router = ModelRouter()
# Initially all keys should be healthy
for i in range(len(router.api_keys)):
assert router._is_key_healthy(i) == True, f"Key {i} should be healthy initially"
# Mark first key as unhealthy
router._mark_key_unhealthy(0, Exception("Test error"), KEY_COOLDOWN_RATE_LIMIT)
assert router._is_key_healthy(0) == False, "Key 0 should be unhealthy after marking"
assert router.key_health[0]["last_error"] == "Test error"
@test("Key rotation skips unhealthy keys")
def test_key_rotation_skips_unhealthy():
from app.model_router import ModelRouter
router = ModelRouter()
if len(router.api_keys) < 2:
return # Skip if only one key
# Mark key 0 as unhealthy
router._mark_key_unhealthy(0, Exception("Test"), 60)
# Get next key should skip key 0
key_idx, _ = router._get_next_key()
assert key_idx != 0 or len(router.api_keys) == 1, "Should skip unhealthy key 0"
@test("Key auto-recovers after cooldown")
def test_key_auto_recovery():
from app.model_router import ModelRouter
from datetime import datetime, timedelta
router = ModelRouter()
# Mark key as unhealthy with expired cooldown
router.key_health[0] = {
"healthy": False,
"last_error": "Test",
"retry_after": datetime.now() - timedelta(seconds=1) # Already expired
}
# Should recover when checked
assert router._is_key_healthy(0) == True, "Key should auto-recover after cooldown"
assert router.key_health[0]["healthy"] == True
assert router.key_health[0]["last_error"] is None
@test("Stats includes key information")
def test_stats_includes_keys():
from app.model_router import ModelRouter
router = ModelRouter()
stats = router.get_stats()
assert "keys" in stats, "Stats should include keys info"
assert stats["keys"]["total"] >= 1, "Should have at least one key"
assert stats["keys"]["healthy"] >= 1, "Should have at least one healthy key"
assert "details" in stats["keys"], "Stats should include key details"
# ========== Integration Tests (requires API key) ==========
@test("Generate returns response and model info")
async def test_generate_integration():
from app.model_router import generate_with_info
response, model = await generate_with_info(
"Say 'test' in one word.",
task_type="default",
use_cache=False
)
assert response is not None, "Response should not be None"
assert len(response) > 0, "Response should not be empty"
assert model in ["gemini-2.0-flash", "gemini-2.0-flash-lite", "gemma-3-27b-it",
"gemma-3-12b-it", "gemma-3-4b-it", "gemma-3-1b-it", "cache"]
@test("Generate uses cache on repeated calls")
async def test_generate_uses_cache():
from app.model_router import generate_with_info, router
# Clear cache first
router.cache.clear()
prompt = "Say 'cached test' in two words."
# First call - should hit model
response1, model1 = await generate_with_info(prompt, task_type="default", use_cache=True)
assert model1 != "cache", f"First call should not be from cache, got {model1}"
# Second call - should hit cache
response2, model2 = await generate_with_info(prompt, task_type="default", use_cache=True)
assert model2 == "cache", f"Second call should be from cache, got {model2}"
assert response1 == response2, "Cached response should match original"
# ========== Run Tests ==========
async def run_tests():
"""Run all tests."""
print("=" * 60)
print("Model Router Tests")
print("=" * 60)
print()
# Unit tests (no API needed)
print("--- Model Selection Tests ---")
await test_model_selection_chat()
await test_model_selection_documentation()
await test_model_selection_synthesis()
await test_model_selection_unknown()
print()
print("--- Rate Limiting Tests ---")
await test_rate_limit_tracking()
await test_model_fallback()
await test_all_models_exhausted()
print()
print("--- Cache Tests ---")
await test_cache_store_retrieve()
await test_cache_key_user_differentiation()
await test_cache_key_task_differentiation()
await test_cache_expiry()
await test_cache_cleaning()
print()
print("--- Stats Tests ---")
await test_stats()
print()
print("--- Multi-Key Tests ---")
await test_multi_key_loading()
await test_key_health_tracking()
await test_key_rotation_skips_unhealthy()
await test_key_auto_recovery()
await test_stats_includes_keys()
print()
print("--- Integration Tests (requires API key) ---")
# Check if API key is available
if not os.getenv("GEMINI_API_KEY") and not os.getenv("GEMINI_API_KEYS"):
print("[SKIP] Integration tests skipped - no API keys")
else:
await test_generate_integration()
await test_generate_uses_cache()
print()
print("=" * 60)
print(f"Results: {TESTS_PASSED} passed, {TESTS_FAILED} failed")
print("=" * 60)
return TESTS_FAILED == 0
if __name__ == "__main__":
success = asyncio.run(run_tests())
exit(0 if success else 1)