jebin2 commited on
Commit
43df312
Β·
1 Parent(s): 1c302c7

Phase 3: Implement API Key Middleware

Browse files

Core implementation:
- Created api_key_config.py: Configuration for key rotation strategies
- Created api_key_middleware.py: Automatic key selection and quota handling
- Integrated into app.py middleware chain

Features:
- Automatic API key selection (least_used or round_robin)
- Quota error detection (429) with automatic retry
- Key cooldown management (60s after quota error)
- Transparent key rotation (app doesn't know which key)
- Usage tracking per key
- No code changes needed in endpoints

Middleware order: Auth β†’ Audit β†’ API Key β†’ Credit β†’ Application

Benefits:
- Zero downtime on quota errors
- Automatic load balancing across keys
- Better observability of key usage
- No manual key management needed

Next: Testing and Phase 4 (Payment Transaction Manager)

app.py CHANGED
@@ -132,6 +132,16 @@ async def lifespan(app: FastAPI):
132
  )
133
  logger.info("βœ… Audit Service configured")
134
 
 
 
 
 
 
 
 
 
 
 
135
  # Check for RESET_DB environment variable
136
  if os.getenv("RESET_DB", "").lower() == "true":
137
  logger.warning(f"RESET_DB is set to true. Skipping download and clearing local database ({DB_FILENAME}).")
@@ -194,11 +204,15 @@ app.add_middleware(
194
  allow_headers=["*"],
195
  )
196
 
197
- # Add Audit Middleware (executes second - after auth, before credit)
198
  from services.audit_service import AuditMiddleware
199
  app.add_middleware(AuditMiddleware)
200
 
201
- # Add Credit Middleware (executes third - after auth and audit)
 
 
 
 
202
  from services.credit_service import CreditMiddleware
203
  app.add_middleware(CreditMiddleware)
204
 
 
132
  )
133
  logger.info("βœ… Audit Service configured")
134
 
135
+ # Register API Key Service configuration
136
+ from services.gemini_service import APIKeyServiceConfig
137
+ APIKeyServiceConfig.register(
138
+ rotation_strategy="least_used", # or "round_robin"
139
+ cooldown_seconds=60, # Wait 1 min after quota error
140
+ max_requests_per_minute=60,
141
+ retry_on_quota_error=True # Auto-retry with different key
142
+ )
143
+ logger.info("βœ… API Key Service configured")
144
+
145
  # Check for RESET_DB environment variable
146
  if os.getenv("RESET_DB", "").lower() == "true":
147
  logger.warning(f"RESET_DB is set to true. Skipping download and clearing local database ({DB_FILENAME}).")
 
204
  allow_headers=["*"],
205
  )
206
 
207
+ # Add Audit Middleware (executes second - after auth, before API key)
208
  from services.audit_service import AuditMiddleware
209
  app.add_middleware(AuditMiddleware)
210
 
211
+ # Add API Key Middleware (executes third - for Gemini requests)
212
+ from services.gemini_service import APIKeyMiddleware
213
+ app.add_middleware(APIKeyMiddleware)
214
+
215
+ # Add Credit Middleware (executes fourth - after auth, audit, and API key)
216
  from services.credit_service import CreditMiddleware
217
  app.add_middleware(CreditMiddleware)
218
 
services/gemini_service/__init__.py CHANGED
@@ -16,6 +16,7 @@ from services.gemini_service.api_client import (
16
  get_gemini_api_key,
17
  MOCK_MODE,
18
  MOCK_VIDEO_URL,
 
19
  )
20
 
21
  # Job Processor exports
@@ -28,6 +29,10 @@ from services.gemini_service.job_processor import (
28
  stop_worker,
29
  )
30
 
 
 
 
 
