""" API Key Middleware - Automatic key selection and rotation Automatically selects and injects Gemini API keys for requests. Handles quota errors with automatic key rotation and retry. """ import time import logging from datetime import datetime, timedelta from typing import Optional, Dict from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp from core.database import async_session_maker from services.gemini_service.api_key_config import APIKeyServiceConfig logger = logging.getLogger(__name__) # Track key cooldowns in memory _key_cooldowns: Dict[int, datetime] = {} class APIKeyMiddleware(BaseHTTPMiddleware): """ Middleware for automatic API key management. Features: - Automatic key selection based on strategy - Quota error detection and recovery - Key cooldown management - Usage tracking """ def __init__(self, app: ASGIApp): super().__init__(app) async def dispatch(self, request: Request, call_next): """ Process request with automatic API key injection. Flow: 1. Check if Gemini request 2. Select best available key 3. Inject into request state 4. Handle response (quota errors) """ # Only handle Gemini requests if not self._is_gemini_request(request): return await call_next(request) # Select API key try: key_index, api_key = await self._select_api_key() request.state.gemini_api_key = api_key request.state.gemini_key_index = key_index except ValueError as e: # No keys available logger.error(f"No API keys available: {e}") return Response( content=f'{{"detail": "{str(e)}"}}', status_code=503, media_type="application/json" ) # Process request response = await call_next(request) # Handle quota errors if response.status_code == 429 and APIKeyServiceConfig._retry_on_quota_error: logger.warning(f"Quota error on key {key_index}, attempting retry") # Mark key in cooldown self._mark_cooldown(key_index) # Try to select different key try: key_index, api_key = await self._select_api_key(exclude_index=key_index) request.state.gemini_api_key = api_key request.state.gemini_key_index = key_index # Retry request logger.info(f"Retrying with key {key_index}") response = await call_next(request) except ValueError: # No other keys available logger.error("All API keys in cooldown or exhausted") # Track usage success = response.status_code < 400 await self._track_usage(key_index, success, response.status_code) return response def _is_gemini_request(self, request: Request) -> bool: """Check if request is for Gemini service.""" path = request.url.path gemini_paths = ["/gemini/", "/api/gemini"] return any(path.startswith(p) for p in gemini_paths) async def _select_api_key(self, exclude_index: Optional[int] = None) -> tuple[int, str]: """ Select best available API key. Args: exclude_index: Key index to exclude (e.g., after quota error) Returns: Tuple of (key_index, api_key) Raises: ValueError: If no keys available """ keys = APIKeyServiceConfig.get_api_keys() if not keys: raise ValueError("No API keys configured") # Filter out excluded and cooldown keys available_indices = [] for i in range(len(keys)): if i == exclude_index: continue if self._is_in_cooldown(i): continue available_indices.append(i) if not available_indices: raise ValueError("All API keys in cooldown") # Select based on strategy if APIKeyServiceConfig._rotation_strategy == "round_robin": # Simple round-robin selected_index = available_indices[0] else: # least_used # Get usage stats from DB async with async_session_maker() as db: from services.api_key_manager import get_least_used_key try: selected_index, _ = await get_least_used_key(db) if selected_index not in available_indices: # Fallback to first available selected_index = available_indices[0] except Exception as e: logger.error(f"Error getting least used key: {e}") selected_index = available_indices[0] logger.debug(f"Selected API key index {selected_index}") return selected_index, keys[selected_index] def _is_in_cooldown(self, key_index: int) -> bool: """Check if key is in cooldown period.""" if key_index not in _key_cooldowns: return False cooldown_until = _key_cooldowns[key_index] if datetime.utcnow() > cooldown_until: # Cooldown expired del _key_cooldowns[key_index] return False return True def _mark_cooldown(self, key_index: int): """Mark key as in cooldown.""" cooldown_seconds = APIKeyServiceConfig._cooldown_seconds cooldown_until = datetime.utcnow() + timedelta(seconds=cooldown_seconds) _key_cooldowns[key_index] = cooldown_until logger.info(f"Key {key_index} in cooldown until {cooldown_until}") async def _track_usage(self, key_index: int, success: bool, status_code: int): """Track API key usage.""" try: async with async_session_maker() as db: from services.api_key_manager import record_usage error_message = f"HTTP {status_code}" if not success else None await record_usage(db, key_index, success, error_message) await db.commit() except Exception as e: logger.error(f"Failed to track usage: {e}")