Spaces:
Sleeping
Sleeping
| """ | |
| API Key Manager - Round-robin management for multiple Gemini API keys. | |
| Selects the least-used key to avoid rate limiting. | |
| Tracks usage per key index (not the actual key for security). | |
| """ | |
| import os | |
| import logging | |
| from datetime import datetime | |
| from typing import Optional, List, Tuple | |
| from sqlalchemy import select | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| logger = logging.getLogger(__name__) | |
| # Cache for API keys (loaded once from env) | |
| _api_keys: Optional[List[str]] = None | |
| def get_api_keys() -> List[str]: | |
| """Load API keys from GEMINI_API_KEYS environment variable.""" | |
| global _api_keys | |
| if _api_keys is None: | |
| keys_str = os.getenv("GEMINI_API_KEYS", "") | |
| if not keys_str: | |
| # Fallback to single key | |
| single_key = os.getenv("GEMINI_API_KEY", "") | |
| if single_key: | |
| _api_keys = [single_key] | |
| else: | |
| _api_keys = [] # Return empty list if no keys configured | |
| else: | |
| _api_keys = [k.strip() for k in keys_str.split(",") if k.strip()] | |
| if _api_keys: | |
| logger.info(f"Loaded {len(_api_keys)} API key(s)") | |
| else: | |
| logger.warning("Unable to authenticate GEMINI.") | |
| return _api_keys | |
| def get_key_count() -> int: | |
| """Get the number of available API keys.""" | |
| return len(get_api_keys()) | |
| async def get_least_used_key(db: AsyncSession) -> Tuple[int, str]: | |
| """ | |
| Get the API key with least requests (round-robin style). | |
| Returns: | |
| Tuple of (key_index, api_key) | |
| Raises: | |
| ValueError: If no API keys are configured | |
| """ | |
| from core.models import ApiKeyUsage | |
| keys = get_api_keys() | |
| if not keys: | |
| raise ValueError("No API keys configured. Set GEMINI_API_KEYS or GEMINI_API_KEY in environment.") | |
| # Get all usage stats | |
| query = select(ApiKeyUsage).order_by(ApiKeyUsage.total_requests) | |
| result = await db.execute(query) | |
| usages = {u.key_index: u for u in result.scalars().all()} | |
| # Find the key with least usage | |
| min_requests = float('inf') | |
| selected_index = 0 | |
| for i in range(len(keys)): | |
| if i in usages: | |
| if usages[i].total_requests < min_requests: | |
| min_requests = usages[i].total_requests | |
| selected_index = i | |
| else: | |
| # Key not in DB yet - create it and use it (0 requests) | |
| # Note: Don't commit here - caller handles transaction | |
| new_usage = ApiKeyUsage(key_index=i, total_requests=0, success_count=0, failure_count=0) | |
| db.add(new_usage) | |
| selected_index = i | |
| break | |
| logger.debug(f"Selected API key index {selected_index} (least used)") | |
| return selected_index, keys[selected_index] | |
| async def record_usage(db: AsyncSession, key_index: int, success: bool, error_message: Optional[str] = None): | |
| """ | |
| Record API key usage after a request. | |
| Args: | |
| db: Database session | |
| key_index: Index of the key used | |
| success: Whether the request succeeded | |
| error_message: Error message if request failed | |
| """ | |
| from core.models import ApiKeyUsage | |
| query = select(ApiKeyUsage).where(ApiKeyUsage.key_index == key_index) | |
| result = await db.execute(query) | |
| usage = result.scalar_one_or_none() | |
| if not usage: | |
| usage = ApiKeyUsage(key_index=key_index, total_requests=0, success_count=0, failure_count=0) | |
| db.add(usage) | |
| usage.total_requests += 1 | |
| if success: | |
| usage.success_count += 1 | |
| else: | |
| usage.failure_count += 1 | |
| if error_message: | |
| usage.last_error = error_message[:1000] # Truncate to 1000 chars | |
| usage.last_used_at = datetime.utcnow() | |
| # Note: Don't commit here - caller handles transaction atomically | |
| logger.debug(f"Recorded {'success' if success else 'failure'} for key index {key_index}") | |
| async def get_all_usage_stats(db: AsyncSession) -> List[dict]: | |
| """ | |
| Get usage stats for all API keys. | |
| Returns: | |
| List of dicts with key_index, total_requests, success_count, failure_count | |
| """ | |
| from core.models import ApiKeyUsage | |
| keys = get_api_keys() | |
| query = select(ApiKeyUsage).order_by(ApiKeyUsage.key_index) | |
| result = await db.execute(query) | |
| usages = {u.key_index: u for u in result.scalars().all()} | |
| stats = [] | |
| for i in range(len(keys)): | |
| if i in usages: | |
| u = usages[i] | |
| stats.append({ | |
| "key_index": i, | |
| "total_requests": u.total_requests, | |
| "success_count": u.success_count, | |
| "failure_count": u.failure_count, | |
| "last_error": u.last_error, | |
| "last_used_at": u.last_used_at.isoformat() if u.last_used_at else None | |
| }) | |
| else: | |
| stats.append({ | |
| "key_index": i, | |
| "total_requests": 0, | |
| "success_count": 0, | |
| "failure_count": 0, | |
| "last_error": None, | |
| "last_used_at": None | |
| }) | |
| return stats | |