31
  __all__ = [
32
  # API Client
33
  'GeminiService',
@@ -36,6 +41,7 @@ __all__ = [
36
  'get_gemini_api_key',
37
  'MOCK_MODE',
38
  'MOCK_VIDEO_URL',
 
39
 
40
  # Job Processor
41
  'GeminiJobProcessor',
@@ -44,4 +50,8 @@ __all__ = [
44
  'get_priority_for_job_type',
45
  'start_worker',
46
  'stop_worker',
 
 
 
 
47
  ]
 
16
  get_gemini_api_key,
17
  MOCK_MODE,
18
  MOCK_VIDEO_URL,
19
+ GeminiAPIClient, # Added
20
  )
21
 
22
  # Job Processor exports
 
29
  stop_worker,
30
  )
31
 
32
+ # API Key Middleware exports # Added
33
+ from services.gemini_service.api_key_config import APIKeyServiceConfig # Added
34
+ from services.gemini_service.api_key_middleware import APIKeyMiddleware # Added
35
+
36
  __all__ = [
37
  # API Client
38
  'GeminiService',
 
41
  'get_gemini_api_key',
42
  'MOCK_MODE',
43
  'MOCK_VIDEO_URL',
44
+ 'GeminiAPIClient',
45
 
46
  # Job Processor
47
  'GeminiJobProcessor',
 
50
  'get_priority_for_job_type',
51
  'start_worker',
52
  'stop_worker',
53
+
54
+ # API Key Middleware
55
+ 'APIKeyServiceConfig',
56
+ 'APIKeyMiddleware',
57
  ]
