AI-Agent-Book / utils /api_manager.py
Cuong2004's picture
init project
ded29b0
import os
import time
from typing import List, Optional
from dotenv import load_dotenv
import logging
import sys
# Configure logging to stdout only
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# API key management constants
KEY_COOLDOWN = 60 # Seconds to wait before retrying a failed key
class ApiKeyManager:
"""
Manages the rotation and cooldown of API keys to prevent rate limiting issues
"""
def __init__(self, api_keys: Optional[List[str]] = None):
"""
Initialize the API key manager with a list of API keys
Args:
api_keys: List of API keys. If None, keys will be loaded from environment variables
"""
if api_keys is None:
# Load API keys from environment variable
env_keys = os.environ.get('GEMINI_API_KEYS', '')
if not env_keys:
logger.error("No API keys found in environment variables")
raise ValueError("No API keys found in environment variables")
api_keys = [key.strip() for key in env_keys.split(',') if key.strip()]
logger.info(f"Loaded {len(api_keys)} API keys from environment variables")
self.api_keys = api_keys
self.current_index = 0
self.failed_keys = {} # Maps key to timestamp of failure
if not self.api_keys:
logger.error("No valid API keys provided")
raise ValueError("No API keys provided. Please add at least one valid API key.")
# Log the first few characters of each key for debugging
for i, key in enumerate(self.api_keys):
masked_key = key[:8] + "..." if len(key) > 8 else "..."
logger.info(f"API key {i+1}: {masked_key}")
logger.info(f"ApiKeyManager initialized with {len(self.api_keys)} keys")
def get_next_api_key(self) -> str:
"""
Get the next working API key in rotation
Returns:
str: The next available API key
Raises:
Exception: If no working API keys are available
"""
current_time = time.time()
# Check for keys that have completed their cooldown period
for key in list(self.failed_keys.keys()):
if current_time - self.failed_keys[key] > KEY_COOLDOWN:
masked_key = key[:8] + "..." if len(key) > 8 else "..."
logger.info(f"API key {masked_key} has cooled down after {KEY_COOLDOWN} seconds and is being retried")
del self.failed_keys[key]
# If all keys are in cooldown
if len(self.failed_keys) == len(self.api_keys):
logger.warning("All API keys are currently in cooldown. Waiting for the first key to be available again...")
# Find the key with the earliest cooldown
earliest_key = min(self.failed_keys, key=self.failed_keys.get)
wait_time = KEY_COOLDOWN - (current_time - self.failed_keys[earliest_key]) + 1
wait_time = max(1, int(wait_time))
logger.info(f"Waiting {wait_time} seconds for cooldown...")
time.sleep(wait_time)
# Remove this key from the failed keys to retry it
masked_key = earliest_key[:8] + "..." if len(earliest_key) > 8 else "..."
logger.info(f"Cooldown complete for {masked_key}, removing from failed keys")
del self.failed_keys[earliest_key]
return earliest_key
# Find the next key not in cooldown
start_index = self.current_index
while True:
api_key = self.api_keys[self.current_index]
self.current_index = (self.current_index + 1) % len(self.api_keys)
if api_key not in self.failed_keys:
masked_key = api_key[:8] + "..." if len(api_key) > 8 else "..."
logger.info(f"Using API key: {masked_key}")
return api_key
# If we've checked all keys and none are available
if self.current_index == start_index:
# This shouldn't happen because we've already checked above
logger.error("Could not find an available API key, which is unexpected")
raise ValueError("Could not find an available API key, which is unexpected")
def mark_key_as_failed(self, api_key: str) -> None:
"""
Mark an API key as temporarily unavailable
Args:
api_key: The API key to mark as failed
"""
self.failed_keys[api_key] = time.time()
masked_key = api_key[:8] + "..." if len(api_key) > 8 else "..."
logger.warning(f"API key {masked_key} marked as temporarily unavailable")
available_keys = len(self.api_keys) - len(self.failed_keys)
logger.info(f"{available_keys} keys immediately available. {len(self.failed_keys)} keys in cooldown.")
# If we're running out of keys, log a more urgent warning
if available_keys <= 1:
logger.warning(f"WARNING: Only {available_keys} API keys remaining. Consider adding more keys.")
def has_working_keys(self) -> bool:
"""
Check if there are any working keys left
Returns:
bool: True if there are working keys, False otherwise
"""
available_keys = len(self.api_keys) - len(self.failed_keys)
logger.info(f"Checking API key availability: {available_keys} keys available")
return available_keys > 0
def get_api_keys_status(self) -> dict:
"""
Get status information about API keys
Returns:
dict: Status information about the API keys
"""
current_time = time.time()
available_keys = []
cooldown_keys = []
for i, key in enumerate(self.api_keys):
masked_key = key[:8] + "..." if len(key) > 8 else "..."
if key in self.failed_keys:
cooldown_remaining = max(0, KEY_COOLDOWN - (current_time - self.failed_keys[key]))
cooldown_keys.append({
"key_id": i,
"masked_key": masked_key,
"cooldown_remaining_seconds": int(cooldown_remaining)
})
else:
available_keys.append({
"key_id": i,
"masked_key": masked_key
})
return {
"total_keys": len(self.api_keys),
"available_keys": len(available_keys),
"cooldown_keys": len(cooldown_keys),
"available_key_details": available_keys,
"cooldown_key_details": cooldown_keys
}