Report-Generator / api_key_manager.py
Jaimodiji's picture
Upload folder using huggingface_hub
92a22cd verified
"""
Unified API Key Manager with automatic failover and rotation.
This module manages multiple API keys for each service and automatically
switches to backup keys when one fails due to rate limiting or errors.
"""
import os
import time
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import threading
import logging
logger = logging.getLogger(__name__)
@dataclass
class APIKeyStatus:
"""Tracks the status of an individual API key."""
key: str
service: str
last_used: Optional[datetime] = None
failure_count: int = 0
last_failure: Optional[datetime] = None
is_blocked: bool = False
blocked_until: Optional[datetime] = None
total_requests: int = 0
successful_requests: int = 0
def mark_success(self):
"""Mark a successful API call."""
self.last_used = datetime.now()
self.total_requests += 1
self.successful_requests += 1
self.failure_count = 0 # Reset failure count on success
self.is_blocked = False
self.blocked_until = None
def mark_failure(self, block_duration_minutes: int = 5):
"""Mark a failed API call and potentially block the key."""
self.last_used = datetime.now()
self.last_failure = datetime.now()
self.total_requests += 1
self.failure_count += 1
# Block key after 3 consecutive failures
if self.failure_count >= 3:
self.is_blocked = True
self.blocked_until = datetime.now() + timedelta(minutes=block_duration_minutes)
logger.warning(f"API key for {self.service} blocked until {self.blocked_until} after {self.failure_count} failures")
def is_available(self) -> bool:
"""Check if this key is available for use."""
if not self.is_blocked:
return True
# Check if block has expired
if self.blocked_until and datetime.now() > self.blocked_until:
self.is_blocked = False
self.blocked_until = None
self.failure_count = 0
logger.info(f"API key for {self.service} unblocked after cooldown period")
return True
return False
def get_success_rate(self) -> float:
"""Calculate success rate percentage."""
if self.total_requests == 0:
return 100.0
return (self.successful_requests / self.total_requests) * 100
class APIKeyManager:
"""
Manages multiple API keys for different services with automatic failover.
Supports multiple keys per service and automatically rotates to backup keys
when one fails or hits rate limits.
"""
def __init__(self):
self.keys: Dict[str, List[APIKeyStatus]] = {}
self.current_index: Dict[str, int] = {}
self.lock = threading.Lock()
self._load_keys_from_env()
def _load_keys_from_env(self):
"""Load API keys from environment variables."""
# NVIDIA API Keys
nvidia_keys = self._get_keys_from_env('NVIDIA_API_KEY')
if nvidia_keys:
self.register_service('nvidia', nvidia_keys)
# Gemini API Keys
gemini_keys = self._get_keys_from_env('GEMINI_API_KEY')
google_keys = self._get_keys_from_env('GOOGLE_API_KEY')
all_gemini_keys = gemini_keys + google_keys
if all_gemini_keys:
self.register_service('gemini', all_gemini_keys)
# OpenRouter API Keys (for Nova)
openrouter_keys = self._get_keys_from_env('OPENROUTER_API_KEY')
if openrouter_keys:
self.register_service('openrouter', openrouter_keys)
logger.info(f"Loaded API keys: NVIDIA={len(nvidia_keys)}, Gemini={len(all_gemini_keys)}, OpenRouter={len(openrouter_keys)}")
def _get_keys_from_env(self, base_name: str) -> List[str]:
"""
Get API keys from environment variables.
Loads keys in order:
1. BASE_NAME (as index 0)
2. BASE_NAME_1, BASE_NAME_2, BASE_NAME_3, etc. (as indices 1, 2, 3...)
Example:
- GEMINI_API_KEY → index 0
- GEMINI_API_KEY_1 → index 1
- GEMINI_API_KEY_2 → index 2
"""
keys = []
# First, try base key (index 0)
base_key = os.environ.get(base_name)
if base_key:
keys.append(base_key)
# Then try numbered keys (1-10)
for i in range(1, 11):
numbered_key = os.environ.get(f"{base_name}_{i}")
if numbered_key:
keys.append(numbered_key)
# Remove duplicates while preserving order
seen = set()
unique_keys = []
for key in keys:
if key not in seen:
seen.add(key)
unique_keys.append(key)
return unique_keys
def register_service(self, service: str, api_keys: List[str]):
"""Register multiple API keys for a service."""
with self.lock:
self.keys[service] = [
APIKeyStatus(key=key, service=service)
for key in api_keys
]
self.current_index[service] = 0
logger.info(f"Registered {len(api_keys)} API key(s) for service: {service}")
def get_key(self, service: str) -> Optional[Tuple[str, int]]:
"""
Get an available API key for the specified service.
Returns (api_key, key_index) or (None, -1) if no keys available.
"""
with self.lock:
if service not in self.keys or not self.keys[service]:
logger.warning(f"No API keys registered for service: {service}")
return None, -1
service_keys = self.keys[service]
start_index = self.current_index[service]
# Try to find an available key, starting from current index
for attempt in range(len(service_keys)):
current_idx = (start_index + attempt) % len(service_keys)
key_status = service_keys[current_idx]
if key_status.is_available():
self.current_index[service] = current_idx
logger.debug(f"Using API key {current_idx + 1}/{len(service_keys)} for {service}")
return key_status.key, current_idx
# All keys are blocked
logger.error(f"All API keys for {service} are currently blocked or unavailable")
return None, -1
def mark_success(self, service: str, key_index: int):
"""Mark an API call as successful."""
with self.lock:
if service in self.keys and 0 <= key_index < len(self.keys[service]):
self.keys[service][key_index].mark_success()
logger.debug(f"API key {key_index + 1} for {service} marked as successful")
# Move to next key for load balancing (round-robin)
self.current_index[service] = (key_index + 1) % len(self.keys[service])
def mark_failure(self, service: str, key_index: int, block_duration_minutes: int = 5):
"""Mark an API call as failed and potentially block the key."""
with self.lock:
if service in self.keys and 0 <= key_index < len(self.keys[service]):
self.keys[service][key_index].mark_failure(block_duration_minutes)
logger.warning(f"API key {key_index + 1} for {service} marked as failed")
# Move to next key immediately
self.current_index[service] = (key_index + 1) % len(self.keys[service])
def get_service_status(self, service: str) -> Dict:
"""Get status information for a service."""
with self.lock:
if service not in self.keys:
return {
'service': service,
'available': False,
'total_keys': 0,
'available_keys': 0,
'blocked_keys': 0
}
service_keys = self.keys[service]
available_keys = sum(1 for k in service_keys if k.is_available())
blocked_keys = sum(1 for k in service_keys if k.is_blocked)
return {
'service': service,
'available': available_keys > 0,
'total_keys': len(service_keys),
'available_keys': available_keys,
'blocked_keys': blocked_keys,
'keys': [
{
'index': i,
'is_available': k.is_available(),
'is_blocked': k.is_blocked,
'failure_count': k.failure_count,
'total_requests': k.total_requests,
'success_rate': round(k.get_success_rate(), 2),
'blocked_until': k.blocked_until.isoformat() if k.blocked_until else None
}
for i, k in enumerate(service_keys)
]
}
def get_all_services_status(self) -> Dict[str, Dict]:
"""Get status for all registered services."""
return {
service: self.get_service_status(service)
for service in self.keys.keys()
}
def reset_service(self, service: str):
"""Reset all keys for a service (unblock and clear stats)."""
with self.lock:
if service in self.keys:
for key_status in self.keys[service]:
key_status.is_blocked = False
key_status.blocked_until = None
key_status.failure_count = 0
logger.info(f"Reset all keys for service: {service}")
# Global singleton instance
_api_key_manager = None
def get_api_key_manager() -> APIKeyManager:
"""Get the global API key manager instance."""
global _api_key_manager
if _api_key_manager is None:
_api_key_manager = APIKeyManager()
return _api_key_manager