Spaces:
Build error
Build error
| """Model Router for multi-model rotation with rate limiting and caching.""" | |
| import google.generativeai as genai | |
| import time | |
| import hashlib | |
| import os | |
| from datetime import datetime, timedelta | |
| from typing import Optional | |
| from collections import deque | |
| import asyncio | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Cooldown durations in seconds | |
| KEY_COOLDOWN_RATE_LIMIT = 60 # For 429/quota errors | |
| KEY_COOLDOWN_OTHER = 30 # For other transient errors | |
| def _load_api_keys() -> list[str]: | |
| """Load API keys from environment (backward compatible).""" | |
| keys_str = os.getenv("GEMINI_API_KEYS", "") | |
| if keys_str: | |
| return [k.strip() for k in keys_str.split(",") if k.strip()] | |
| single_key = os.getenv("GEMINI_API_KEY") | |
| return [single_key] if single_key else [] | |
| # Model configurations with RPM limits and quality tiers | |
| MODEL_CONFIGS = { | |
| "gemini-2.0-flash": {"rpm": 15, "quality": 1}, | |
| "gemini-2.0-flash-lite": {"rpm": 30, "quality": 2}, | |
| "gemma-3-27b-it": {"rpm": 30, "quality": 3}, | |
| "gemma-3-12b-it": {"rpm": 30, "quality": 4}, | |
| "gemma-3-4b-it": {"rpm": 30, "quality": 5}, | |
| "gemma-3-1b-it": {"rpm": 30, "quality": 6}, | |
| } | |
| # Task type to model priority mapping (lower quality number = better model) | |
| TASK_PRIORITIES = { | |
| "chat": ["gemini-2.0-flash", "gemini-2.0-flash-lite", "gemma-3-27b-it"], | |
| "smart_query": ["gemini-2.0-flash", "gemma-3-27b-it", "gemma-3-12b-it"], | |
| "documentation": ["gemini-2.0-flash-lite", "gemma-3-27b-it", "gemma-3-12b-it"], | |
| "synthesis": ["gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it"], | |
| "default": ["gemini-2.0-flash", "gemini-2.0-flash-lite", "gemma-3-27b-it", | |
| "gemma-3-12b-it", "gemma-3-4b-it", "gemma-3-1b-it"], | |
| } | |
| # Cache TTL in seconds | |
| CACHE_TTL = 300 # 5 minutes | |
| # Retry delay in seconds | |
| RETRY_DELAY = 2.5 | |
| class ModelRouter: | |
| """Manages model rotation, rate limiting, response caching, and multi-key support.""" | |
| def __init__(self): | |
| # Load API keys | |
| self.api_keys = _load_api_keys() | |
| if not self.api_keys: | |
| raise ValueError("No API keys found. Set GEMINI_API_KEYS or GEMINI_API_KEY in .env") | |
| # Key rotation state | |
| self.key_index = 0 | |
| self.key_health: dict[int, dict] = { | |
| i: {"healthy": True, "last_error": None, "retry_after": None} | |
| for i in range(len(self.api_keys)) | |
| } | |
| # Track usage per model per key: {key_idx: {model: deque}} | |
| self.usage: dict[int, dict[str, deque]] = { | |
| i: {model: deque() for model in MODEL_CONFIGS} | |
| for i in range(len(self.api_keys)) | |
| } | |
| # Response cache: {cache_key: {"response": str, "timestamp": datetime, "model": str}} | |
| self.cache: dict[str, dict] = {} | |
| # Initialize with first key (models created on-demand for key rotation) | |
| self._configure_key(0) | |
| self.models: dict[str, genai.GenerativeModel] = { | |
| model: genai.GenerativeModel(model) for model in MODEL_CONFIGS | |
| } | |
| def _configure_key(self, key_idx: int): | |
| """Configure genai with the specified API key.""" | |
| genai.configure(api_key=self.api_keys[key_idx]) | |
| def _is_key_healthy(self, key_idx: int) -> bool: | |
| """Check if a key is healthy (not in cooldown).""" | |
| health = self.key_health[key_idx] | |
| if not health["healthy"] and health["retry_after"]: | |
| if datetime.now() > health["retry_after"]: | |
| health["healthy"] = True | |
| health["last_error"] = None | |
| health["retry_after"] = None | |
| return health["healthy"] | |
| def _mark_key_unhealthy(self, key_idx: int, error: Exception, cooldown_seconds: int): | |
| """Mark a key as unhealthy with cooldown.""" | |
| self.key_health[key_idx] = { | |
| "healthy": False, | |
| "last_error": str(error), | |
| "retry_after": datetime.now() + timedelta(seconds=cooldown_seconds) | |
| } | |
| def _get_next_key(self) -> tuple[int, str]: | |
| """Get next healthy API key using round-robin.""" | |
| num_keys = len(self.api_keys) | |
| # Try each key once | |
| for _ in range(num_keys): | |
| idx = self.key_index % num_keys | |
| self.key_index += 1 | |
| if self._is_key_healthy(idx): | |
| return idx, self.api_keys[idx] | |
| # All keys unhealthy - find the one with earliest retry_after | |
| earliest_idx = 0 | |
| earliest_time = datetime.max | |
| for idx, health in self.key_health.items(): | |
| if health["retry_after"] and health["retry_after"] < earliest_time: | |
| earliest_time = health["retry_after"] | |
| earliest_idx = idx | |
| # Reset that key and use it | |
| self.key_health[earliest_idx]["healthy"] = True | |
| return earliest_idx, self.api_keys[earliest_idx] | |
| def _get_model_with_key(self, model_name: str, key_idx: int) -> genai.GenerativeModel: | |
| """Get a model instance configured with the specified key.""" | |
| self._configure_key(key_idx) | |
| return genai.GenerativeModel(model_name) | |
| def _get_cache_key(self, task_type: str, user_id: Optional[str], prompt: str) -> str: | |
| """Generate cache key from task type, user, and prompt.""" | |
| # Use first 200 chars of prompt to keep keys reasonable | |
| key_string = f"{task_type}:{user_id or 'anon'}:{prompt[:200]}" | |
| return hashlib.md5(key_string.encode()).hexdigest() | |
| def _check_cache(self, cache_key: str) -> Optional[str]: | |
| """Check if response is cached and not expired.""" | |
| if cache_key in self.cache: | |
| entry = self.cache[cache_key] | |
| if datetime.now() - entry["timestamp"] < timedelta(seconds=CACHE_TTL): | |
| return entry["response"] | |
| else: | |
| # Expired, remove it | |
| del self.cache[cache_key] | |
| return None | |
| def _store_cache(self, cache_key: str, response: str, model_used: str): | |
| """Store response in cache.""" | |
| self.cache[cache_key] = { | |
| "response": response, | |
| "timestamp": datetime.now(), | |
| "model": model_used | |
| } | |
| # Clean old cache entries periodically (every 100 entries) | |
| if len(self.cache) > 100: | |
| self._clean_cache() | |
| def _clean_cache(self): | |
| """Remove expired cache entries.""" | |
| now = datetime.now() | |
| expired_keys = [ | |
| key for key, entry in self.cache.items() | |
| if now - entry["timestamp"] >= timedelta(seconds=CACHE_TTL) | |
| ] | |
| for key in expired_keys: | |
| del self.cache[key] | |
| def _check_rate_limit(self, model_name: str, key_idx: int = 0) -> bool: | |
| """Check if model is within rate limit for a specific key. Returns True if OK to use.""" | |
| config = MODEL_CONFIGS[model_name] | |
| rpm_limit = config["rpm"] | |
| usage_queue = self.usage[key_idx][model_name] | |
| # Remove timestamps older than 60 seconds | |
| now = time.time() | |
| while usage_queue and usage_queue[0] < now - 60: | |
| usage_queue.popleft() | |
| # Check if under limit | |
| return len(usage_queue) < rpm_limit | |
| def _record_usage(self, model_name: str, key_idx: int = 0): | |
| """Record a usage for rate limiting.""" | |
| self.usage[key_idx][model_name].append(time.time()) | |
| def get_model_for_task(self, task_type: str) -> Optional[str]: | |
| """Get the best available model for a task type (checks all keys).""" | |
| priorities = TASK_PRIORITIES.get(task_type, TASK_PRIORITIES["default"]) | |
| # Check across all healthy keys | |
| for key_idx in range(len(self.api_keys)): | |
| if not self._is_key_healthy(key_idx): | |
| continue | |
| for model_name in priorities: | |
| if self._check_rate_limit(model_name, key_idx): | |
| return model_name | |
| # All preferred models at limit, try any available model on any key | |
| for key_idx in range(len(self.api_keys)): | |
| if not self._is_key_healthy(key_idx): | |
| continue | |
| for model_name in MODEL_CONFIGS: | |
| if self._check_rate_limit(model_name, key_idx): | |
| return model_name | |
| return None | |
| async def generate( | |
| self, | |
| prompt: str, | |
| task_type: str = "default", | |
| user_id: Optional[str] = None, | |
| use_cache: bool = True | |
| ) -> tuple[str, str]: | |
| """Generate response with model rotation, key rotation, and caching. | |
| Args: | |
| prompt: The prompt to send to the model | |
| task_type: Type of task (chat, smart_query, documentation, synthesis) | |
| user_id: User ID for cache key differentiation | |
| use_cache: Whether to use caching (default True) | |
| Returns: | |
| Tuple of (response_text, model_used) | |
| """ | |
| # Check cache first | |
| if use_cache: | |
| cache_key = self._get_cache_key(task_type, user_id, prompt) | |
| cached = self._check_cache(cache_key) | |
| if cached: | |
| return cached, "cache" | |
| # Get prioritized models for this task | |
| priorities = TASK_PRIORITIES.get(task_type, TASK_PRIORITIES["default"]) | |
| all_models = list(priorities) + [m for m in MODEL_CONFIGS if m not in priorities] | |
| last_error = None | |
| tried_combinations = set() | |
| # Try each key/model combination | |
| max_attempts = len(self.api_keys) * len(all_models) | |
| for _ in range(max_attempts): | |
| # Get next healthy key | |
| key_idx, api_key = self._get_next_key() | |
| for model_name in all_models: | |
| combo = (key_idx, model_name) | |
| if combo in tried_combinations: | |
| continue | |
| # Check rate limit for this key/model | |
| if not self._check_rate_limit(model_name, key_idx): | |
| continue | |
| tried_combinations.add(combo) | |
| try: | |
| # Get model with this key | |
| model = self._get_model_with_key(model_name, key_idx) | |
| self._record_usage(model_name, key_idx) | |
| response = model.generate_content(prompt) | |
| response_text = response.text | |
| # Cache the response | |
| if use_cache: | |
| self._store_cache(cache_key, response_text, model_name) | |
| return response_text, model_name | |
| except Exception as e: | |
| error_str = str(e).lower() | |
| last_error = e | |
| # Determine cooldown based on error type | |
| if "429" in str(e) or "resource exhausted" in error_str or "quota" in error_str: | |
| # Rate limit - mark key unhealthy, wait briefly, try next | |
| self._mark_key_unhealthy(key_idx, e, KEY_COOLDOWN_RATE_LIMIT) | |
| await asyncio.sleep(RETRY_DELAY) | |
| break # Try next key | |
| elif "401" in str(e) or "403" in str(e) or "invalid" in error_str: | |
| # Auth error - mark key permanently unhealthy | |
| self._mark_key_unhealthy(key_idx, e, 86400) # 24 hours | |
| break # Try next key | |
| else: | |
| # Other error - short cooldown, try next model | |
| await asyncio.sleep(0.5) | |
| continue | |
| # All combinations exhausted | |
| if last_error: | |
| raise Exception(f"All models/keys exhausted. Last error: {last_error}") | |
| else: | |
| raise Exception("All models are rate limited. Please try again in a minute.") | |
| async def generate_with_model( | |
| self, | |
| model_name: str, | |
| prompt: str, | |
| user_id: Optional[str] = None, | |
| use_cache: bool = True | |
| ) -> str: | |
| """Generate with a specific model (for chat sessions that need consistency). | |
| Falls back to other models if specified model is rate limited. | |
| """ | |
| response, _ = await self.generate( | |
| prompt=prompt, | |
| task_type="default", | |
| user_id=user_id, | |
| use_cache=use_cache | |
| ) | |
| return response | |
| def get_stats(self) -> dict: | |
| """Get current usage stats for monitoring.""" | |
| now = time.time() | |
| stats = { | |
| "keys": { | |
| "total": len(self.api_keys), | |
| "healthy": sum(1 for i in range(len(self.api_keys)) if self._is_key_healthy(i)), | |
| "details": {} | |
| }, | |
| "models": {}, | |
| "cache_size": len(self.cache) | |
| } | |
| # Per-key stats | |
| for key_idx in range(len(self.api_keys)): | |
| health = self.key_health[key_idx] | |
| stats["keys"]["details"][f"key_{key_idx}"] = { | |
| "healthy": self._is_key_healthy(key_idx), | |
| "last_error": health["last_error"], | |
| "retry_after": health["retry_after"].isoformat() if health["retry_after"] else None | |
| } | |
| # Aggregate model usage across all keys | |
| for model_name in MODEL_CONFIGS: | |
| total_used = 0 | |
| for key_idx in range(len(self.api_keys)): | |
| usage_queue = self.usage[key_idx][model_name] | |
| total_used += sum(1 for t in usage_queue if t > now - 60) | |
| # Limit is per-key, so total limit = per_key_limit * num_keys | |
| per_key_limit = MODEL_CONFIGS[model_name]["rpm"] | |
| total_limit = per_key_limit * len(self.api_keys) | |
| stats["models"][model_name] = { | |
| "used": total_used, | |
| "limit": total_limit, | |
| "available": total_limit - total_used | |
| } | |
| return stats | |
| # Global router instance | |
| router = ModelRouter() | |
| # Convenience functions | |
| async def generate( | |
| prompt: str, | |
| task_type: str = "default", | |
| user_id: Optional[str] = None, | |
| use_cache: bool = True | |
| ) -> str: | |
| """Generate response using model router. | |
| Args: | |
| prompt: The prompt to send | |
| task_type: One of 'chat', 'smart_query', 'documentation', 'synthesis', 'default' | |
| user_id: User ID for cache differentiation | |
| use_cache: Whether to use response cache | |
| Returns: | |
| Response text | |
| """ | |
| response, model = await router.generate(prompt, task_type, user_id, use_cache) | |
| return response | |
| async def generate_with_info( | |
| prompt: str, | |
| task_type: str = "default", | |
| user_id: Optional[str] = None, | |
| use_cache: bool = True | |
| ) -> tuple[str, str]: | |
| """Generate response and return which model was used. | |
| Returns: | |
| Tuple of (response_text, model_name) | |
| """ | |
| return await router.generate(prompt, task_type, user_id, use_cache) | |
| def get_model_for_task(task_type: str) -> Optional[str]: | |
| """Get best available model for a task type.""" | |
| return router.get_model_for_task(task_type) | |
| def get_stats() -> dict: | |
| """Get current router stats.""" | |
| return router.get_stats() | |