|
|
"""Model Router for multi-model rotation with rate limiting and caching.""" |
|
|
|
|
|
import google.generativeai as genai |
|
|
import time |
|
|
import hashlib |
|
|
import os |
|
|
from datetime import datetime, timedelta |
|
|
from typing import Optional |
|
|
from collections import deque |
|
|
import asyncio |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
KEY_COOLDOWN_RATE_LIMIT = 60 |
|
|
KEY_COOLDOWN_OTHER = 30 |
|
|
|
|
|
|
|
|
def _load_api_keys() -> list[str]: |
|
|
"""Load API keys from environment (backward compatible).""" |
|
|
keys_str = os.getenv("GEMINI_API_KEYS", "") |
|
|
if keys_str: |
|
|
return [k.strip() for k in keys_str.split(",") if k.strip()] |
|
|
single_key = os.getenv("GEMINI_API_KEY") |
|
|
return [single_key] if single_key else [] |
|
|
|
|
|
|
|
|
|
|
|
MODEL_CONFIGS = { |
|
|
"gemini-2.0-flash": {"rpm": 15, "quality": 1}, |
|
|
"gemini-2.0-flash-lite": {"rpm": 30, "quality": 2}, |
|
|
"gemma-3-27b-it": {"rpm": 30, "quality": 3}, |
|
|
"gemma-3-12b-it": {"rpm": 30, "quality": 4}, |
|
|
"gemma-3-4b-it": {"rpm": 30, "quality": 5}, |
|
|
"gemma-3-1b-it": {"rpm": 30, "quality": 6}, |
|
|
} |
|
|
|
|
|
|
|
|
TASK_PRIORITIES = { |
|
|
"chat": ["gemini-2.0-flash", "gemini-2.0-flash-lite", "gemma-3-27b-it"], |
|
|
"smart_query": ["gemini-2.0-flash", "gemma-3-27b-it", "gemma-3-12b-it"], |
|
|
"documentation": ["gemini-2.0-flash-lite", "gemma-3-27b-it", "gemma-3-12b-it"], |
|
|
"synthesis": ["gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it"], |
|
|
"default": ["gemini-2.0-flash", "gemini-2.0-flash-lite", "gemma-3-27b-it", |
|
|
"gemma-3-12b-it", "gemma-3-4b-it", "gemma-3-1b-it"], |
|
|
} |
|
|
|
|
|
|
|
|
CACHE_TTL = 300 |
|
|
|
|
|
|
|
|
RETRY_DELAY = 2.5 |
|
|
|
|
|
|
|
|
class ModelRouter: |
|
|
"""Manages model rotation, rate limiting, response caching, and multi-key support.""" |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
self.api_keys = _load_api_keys() |
|
|
if not self.api_keys: |
|
|
raise ValueError("No API keys found. Set GEMINI_API_KEYS or GEMINI_API_KEY in .env") |
|
|
|
|
|
|
|
|
self.key_index = 0 |
|
|
self.key_health: dict[int, dict] = { |
|
|
i: {"healthy": True, "last_error": None, "retry_after": None} |
|
|
for i in range(len(self.api_keys)) |
|
|
} |
|
|
|
|
|
|
|
|
self.usage: dict[int, dict[str, deque]] = { |
|
|
i: {model: deque() for model in MODEL_CONFIGS} |
|
|
for i in range(len(self.api_keys)) |
|
|
} |
|
|
|
|
|
|
|
|
self.cache: dict[str, dict] = {} |
|
|
|
|
|
|
|
|
self._configure_key(0) |
|
|
self.models: dict[str, genai.GenerativeModel] = { |
|
|
model: genai.GenerativeModel(model) for model in MODEL_CONFIGS |
|
|
} |
|
|
|
|
|
def _configure_key(self, key_idx: int): |
|
|
"""Configure genai with the specified API key.""" |
|
|
genai.configure(api_key=self.api_keys[key_idx]) |
|
|
|
|
|
def _is_key_healthy(self, key_idx: int) -> bool: |
|
|
"""Check if a key is healthy (not in cooldown).""" |
|
|
health = self.key_health[key_idx] |
|
|
if not health["healthy"] and health["retry_after"]: |
|
|
if datetime.now() > health["retry_after"]: |
|
|
health["healthy"] = True |
|
|
health["last_error"] = None |
|
|
health["retry_after"] = None |
|
|
return health["healthy"] |
|
|
|
|
|
def _mark_key_unhealthy(self, key_idx: int, error: Exception, cooldown_seconds: int): |
|
|
"""Mark a key as unhealthy with cooldown.""" |
|
|
self.key_health[key_idx] = { |
|
|
"healthy": False, |
|
|
"last_error": str(error), |
|
|
"retry_after": datetime.now() + timedelta(seconds=cooldown_seconds) |
|
|
} |
|
|
|
|
|
def _get_next_key(self) -> tuple[int, str]: |
|
|
"""Get next healthy API key using round-robin.""" |
|
|
num_keys = len(self.api_keys) |
|
|
|
|
|
|
|
|
for _ in range(num_keys): |
|
|
idx = self.key_index % num_keys |
|
|
self.key_index += 1 |
|
|
if self._is_key_healthy(idx): |
|
|
return idx, self.api_keys[idx] |
|
|
|
|
|
|
|
|
earliest_idx = 0 |
|
|
earliest_time = datetime.max |
|
|
for idx, health in self.key_health.items(): |
|
|
if health["retry_after"] and health["retry_after"] < earliest_time: |
|
|
earliest_time = health["retry_after"] |
|
|
earliest_idx = idx |
|
|
|
|
|
|
|
|
self.key_health[earliest_idx]["healthy"] = True |
|
|
return earliest_idx, self.api_keys[earliest_idx] |
|
|
|
|
|
def _get_model_with_key(self, model_name: str, key_idx: int) -> genai.GenerativeModel: |
|
|
"""Get a model instance configured with the specified key.""" |
|
|
self._configure_key(key_idx) |
|
|
return genai.GenerativeModel(model_name) |
|
|
|
|
|
def _get_cache_key(self, task_type: str, user_id: Optional[str], prompt: str) -> str: |
|
|
"""Generate cache key from task type, user, and prompt.""" |
|
|
|
|
|
key_string = f"{task_type}:{user_id or 'anon'}:{prompt[:200]}" |
|
|
return hashlib.md5(key_string.encode()).hexdigest() |
|
|
|
|
|
def _check_cache(self, cache_key: str) -> Optional[str]: |
|
|
"""Check if response is cached and not expired.""" |
|
|
if cache_key in self.cache: |
|
|
entry = self.cache[cache_key] |
|
|
if datetime.now() - entry["timestamp"] < timedelta(seconds=CACHE_TTL): |
|
|
return entry["response"] |
|
|
else: |
|
|
|
|
|
del self.cache[cache_key] |
|
|
return None |
|
|
|
|
|
def _store_cache(self, cache_key: str, response: str, model_used: str): |
|
|
"""Store response in cache.""" |
|
|
self.cache[cache_key] = { |
|
|
"response": response, |
|
|
"timestamp": datetime.now(), |
|
|
"model": model_used |
|
|
} |
|
|
|
|
|
if len(self.cache) > 100: |
|
|
self._clean_cache() |
|
|
|
|
|
def _clean_cache(self): |
|
|
"""Remove expired cache entries.""" |
|
|
now = datetime.now() |
|
|
expired_keys = [ |
|
|
key for key, entry in self.cache.items() |
|
|
if now - entry["timestamp"] >= timedelta(seconds=CACHE_TTL) |
|
|
] |
|
|
for key in expired_keys: |
|
|
del self.cache[key] |
|
|
|
|
|
def _check_rate_limit(self, model_name: str, key_idx: int = 0) -> bool: |
|
|
"""Check if model is within rate limit for a specific key. Returns True if OK to use.""" |
|
|
config = MODEL_CONFIGS[model_name] |
|
|
rpm_limit = config["rpm"] |
|
|
usage_queue = self.usage[key_idx][model_name] |
|
|
|
|
|
|
|
|
now = time.time() |
|
|
while usage_queue and usage_queue[0] < now - 60: |
|
|
usage_queue.popleft() |
|
|
|
|
|
|
|
|
return len(usage_queue) < rpm_limit |
|
|
|
|
|
def _record_usage(self, model_name: str, key_idx: int = 0): |
|
|
"""Record a usage for rate limiting.""" |
|
|
self.usage[key_idx][model_name].append(time.time()) |
|
|
|
|
|
def get_model_for_task(self, task_type: str) -> Optional[str]: |
|
|
"""Get the best available model for a task type (checks all keys).""" |
|
|
priorities = TASK_PRIORITIES.get(task_type, TASK_PRIORITIES["default"]) |
|
|
|
|
|
|
|
|
for key_idx in range(len(self.api_keys)): |
|
|
if not self._is_key_healthy(key_idx): |
|
|
continue |
|
|
for model_name in priorities: |
|
|
if self._check_rate_limit(model_name, key_idx): |
|
|
return model_name |
|
|
|
|
|
|
|
|
for key_idx in range(len(self.api_keys)): |
|
|
if not self._is_key_healthy(key_idx): |
|
|
continue |
|
|
for model_name in MODEL_CONFIGS: |
|
|
if self._check_rate_limit(model_name, key_idx): |
|
|
return model_name |
|
|
|
|
|
return None |
|
|
|
|
|
async def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
task_type: str = "default", |
|
|
user_id: Optional[str] = None, |
|
|
use_cache: bool = True |
|
|
) -> tuple[str, str]: |
|
|
"""Generate response with model rotation, key rotation, and caching. |
|
|
|
|
|
Args: |
|
|
prompt: The prompt to send to the model |
|
|
task_type: Type of task (chat, smart_query, documentation, synthesis) |
|
|
user_id: User ID for cache key differentiation |
|
|
use_cache: Whether to use caching (default True) |
|
|
|
|
|
Returns: |
|
|
Tuple of (response_text, model_used) |
|
|
""" |
|
|
|
|
|
if use_cache: |
|
|
cache_key = self._get_cache_key(task_type, user_id, prompt) |
|
|
cached = self._check_cache(cache_key) |
|
|
if cached: |
|
|
return cached, "cache" |
|
|
|
|
|
|
|
|
priorities = TASK_PRIORITIES.get(task_type, TASK_PRIORITIES["default"]) |
|
|
all_models = list(priorities) + [m for m in MODEL_CONFIGS if m not in priorities] |
|
|
|
|
|
last_error = None |
|
|
tried_combinations = set() |
|
|
|
|
|
|
|
|
max_attempts = len(self.api_keys) * len(all_models) |
|
|
|
|
|
for _ in range(max_attempts): |
|
|
|
|
|
key_idx, api_key = self._get_next_key() |
|
|
|
|
|
for model_name in all_models: |
|
|
combo = (key_idx, model_name) |
|
|
if combo in tried_combinations: |
|
|
continue |
|
|
|
|
|
|
|
|
if not self._check_rate_limit(model_name, key_idx): |
|
|
continue |
|
|
|
|
|
tried_combinations.add(combo) |
|
|
|
|
|
try: |
|
|
|
|
|
model = self._get_model_with_key(model_name, key_idx) |
|
|
self._record_usage(model_name, key_idx) |
|
|
|
|
|
response = model.generate_content(prompt) |
|
|
response_text = response.text |
|
|
|
|
|
|
|
|
if use_cache: |
|
|
self._store_cache(cache_key, response_text, model_name) |
|
|
|
|
|
return response_text, model_name |
|
|
|
|
|
except Exception as e: |
|
|
error_str = str(e).lower() |
|
|
last_error = e |
|
|
|
|
|
|
|
|
if "429" in str(e) or "resource exhausted" in error_str or "quota" in error_str: |
|
|
|
|
|
self._mark_key_unhealthy(key_idx, e, KEY_COOLDOWN_RATE_LIMIT) |
|
|
await asyncio.sleep(RETRY_DELAY) |
|
|
break |
|
|
|
|
|
elif "401" in str(e) or "403" in str(e) or "invalid" in error_str: |
|
|
|
|
|
self._mark_key_unhealthy(key_idx, e, 86400) |
|
|
break |
|
|
|
|
|
else: |
|
|
|
|
|
await asyncio.sleep(0.5) |
|
|
continue |
|
|
|
|
|
|
|
|
if last_error: |
|
|
raise Exception(f"All models/keys exhausted. Last error: {last_error}") |
|
|
else: |
|
|
raise Exception("All models are rate limited. Please try again in a minute.") |
|
|
|
|
|
async def generate_with_model( |
|
|
self, |
|
|
model_name: str, |
|
|
prompt: str, |
|
|
user_id: Optional[str] = None, |
|
|
use_cache: bool = True |
|
|
) -> str: |
|
|
"""Generate with a specific model (for chat sessions that need consistency). |
|
|
|
|
|
Falls back to other models if specified model is rate limited. |
|
|
""" |
|
|
response, _ = await self.generate( |
|
|
prompt=prompt, |
|
|
task_type="default", |
|
|
user_id=user_id, |
|
|
use_cache=use_cache |
|
|
) |
|
|
return response |
|
|
|
|
|
def get_stats(self) -> dict: |
|
|
"""Get current usage stats for monitoring.""" |
|
|
now = time.time() |
|
|
stats = { |
|
|
"keys": { |
|
|
"total": len(self.api_keys), |
|
|
"healthy": sum(1 for i in range(len(self.api_keys)) if self._is_key_healthy(i)), |
|
|
"details": {} |
|
|
}, |
|
|
"models": {}, |
|
|
"cache_size": len(self.cache) |
|
|
} |
|
|
|
|
|
|
|
|
for key_idx in range(len(self.api_keys)): |
|
|
health = self.key_health[key_idx] |
|
|
stats["keys"]["details"][f"key_{key_idx}"] = { |
|
|
"healthy": self._is_key_healthy(key_idx), |
|
|
"last_error": health["last_error"], |
|
|
"retry_after": health["retry_after"].isoformat() if health["retry_after"] else None |
|
|
} |
|
|
|
|
|
|
|
|
for model_name in MODEL_CONFIGS: |
|
|
total_used = 0 |
|
|
for key_idx in range(len(self.api_keys)): |
|
|
usage_queue = self.usage[key_idx][model_name] |
|
|
total_used += sum(1 for t in usage_queue if t > now - 60) |
|
|
|
|
|
|
|
|
per_key_limit = MODEL_CONFIGS[model_name]["rpm"] |
|
|
total_limit = per_key_limit * len(self.api_keys) |
|
|
|
|
|
stats["models"][model_name] = { |
|
|
"used": total_used, |
|
|
"limit": total_limit, |
|
|
"available": total_limit - total_used |
|
|
} |
|
|
|
|
|
return stats |
|
|
|
|
|
|
|
|
|
|
|
router = ModelRouter() |
|
|
|
|
|
|
|
|
|
|
|
async def generate( |
|
|
prompt: str, |
|
|
task_type: str = "default", |
|
|
user_id: Optional[str] = None, |
|
|
use_cache: bool = True |
|
|
) -> str: |
|
|
"""Generate response using model router. |
|
|
|
|
|
Args: |
|
|
prompt: The prompt to send |
|
|
task_type: One of 'chat', 'smart_query', 'documentation', 'synthesis', 'default' |
|
|
user_id: User ID for cache differentiation |
|
|
use_cache: Whether to use response cache |
|
|
|
|
|
Returns: |
|
|
Response text |
|
|
""" |
|
|
response, model = await router.generate(prompt, task_type, user_id, use_cache) |
|
|
return response |
|
|
|
|
|
|
|
|
async def generate_with_info( |
|
|
prompt: str, |
|
|
task_type: str = "default", |
|
|
user_id: Optional[str] = None, |
|
|
use_cache: bool = True |
|
|
) -> tuple[str, str]: |
|
|
"""Generate response and return which model was used. |
|
|
|
|
|
Returns: |
|
|
Tuple of (response_text, model_name) |
|
|
""" |
|
|
return await router.generate(prompt, task_type, user_id, use_cache) |
|
|
|
|
|
|
|
|
def get_model_for_task(task_type: str) -> Optional[str]: |
|
|
"""Get best available model for a task type.""" |
|
|
return router.get_model_for_task(task_type) |
|
|
|
|
|
|
|
|
def get_stats() -> dict: |
|
|
"""Get current router stats.""" |
|
|
return router.get_stats() |
|
|
|