Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from typing import List, Optional | |
| import asyncio | |
| from dataclasses import dataclass | |
| from datetime import datetime, timedelta | |
| from dotenv import load_dotenv | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables from .env file | |
| env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.env') | |
| load_dotenv(dotenv_path=env_path) | |
| # Debug: Print environment variables | |
| logger.info(f"Current working directory: {os.getcwd()}") | |
| logger.info(f"Loading .env from: {env_path}") | |
| logger.info(f"GEMINI_API_KEY: {'*' * 8 + os.getenv('GEMINI_API_KEY', '')[-4:] if os.getenv('GEMINI_API_KEY') else 'Not set'}") | |
| logger.info(f"GEMINI_API_KEYS: {'*' * 8 + os.getenv('GEMINI_API_KEYS', '')[-4:] if os.getenv('GEMINI_API_KEYS') else 'Not set'}") | |
| class APIKey: | |
| key: str | |
| last_used: Optional[datetime] = None | |
| is_available: bool = True | |
| rate_limit_reset: Optional[datetime] = None | |
| class APIKeyManager: | |
| _instance = None | |
| _lock = asyncio.Lock() | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(APIKeyManager, cls).__new__(cls) | |
| cls._instance._initialize() | |
| return cls._instance | |
| def _initialize(self): | |
| self.keys: List[APIKey] = [] | |
| self._current_index = 0 | |
| self._load_api_keys() | |
| def _load_api_keys(self): | |
| # Try to load from GEMINI_API_KEY first | |
| single_key = os.getenv('GEMINI_API_KEY', '').strip() | |
| if single_key: | |
| single_key = single_key.strip('"\'') | |
| self.keys = [APIKey(key=single_key)] | |
| logger.info(f"Loaded 1 API key from GEMINI_API_KEY") | |
| return | |
| # Fall back to GEMINI_API_KEYS if GEMINI_API_KEY is not set | |
| api_keys_str = os.getenv('GEMINI_API_KEYS', '').strip() | |
| if api_keys_str: | |
| keys = [key.strip().strip('"\'') for key in api_keys_str.split(',') if key.strip()] | |
| self.keys = [APIKey(key=key) for key in keys] | |
| logger.info(f"Loaded {len(keys)} API keys from GEMINI_API_KEYS") | |
| return | |
| logger.warning("No API keys found in environment variables") | |
| def get_available_key(self) -> Optional[str]: | |
| """Get an available API key, considering rate limits.""" | |
| now = datetime.utcnow() | |
| for key_obj in self.keys: | |
| if not key_obj.is_available: | |
| if key_obj.rate_limit_reset and now >= key_obj.rate_limit_reset: | |
| key_obj.is_available = True | |
| key_obj.rate_limit_reset = None | |
| else: | |
| continue | |
| key_obj.last_used = now | |
| return key_obj.key | |
| return None | |
| def mark_key_unavailable(self, key: str, retry_after_seconds: int = 60): | |
| """Mark a key as unavailable due to rate limiting.""" | |
| for key_obj in self.keys: | |
| if key_obj.key == key: | |
| key_obj.is_available = False | |
| key_obj.rate_limit_reset = datetime.utcnow() + timedelta(seconds=retry_after_seconds) | |
| logger.warning(f"Rate limit hit for API key. Will retry after {retry_after_seconds} seconds") | |
| return | |
| logger.warning(f"Tried to mark unknown API key as unavailable") | |
| api_key_manager = APIKeyManager() |