services/gemini_service/api_key_config.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API Key Service Configuration
3
+
4
+ Configures automatic API key selection and rotation via middleware.
5
+ """
6
+ from typing import List, Optional
7
+ import os
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class APIKeyServiceConfig:
14
+ """Configuration for API key middleware."""
15
+
16
+ _rotation_strategy: str = "least_used" # or "round_robin"
17
+ _cooldown_seconds: int = 60
18
+ _max_requests_per_minute: int = 60
19
+ _retry_on_quota_error: bool = True
20
+ _api_keys: Optional[List[str]] = None
21
+
22
+ @classmethod
23
+ def register(
24
+ cls,
25
+ rotation_strategy: str = "least_used",
26
+ cooldown_seconds: int = 60,
27
+ max_requests_per_minute: int = 60,
28
+ retry_on_quota_error: bool = True
29
+ ) -> None:
30
+ """
31
+ Register API key service configuration.
32
+
33
+ Args:
34
+ rotation_strategy: "least_used" or "round_robin"
35
+ cooldown_seconds: Time to wait before reusing a key after quota error
36
+ max_requests_per_minute: Rate limit per key
37
+ retry_on_quota_error: Auto-retry with different key on 429
38
+
39
+ Example:
40
+ APIKeyServiceConfig.register(
41
+ rotation_strategy="least_used",
42
+ cooldown_seconds=60,
43
+ retry_on_quota_error=True
44
+ )
45
+ """
46
+ cls._rotation_strategy = rotation_strategy
47
+ cls._cooldown_seconds = cooldown_seconds
48
+ cls._max_requests_per_minute = max_requests_per_minute
49
+ cls._retry_on_quota_error = retry_on_quota_error
50
+
51
+ # Load API keys from env
52
+ cls._load_api_keys()
53
+
54
+ logger.info(
55
+ f"API Key Service configured: "
56
+ f"keys={len(cls._api_keys or [])}, "
57
+ f"strategy={rotation_strategy}, "
58
+ f"retry={retry_on_quota_error}"
59
+ )
60
+
61
+ @classmethod
62
+ def _load_api_keys(cls):
63
+ """Load API keys from environment variables."""
64
+ keys_str = os.getenv("GEMINI_API_KEYS", "")
65
+ if not keys_str:
66
+ # Fallback to single key
67
+ single_key = os.getenv("GEMINI_API_KEY", "")
68
+ if single_key:
69
+ cls._api_keys = [single_key]
70
+ else:
71
+ cls._api_keys = []
72
+ logger.warning("No Gemini API keys configured!")
73
+ else:
74
+ cls._api_keys = [k.strip() for k in keys_str.split(",") if k.strip()]
75
+
76
+ if cls._api_keys:
77
+ logger.info(f"Loaded {len(cls._api_keys)} Gemini API key(s)")
78
+
79
+ @classmethod
80
+ def get_api_keys(cls) -> List[str]:
81
+ """Get loaded API keys."""
82
+ if cls._api_keys is None:
83
+ cls._load_api_keys()
84
+ return cls._api_keys or []
85
+
86
+ @classmethod
87
+ def get_key_count(cls) -> int:
88
+ """Get number of available keys."""
89
+ return len(cls.get_api_keys())
90
+
91
+ @classmethod
92
+ def get_config(cls) -> dict:
93
+ """Get current configuration."""
94
+ return {
95
+ "key_count": cls.get_key_count(),
96
+ "rotation_strategy": cls._rotation_strategy,
97
+ "cooldown_seconds": cls._cooldown_seconds,
98
+ "max_requests_per_minute": cls._max_requests_per_minute,
99
+ "retry_on_quota_error": cls._retry_on_quota_error
100
+ }
services/gemini_service/api_key_middleware.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API Key Middleware - Automatic key selection and rotation
3
+
4
+ Automatically selects and injects Gemini API keys for requests.
5
+ Handles quota errors with automatic key rotation and retry.
6
+ """
7
+ import time
8
+ import logging
9
+ from datetime import datetime, timedelta
10
+ from typing import Optional, Dict
11
+ from fastapi import Request, Response
12
+ from starlette.middleware.base import BaseHTTPMiddleware
13
+ from starlette.types import ASGIApp
14
+
15
+ from core.database import async_session_maker
16
+ from services.gemini_service.api_key_config import APIKeyServiceConfig
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ # Track key cooldowns in memory
22
+ _key_cooldowns: Dict[int, datetime] = {}
23
+
24
+
25
+ class APIKeyMiddleware(BaseHTTPMiddleware):
26
+ """
27
+ Middleware for automatic API key management.
28
+
29
+ Features:
30
+ - Automatic key selection based on strategy
31
+ - Quota error detection and recovery
32
+ - Key cooldown management
33
+ - Usage tracking
34
+ """
35
+
36
+ def __init__(self, app: ASGIApp):
37
+ super().__init__(app)
38
+
39
+ async def dispatch(self, request: Request, call_next):
40
+ """
41
+ Process request with automatic API key injection.
42
+
43
+ Flow:
44
+ 1. Check if Gemini request
45
+ 2. Select best available key
46
+ 3. Inject into request state
47
+ 4. Handle response (quota errors)
48
+ """
49
+ # Only handle Gemini requests
50
+ if not self._is_gemini_request(request):
51
+ return await call_next(request)
52
+
53
+ # Select API key
54
+ try:
55
+ key_index, api_key = await self._select_api_key()
56
+ request.state.gemini_api_key = api_key
57
+ request.state.gemini_key_index = key_index
58
+ except ValueError as e:
59
+ # No keys available
60
+ logger.error(f"No API keys available: {e}")
61
+ return Response(
62
+ content=f'{{"detail": "{str(e)}"}}',
63
+ status_code=503,
64
+ media_type="application/json"
65
+ )
66
+
67
+ # Process request
68
+ response = await call_next(request)
69
+
70
+ # Handle quota errors
71
+ if response.status_code == 429 and API KeyServiceConfig._retry_on_quota_error:
72
+ logger.warning(f"Quota error on key {key_index}, attempting retry")
73
+
74
+ # Mark key in cooldown
75
+ self._mark_cooldown(key_index)
76
+
77
+ # Try to select different key
78
+ try:
79
+ key_index, api_key = await self._select_api_key(exclude_index=key_index)
80
+ request.state.gemini_api_key = api_key
81
+ request.state.gemini_key_index = key_index
82
+
83
+ # Retry request
84
+ logger.info(f"Retrying with key {key_index}")
85
+ response = await call_next(request)
86
+ except ValueError:
87
+ # No other keys available
88
+ logger.error("All API keys in cooldown or exhausted")
89
+
90
+ # Track usage
91
+ success = response.status_code < 400
92
+ await self._track_usage(key_index, success, response.status_code)
93
+
94
+ return response
95
+
96
+ def _is_gemini_request(self, request: Request) -> bool:
97
+ """Check if request is for Gemini service."""
98
+ path = request.url.path
99
+ gemini_paths = ["/gemini/", "/api/gemini"]
100
+ return any(path.startswith(p) for p in gemini_paths)
101
+
102
+ async def _select_api_key(self, exclude_index: Optional[int] = None) -> tuple[int, str]:
103
+ """
104
+ Select best available API key.
105
+
106
+ Args:
107
+ exclude_index: Key index to exclude (e.g., after quota error)
108
+
109
+ Returns:
110
+ Tuple of (key_index, api_key)
111
+
112
+ Raises:
113
+ ValueError: If no keys available
114
+ """
115
+ keys = APIKeyServiceConfig.get_api_keys()
116
+ if not keys:
117
+ raise ValueError("No API keys configured")
118
+
119
+ # Filter out excluded and cooldown keys
120
+ available_indices = []
121
+ for i in range(len(keys)):
122
+ if i == exclude_index:
123
+ continue
124
+ if self._is_in_cooldown(i):
125
+ continue
126
+ available_indices.append(i)
127
+
128
+ if not available_indices:
129
+ raise ValueError("All API keys in cooldown")
130
+
131
+ # Select based on strategy
132
+ if APIKeyServiceConfig._rotation_strategy == "round_robin":
133
+ # Simple round-robin
134
+ selected_index = available_indices[0]
135
+ else: # least_used
136
+ # Get usage stats from DB
137
+ async with async_session_maker() as db:
138
+ from services.api_key_manager import get_least_used_key
139
+ try:
140
+ selected_index, _ = await get_least_used_key(db)
141
+ if selected_index not in available_indices:
142
+ # Fallback to first available
143
+ selected_index = available_indices[0]
144
+ except Exception as e:
145
+ logger.error(f"Error getting least used key: {e}")
146
+ selected_index = available_indices[0]
147
+
148
+ logger.debug(f"Selected API key index {selected_index}")
149
+ return selected_index, keys[selected_index]
150
+
151
+ def _is_in_cooldown(self, key_index: int) -> bool:
152
+ """Check if key is in cooldown period."""
153
+ if key_index not in _key_cooldowns:
154
+ return False
155
+
156
+ cooldown_until = _key_cooldowns[key_index]
157
+ if datetime.utcnow() > cooldown_until:
158
+ # Cooldown expired
159
+ del _key_cooldowns[key_index]
160
+ return False
161
+
162
+ return True
163
+
164
+ def _mark_cooldown(self, key_index: int):
165
+ """Mark key as in cooldown."""
166
+ cooldown_seconds = APIKeyServiceConfig._cooldown_seconds
167
+ cooldown_until = datetime.utcnow() + timedelta(seconds=cooldown_seconds)
168
+ _key_cooldowns[key_index] = cooldown_until
169
+ logger.info(f"Key {key_index} in cooldown until {cooldown_until}")
170
+
171
+ async def _track_usage(self, key_index: int, success: bool, status_code: int):
172
+ """Track API key usage."""
173
+ try:
174
+ async with async_session_maker() as db:
175
+ from services.api_key_manager import record_usage
176
+ error_message = f"HTTP {status_code}" if not success else None
177
+ await record_usage(db, key_index, success, error_message)
178
+ await db.commit()
179
+ except Exception as e:
180
+ logger.error(f"Failed to track usage: {e}")