File size: 6,523 Bytes
43df312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfe2de7
43df312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
API Key Middleware - Automatic key selection and rotation

Automatically selects and injects Gemini API keys for requests.
Handles quota errors with automatic key rotation and retry.
"""
import time
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp

from core.database import async_session_maker
from services.gemini_service.api_key_config import APIKeyServiceConfig

logger = logging.getLogger(__name__)


# Track key cooldowns in memory
_key_cooldowns: Dict[int, datetime] = {}


class APIKeyMiddleware(BaseHTTPMiddleware):
    """
    Middleware for automatic API key management.
    
    Features:
    - Automatic key selection based on strategy
    - Quota error detection and recovery
    - Key cooldown management
    - Usage tracking
    """
    
    def __init__(self, app: ASGIApp):
        super().__init__(app)
    
    async def dispatch(self, request: Request, call_next):
        """
        Process request with automatic API key injection.
        
        Flow:
        1. Check if Gemini request
        2. Select best available key
        3. Inject into request state
        4. Handle response (quota errors)
        """
        # Only handle Gemini requests
        if not self._is_gemini_request(request):
            return await call_next(request)
        
        # Select API key
        try:
            key_index, api_key = await self._select_api_key()
            request.state.gemini_api_key = api_key
            request.state.gemini_key_index = key_index
        except ValueError as e:
            # No keys available
            logger.error(f"No API keys available: {e}")
            return Response(
                content=f'{{"detail": "{str(e)}"}}',
                status_code=503,
                media_type="application/json"
            )
        
        # Process request
        response = await call_next(request)
        
        # Handle quota errors
        if response.status_code == 429 and APIKeyServiceConfig._retry_on_quota_error:
            logger.warning(f"Quota error on key {key_index}, attempting retry")
            
            # Mark key in cooldown
            self._mark_cooldown(key_index)
            
            # Try to select different key
            try:
                key_index, api_key = await self._select_api_key(exclude_index=key_index)
                request.state.gemini_api_key = api_key
                request.state.gemini_key_index = key_index
                
                # Retry request
                logger.info(f"Retrying with key {key_index}")
                response = await call_next(request)
            except ValueError:
                # No other keys available
                logger.error("All API keys in cooldown or exhausted")
        
        # Track usage
        success = response.status_code < 400
        await self._track_usage(key_index, success, response.status_code)
        
        return response
    
    def _is_gemini_request(self, request: Request) -> bool:
        """Check if request is for Gemini service."""
        path = request.url.path
        gemini_paths = ["/gemini/", "/api/gemini"]
        return any(path.startswith(p) for p in gemini_paths)
    
    async def _select_api_key(self, exclude_index: Optional[int] = None) -> tuple[int, str]:
        """
        Select best available API key.
        
        Args:
            exclude_index: Key index to exclude (e.g., after quota error)
        
        Returns:
            Tuple of (key_index, api_key)
        
        Raises:
            ValueError: If no keys available
        """
        keys = APIKeyServiceConfig.get_api_keys()
        if not keys:
            raise ValueError("No API keys configured")
        
        # Filter out excluded and cooldown keys
        available_indices = []
        for i in range(len(keys)):
            if i == exclude_index:
                continue
            if self._is_in_cooldown(i):
                continue
            available_indices.append(i)
        
        if not available_indices:
            raise ValueError("All API keys in cooldown")
        
        # Select based on strategy
        if APIKeyServiceConfig._rotation_strategy == "round_robin":
            # Simple round-robin
            selected_index = available_indices[0]
        else:  # least_used
            # Get usage stats from DB
            async with async_session_maker() as db:
                from services.api_key_manager import get_least_used_key
                try:
                    selected_index, _ = await get_least_used_key(db)
                    if selected_index not in available_indices:
                        # Fallback to first available
                        selected_index = available_indices[0]
                except Exception as e:
                    logger.error(f"Error getting least used key: {e}")
                    selected_index = available_indices[0]
        
        logger.debug(f"Selected API key index {selected_index}")
        return selected_index, keys[selected_index]
    
    def _is_in_cooldown(self, key_index: int) -> bool:
        """Check if key is in cooldown period."""
        if key_index not in _key_cooldowns:
            return False
        
        cooldown_until = _key_cooldowns[key_index]
        if datetime.utcnow() > cooldown_until:
            # Cooldown expired
            del _key_cooldowns[key_index]
            return False
        
        return True
    
    def _mark_cooldown(self, key_index: int):
        """Mark key as in cooldown."""
        cooldown_seconds = APIKeyServiceConfig._cooldown_seconds
        cooldown_until = datetime.utcnow() + timedelta(seconds=cooldown_seconds)
        _key_cooldowns[key_index] = cooldown_until
        logger.info(f"Key {key_index} in cooldown until {cooldown_until}")
    
    async def _track_usage(self, key_index: int, success: bool, status_code: int):
        """Track API key usage."""
        try:
            async with async_session_maker() as db:
                from services.api_key_manager import record_usage
                error_message = f"HTTP {status_code}" if not success else None
                await record_usage(db, key_index, success, error_message)
                await db.commit()
        except Exception as e:
            logger.error(f"Failed to track usage: {e}")