Spaces:
Running
Running
| """ | |
| Unified API Key Manager with automatic failover and rotation. | |
| This module manages multiple API keys for each service and automatically | |
| switches to backup keys when one fails due to rate limiting or errors. | |
| """ | |
| import os | |
| import time | |
| from typing import List, Dict, Optional, Tuple | |
| from dataclasses import dataclass, field | |
| from datetime import datetime, timedelta | |
| import threading | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class APIKeyStatus: | |
| """Tracks the status of an individual API key.""" | |
| key: str | |
| service: str | |
| last_used: Optional[datetime] = None | |
| failure_count: int = 0 | |
| last_failure: Optional[datetime] = None | |
| is_blocked: bool = False | |
| blocked_until: Optional[datetime] = None | |
| total_requests: int = 0 | |
| successful_requests: int = 0 | |
| def mark_success(self): | |
| """Mark a successful API call.""" | |
| self.last_used = datetime.now() | |
| self.total_requests += 1 | |
| self.successful_requests += 1 | |
| self.failure_count = 0 # Reset failure count on success | |
| self.is_blocked = False | |
| self.blocked_until = None | |
| def mark_failure(self, block_duration_minutes: int = 5): | |
| """Mark a failed API call and potentially block the key.""" | |
| self.last_used = datetime.now() | |
| self.last_failure = datetime.now() | |
| self.total_requests += 1 | |
| self.failure_count += 1 | |
| # Block key after 3 consecutive failures | |
| if self.failure_count >= 3: | |
| self.is_blocked = True | |
| self.blocked_until = datetime.now() + timedelta(minutes=block_duration_minutes) | |
| logger.warning(f"API key for {self.service} blocked until {self.blocked_until} after {self.failure_count} failures") | |
| def is_available(self) -> bool: | |
| """Check if this key is available for use.""" | |
| if not self.is_blocked: | |
| return True | |
| # Check if block has expired | |
| if self.blocked_until and datetime.now() > self.blocked_until: | |
| self.is_blocked = False | |
| self.blocked_until = None | |
| self.failure_count = 0 | |
| logger.info(f"API key for {self.service} unblocked after cooldown period") | |
| return True | |
| return False | |
| def get_success_rate(self) -> float: | |
| """Calculate success rate percentage.""" | |
| if self.total_requests == 0: | |
| return 100.0 | |
| return (self.successful_requests / self.total_requests) * 100 | |
| class APIKeyManager: | |
| """ | |
| Manages multiple API keys for different services with automatic failover. | |
| Supports multiple keys per service and automatically rotates to backup keys | |
| when one fails or hits rate limits. | |
| """ | |
| def __init__(self): | |
| self.keys: Dict[str, List[APIKeyStatus]] = {} | |
| self.current_index: Dict[str, int] = {} | |
| self.lock = threading.Lock() | |
| self._load_keys_from_env() | |
| def _load_keys_from_env(self): | |
| """Load API keys from environment variables.""" | |
| # NVIDIA API Keys | |
| nvidia_keys = self._get_keys_from_env('NVIDIA_API_KEY') | |
| if nvidia_keys: | |
| self.register_service('nvidia', nvidia_keys) | |
| # Gemini API Keys | |
| gemini_keys = self._get_keys_from_env('GEMINI_API_KEY') | |
| google_keys = self._get_keys_from_env('GOOGLE_API_KEY') | |
| all_gemini_keys = gemini_keys + google_keys | |
| if all_gemini_keys: | |
| self.register_service('gemini', all_gemini_keys) | |
| # OpenRouter API Keys (for Nova) | |
| openrouter_keys = self._get_keys_from_env('OPENROUTER_API_KEY') | |
| if openrouter_keys: | |
| self.register_service('openrouter', openrouter_keys) | |
| logger.info(f"Loaded API keys: NVIDIA={len(nvidia_keys)}, Gemini={len(all_gemini_keys)}, OpenRouter={len(openrouter_keys)}") | |
| def _get_keys_from_env(self, base_name: str) -> List[str]: | |
| """ | |
| Get API keys from environment variables. | |
| Loads keys in order: | |
| 1. BASE_NAME (as index 0) | |
| 2. BASE_NAME_1, BASE_NAME_2, BASE_NAME_3, etc. (as indices 1, 2, 3...) | |
| Example: | |
| - GEMINI_API_KEY → index 0 | |
| - GEMINI_API_KEY_1 → index 1 | |
| - GEMINI_API_KEY_2 → index 2 | |
| """ | |
| keys = [] | |
| # First, try base key (index 0) | |
| base_key = os.environ.get(base_name) | |
| if base_key: | |
| keys.append(base_key) | |
| # Then try numbered keys (1-10) | |
| for i in range(1, 11): | |
| numbered_key = os.environ.get(f"{base_name}_{i}") | |
| if numbered_key: | |
| keys.append(numbered_key) | |
| # Remove duplicates while preserving order | |
| seen = set() | |
| unique_keys = [] | |
| for key in keys: | |
| if key not in seen: | |
| seen.add(key) | |
| unique_keys.append(key) | |
| return unique_keys | |
| def register_service(self, service: str, api_keys: List[str]): | |
| """Register multiple API keys for a service.""" | |
| with self.lock: | |
| self.keys[service] = [ | |
| APIKeyStatus(key=key, service=service) | |
| for key in api_keys | |
| ] | |
| self.current_index[service] = 0 | |
| logger.info(f"Registered {len(api_keys)} API key(s) for service: {service}") | |
| def get_key(self, service: str) -> Optional[Tuple[str, int]]: | |
| """ | |
| Get an available API key for the specified service. | |
| Returns (api_key, key_index) or (None, -1) if no keys available. | |
| """ | |
| with self.lock: | |
| if service not in self.keys or not self.keys[service]: | |
| logger.warning(f"No API keys registered for service: {service}") | |
| return None, -1 | |
| service_keys = self.keys[service] | |
| start_index = self.current_index[service] | |
| # Try to find an available key, starting from current index | |
| for attempt in range(len(service_keys)): | |
| current_idx = (start_index + attempt) % len(service_keys) | |
| key_status = service_keys[current_idx] | |
| if key_status.is_available(): | |
| self.current_index[service] = current_idx | |
| logger.debug(f"Using API key {current_idx + 1}/{len(service_keys)} for {service}") | |
| return key_status.key, current_idx | |
| # All keys are blocked | |
| logger.error(f"All API keys for {service} are currently blocked or unavailable") | |
| return None, -1 | |
| def mark_success(self, service: str, key_index: int): | |
| """Mark an API call as successful.""" | |
| with self.lock: | |
| if service in self.keys and 0 <= key_index < len(self.keys[service]): | |
| self.keys[service][key_index].mark_success() | |
| logger.debug(f"API key {key_index + 1} for {service} marked as successful") | |
| # Move to next key for load balancing (round-robin) | |
| self.current_index[service] = (key_index + 1) % len(self.keys[service]) | |
| def mark_failure(self, service: str, key_index: int, block_duration_minutes: int = 5): | |
| """Mark an API call as failed and potentially block the key.""" | |
| with self.lock: | |
| if service in self.keys and 0 <= key_index < len(self.keys[service]): | |
| self.keys[service][key_index].mark_failure(block_duration_minutes) | |
| logger.warning(f"API key {key_index + 1} for {service} marked as failed") | |
| # Move to next key immediately | |
| self.current_index[service] = (key_index + 1) % len(self.keys[service]) | |
| def get_service_status(self, service: str) -> Dict: | |
| """Get status information for a service.""" | |
| with self.lock: | |
| if service not in self.keys: | |
| return { | |
| 'service': service, | |
| 'available': False, | |
| 'total_keys': 0, | |
| 'available_keys': 0, | |
| 'blocked_keys': 0 | |
| } | |
| service_keys = self.keys[service] | |
| available_keys = sum(1 for k in service_keys if k.is_available()) | |
| blocked_keys = sum(1 for k in service_keys if k.is_blocked) | |
| return { | |
| 'service': service, | |
| 'available': available_keys > 0, | |
| 'total_keys': len(service_keys), | |
| 'available_keys': available_keys, | |
| 'blocked_keys': blocked_keys, | |
| 'keys': [ | |
| { | |
| 'index': i, | |
| 'is_available': k.is_available(), | |
| 'is_blocked': k.is_blocked, | |
| 'failure_count': k.failure_count, | |
| 'total_requests': k.total_requests, | |
| 'success_rate': round(k.get_success_rate(), 2), | |
| 'blocked_until': k.blocked_until.isoformat() if k.blocked_until else None | |
| } | |
| for i, k in enumerate(service_keys) | |
| ] | |
| } | |
| def get_all_services_status(self) -> Dict[str, Dict]: | |
| """Get status for all registered services.""" | |
| return { | |
| service: self.get_service_status(service) | |
| for service in self.keys.keys() | |
| } | |
| def reset_service(self, service: str): | |
| """Reset all keys for a service (unblock and clear stats).""" | |
| with self.lock: | |
| if service in self.keys: | |
| for key_status in self.keys[service]: | |
| key_status.is_blocked = False | |
| key_status.blocked_until = None | |
| key_status.failure_count = 0 | |
| logger.info(f"Reset all keys for service: {service}") | |
| # Global singleton instance | |
| _api_key_manager = None | |
| def get_api_key_manager() -> APIKeyManager: | |
| """Get the global API key manager instance.""" | |
| global _api_key_manager | |
| if _api_key_manager is None: | |
| _api_key_manager = APIKeyManager() | |
| return _api_key_manager | |