Spaces:
Sleeping
Sleeping
| """ | |
| API Key Middleware - Automatic key selection and rotation | |
| Automatically selects and injects Gemini API keys for requests. | |
| Handles quota errors with automatic key rotation and retry. | |
| """ | |
| import time | |
| import logging | |
| from datetime import datetime, timedelta | |
| from typing import Optional, Dict | |
| from fastapi import Request, Response | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.types import ASGIApp | |
| from core.database import async_session_maker | |
| from services.gemini_service.api_key_config import APIKeyServiceConfig | |
| logger = logging.getLogger(__name__) | |
| # Track key cooldowns in memory | |
| _key_cooldowns: Dict[int, datetime] = {} | |
| class APIKeyMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Middleware for automatic API key management. | |
| Features: | |
| - Automatic key selection based on strategy | |
| - Quota error detection and recovery | |
| - Key cooldown management | |
| - Usage tracking | |
| """ | |
| def __init__(self, app: ASGIApp): | |
| super().__init__(app) | |
| async def dispatch(self, request: Request, call_next): | |
| """ | |
| Process request with automatic API key injection. | |
| Flow: | |
| 1. Check if Gemini request | |
| 2. Select best available key | |
| 3. Inject into request state | |
| 4. Handle response (quota errors) | |
| """ | |
| # Only handle Gemini requests | |
| if not self._is_gemini_request(request): | |
| return await call_next(request) | |
| # Select API key | |
| try: | |
| key_index, api_key = await self._select_api_key() | |
| request.state.gemini_api_key = api_key | |
| request.state.gemini_key_index = key_index | |
| except ValueError as e: | |
| # No keys available | |
| logger.error(f"No API keys available: {e}") | |
| return Response( | |
| content=f'{{"detail": "{str(e)}"}}', | |
| status_code=503, | |
| media_type="application/json" | |
| ) | |
| # Process request | |
| response = await call_next(request) | |
| # Handle quota errors | |
| if response.status_code == 429 and APIKeyServiceConfig._retry_on_quota_error: | |
| logger.warning(f"Quota error on key {key_index}, attempting retry") | |
| # Mark key in cooldown | |
| self._mark_cooldown(key_index) | |
| # Try to select different key | |
| try: | |
| key_index, api_key = await self._select_api_key(exclude_index=key_index) | |
| request.state.gemini_api_key = api_key | |
| request.state.gemini_key_index = key_index | |
| # Retry request | |
| logger.info(f"Retrying with key {key_index}") | |
| response = await call_next(request) | |
| except ValueError: | |
| # No other keys available | |
| logger.error("All API keys in cooldown or exhausted") | |
| # Track usage | |
| success = response.status_code < 400 | |
| await self._track_usage(key_index, success, response.status_code) | |
| return response | |
| def _is_gemini_request(self, request: Request) -> bool: | |
| """Check if request is for Gemini service.""" | |
| path = request.url.path | |
| gemini_paths = ["/gemini/", "/api/gemini"] | |
| return any(path.startswith(p) for p in gemini_paths) | |
| async def _select_api_key(self, exclude_index: Optional[int] = None) -> tuple[int, str]: | |
| """ | |
| Select best available API key. | |
| Args: | |
| exclude_index: Key index to exclude (e.g., after quota error) | |
| Returns: | |
| Tuple of (key_index, api_key) | |
| Raises: | |
| ValueError: If no keys available | |
| """ | |
| keys = APIKeyServiceConfig.get_api_keys() | |
| if not keys: | |
| raise ValueError("No API keys configured") | |
| # Filter out excluded and cooldown keys | |
| available_indices = [] | |
| for i in range(len(keys)): | |
| if i == exclude_index: | |
| continue | |
| if self._is_in_cooldown(i): | |
| continue | |
| available_indices.append(i) | |
| if not available_indices: | |
| raise ValueError("All API keys in cooldown") | |
| # Select based on strategy | |
| if APIKeyServiceConfig._rotation_strategy == "round_robin": | |
| # Simple round-robin | |
| selected_index = available_indices[0] | |
| else: # least_used | |
| # Get usage stats from DB | |
| async with async_session_maker() as db: | |
| from services.api_key_manager import get_least_used_key | |
| try: | |
| selected_index, _ = await get_least_used_key(db) | |
| if selected_index not in available_indices: | |
| # Fallback to first available | |
| selected_index = available_indices[0] | |
| except Exception as e: | |
| logger.error(f"Error getting least used key: {e}") | |
| selected_index = available_indices[0] | |
| logger.debug(f"Selected API key index {selected_index}") | |
| return selected_index, keys[selected_index] | |
| def _is_in_cooldown(self, key_index: int) -> bool: | |
| """Check if key is in cooldown period.""" | |
| if key_index not in _key_cooldowns: | |
| return False | |
| cooldown_until = _key_cooldowns[key_index] | |
| if datetime.utcnow() > cooldown_until: | |
| # Cooldown expired | |
| del _key_cooldowns[key_index] | |
| return False | |
| return True | |
| def _mark_cooldown(self, key_index: int): | |
| """Mark key as in cooldown.""" | |
| cooldown_seconds = APIKeyServiceConfig._cooldown_seconds | |
| cooldown_until = datetime.utcnow() + timedelta(seconds=cooldown_seconds) | |
| _key_cooldowns[key_index] = cooldown_until | |
| logger.info(f"Key {key_index} in cooldown until {cooldown_until}") | |
| async def _track_usage(self, key_index: int, success: bool, status_code: int): | |
| """Track API key usage.""" | |
| try: | |
| async with async_session_maker() as db: | |
| from services.api_key_manager import record_usage | |
| error_message = f"HTTP {status_code}" if not success else None | |
| await record_usage(db, key_index, success, error_message) | |
| await db.commit() | |
| except Exception as e: | |
| logger.error(f"Failed to track usage: {e}") | |