Spaces:
Sleeping
Sleeping
File size: 6,523 Bytes
43df312 cfe2de7 43df312 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
"""
API Key Middleware - Automatic key selection and rotation
Automatically selects and injects Gemini API keys for requests.
Handles quota errors with automatic key rotation and retry.
"""
import time
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from core.database import async_session_maker
from services.gemini_service.api_key_config import APIKeyServiceConfig
logger = logging.getLogger(__name__)
# Track key cooldowns in memory
_key_cooldowns: Dict[int, datetime] = {}
class APIKeyMiddleware(BaseHTTPMiddleware):
"""
Middleware for automatic API key management.
Features:
- Automatic key selection based on strategy
- Quota error detection and recovery
- Key cooldown management
- Usage tracking
"""
def __init__(self, app: ASGIApp):
super().__init__(app)
async def dispatch(self, request: Request, call_next):
"""
Process request with automatic API key injection.
Flow:
1. Check if Gemini request
2. Select best available key
3. Inject into request state
4. Handle response (quota errors)
"""
# Only handle Gemini requests
if not self._is_gemini_request(request):
return await call_next(request)
# Select API key
try:
key_index, api_key = await self._select_api_key()
request.state.gemini_api_key = api_key
request.state.gemini_key_index = key_index
except ValueError as e:
# No keys available
logger.error(f"No API keys available: {e}")
return Response(
content=f'{{"detail": "{str(e)}"}}',
status_code=503,
media_type="application/json"
)
# Process request
response = await call_next(request)
# Handle quota errors
if response.status_code == 429 and APIKeyServiceConfig._retry_on_quota_error:
logger.warning(f"Quota error on key {key_index}, attempting retry")
# Mark key in cooldown
self._mark_cooldown(key_index)
# Try to select different key
try:
key_index, api_key = await self._select_api_key(exclude_index=key_index)
request.state.gemini_api_key = api_key
request.state.gemini_key_index = key_index
# Retry request
logger.info(f"Retrying with key {key_index}")
response = await call_next(request)
except ValueError:
# No other keys available
logger.error("All API keys in cooldown or exhausted")
# Track usage
success = response.status_code < 400
await self._track_usage(key_index, success, response.status_code)
return response
def _is_gemini_request(self, request: Request) -> bool:
"""Check if request is for Gemini service."""
path = request.url.path
gemini_paths = ["/gemini/", "/api/gemini"]
return any(path.startswith(p) for p in gemini_paths)
async def _select_api_key(self, exclude_index: Optional[int] = None) -> tuple[int, str]:
"""
Select best available API key.
Args:
exclude_index: Key index to exclude (e.g., after quota error)
Returns:
Tuple of (key_index, api_key)
Raises:
ValueError: If no keys available
"""
keys = APIKeyServiceConfig.get_api_keys()
if not keys:
raise ValueError("No API keys configured")
# Filter out excluded and cooldown keys
available_indices = []
for i in range(len(keys)):
if i == exclude_index:
continue
if self._is_in_cooldown(i):
continue
available_indices.append(i)
if not available_indices:
raise ValueError("All API keys in cooldown")
# Select based on strategy
if APIKeyServiceConfig._rotation_strategy == "round_robin":
# Simple round-robin
selected_index = available_indices[0]
else: # least_used
# Get usage stats from DB
async with async_session_maker() as db:
from services.api_key_manager import get_least_used_key
try:
selected_index, _ = await get_least_used_key(db)
if selected_index not in available_indices:
# Fallback to first available
selected_index = available_indices[0]
except Exception as e:
logger.error(f"Error getting least used key: {e}")
selected_index = available_indices[0]
logger.debug(f"Selected API key index {selected_index}")
return selected_index, keys[selected_index]
def _is_in_cooldown(self, key_index: int) -> bool:
"""Check if key is in cooldown period."""
if key_index not in _key_cooldowns:
return False
cooldown_until = _key_cooldowns[key_index]
if datetime.utcnow() > cooldown_until:
# Cooldown expired
del _key_cooldowns[key_index]
return False
return True
def _mark_cooldown(self, key_index: int):
"""Mark key as in cooldown."""
cooldown_seconds = APIKeyServiceConfig._cooldown_seconds
cooldown_until = datetime.utcnow() + timedelta(seconds=cooldown_seconds)
_key_cooldowns[key_index] = cooldown_until
logger.info(f"Key {key_index} in cooldown until {cooldown_until}")
async def _track_usage(self, key_index: int, success: bool, status_code: int):
"""Track API key usage."""
try:
async with async_session_maker() as db:
from services.api_key_manager import record_usage
error_message = f"HTTP {status_code}" if not success else None
await record_usage(db, key_index, success, error_message)
await db.commit()
except Exception as e:
logger.error(f"Failed to track usage: {e}")
|