apigateway / services /gemini_service /api_key_middleware.py
jebin2's picture
ref
cfe2de7
"""
API Key Middleware - Automatic key selection and rotation
Automatically selects and injects Gemini API keys for requests.
Handles quota errors with automatic key rotation and retry.
"""
import time
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from core.database import async_session_maker
from services.gemini_service.api_key_config import APIKeyServiceConfig
logger = logging.getLogger(__name__)
# Track key cooldowns in memory
_key_cooldowns: Dict[int, datetime] = {}
class APIKeyMiddleware(BaseHTTPMiddleware):
"""
Middleware for automatic API key management.
Features:
- Automatic key selection based on strategy
- Quota error detection and recovery
- Key cooldown management
- Usage tracking
"""
def __init__(self, app: ASGIApp):
super().__init__(app)
async def dispatch(self, request: Request, call_next):
"""
Process request with automatic API key injection.
Flow:
1. Check if Gemini request
2. Select best available key
3. Inject into request state
4. Handle response (quota errors)
"""
# Only handle Gemini requests
if not self._is_gemini_request(request):
return await call_next(request)
# Select API key
try:
key_index, api_key = await self._select_api_key()
request.state.gemini_api_key = api_key
request.state.gemini_key_index = key_index
except ValueError as e:
# No keys available
logger.error(f"No API keys available: {e}")
return Response(
content=f'{{"detail": "{str(e)}"}}',
status_code=503,
media_type="application/json"
)
# Process request
response = await call_next(request)
# Handle quota errors
if response.status_code == 429 and APIKeyServiceConfig._retry_on_quota_error:
logger.warning(f"Quota error on key {key_index}, attempting retry")
# Mark key in cooldown
self._mark_cooldown(key_index)
# Try to select different key
try:
key_index, api_key = await self._select_api_key(exclude_index=key_index)
request.state.gemini_api_key = api_key
request.state.gemini_key_index = key_index
# Retry request
logger.info(f"Retrying with key {key_index}")
response = await call_next(request)
except ValueError:
# No other keys available
logger.error("All API keys in cooldown or exhausted")
# Track usage
success = response.status_code < 400
await self._track_usage(key_index, success, response.status_code)
return response
def _is_gemini_request(self, request: Request) -> bool:
"""Check if request is for Gemini service."""
path = request.url.path
gemini_paths = ["/gemini/", "/api/gemini"]
return any(path.startswith(p) for p in gemini_paths)
async def _select_api_key(self, exclude_index: Optional[int] = None) -> tuple[int, str]:
"""
Select best available API key.
Args:
exclude_index: Key index to exclude (e.g., after quota error)
Returns:
Tuple of (key_index, api_key)
Raises:
ValueError: If no keys available
"""
keys = APIKeyServiceConfig.get_api_keys()
if not keys:
raise ValueError("No API keys configured")
# Filter out excluded and cooldown keys
available_indices = []
for i in range(len(keys)):
if i == exclude_index:
continue
if self._is_in_cooldown(i):
continue
available_indices.append(i)
if not available_indices:
raise ValueError("All API keys in cooldown")
# Select based on strategy
if APIKeyServiceConfig._rotation_strategy == "round_robin":
# Simple round-robin
selected_index = available_indices[0]
else: # least_used
# Get usage stats from DB
async with async_session_maker() as db:
from services.api_key_manager import get_least_used_key
try:
selected_index, _ = await get_least_used_key(db)
if selected_index not in available_indices:
# Fallback to first available
selected_index = available_indices[0]
except Exception as e:
logger.error(f"Error getting least used key: {e}")
selected_index = available_indices[0]
logger.debug(f"Selected API key index {selected_index}")
return selected_index, keys[selected_index]
def _is_in_cooldown(self, key_index: int) -> bool:
"""Check if key is in cooldown period."""
if key_index not in _key_cooldowns:
return False
cooldown_until = _key_cooldowns[key_index]
if datetime.utcnow() > cooldown_until:
# Cooldown expired
del _key_cooldowns[key_index]
return False
return True
def _mark_cooldown(self, key_index: int):
"""Mark key as in cooldown."""
cooldown_seconds = APIKeyServiceConfig._cooldown_seconds
cooldown_until = datetime.utcnow() + timedelta(seconds=cooldown_seconds)
_key_cooldowns[key_index] = cooldown_until
logger.info(f"Key {key_index} in cooldown until {cooldown_until}")
async def _track_usage(self, key_index: int, success: bool, status_code: int):
"""Track API key usage."""
try:
async with async_session_maker() as db:
from services.api_key_manager import record_usage
error_message = f"HTTP {status_code}" if not success else None
await record_usage(db, key_index, success, error_message)
await db.commit()
except Exception as e:
logger.error(f"Failed to track usage: {e}")