File size: 10,250 Bytes
c001f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"""
Unified API Key Manager with automatic failover and rotation.

This module manages multiple API keys for each service and automatically
switches to backup keys when one fails due to rate limiting or errors.
"""

import os
import time
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import threading
import logging

logger = logging.getLogger(__name__)

@dataclass
class APIKeyStatus:
    """Tracks the status of an individual API key."""
    key: str
    service: str
    last_used: Optional[datetime] = None
    failure_count: int = 0
    last_failure: Optional[datetime] = None
    is_blocked: bool = False
    blocked_until: Optional[datetime] = None
    total_requests: int = 0
    successful_requests: int = 0
    
    def mark_success(self):
        """Mark a successful API call."""
        self.last_used = datetime.now()
        self.total_requests += 1
        self.successful_requests += 1
        self.failure_count = 0  # Reset failure count on success
        self.is_blocked = False
        self.blocked_until = None
    
    def mark_failure(self, block_duration_minutes: int = 5):
        """Mark a failed API call and potentially block the key."""
        self.last_used = datetime.now()
        self.last_failure = datetime.now()
        self.total_requests += 1
        self.failure_count += 1
        
        # Block key after 3 consecutive failures
        if self.failure_count >= 3:
            self.is_blocked = True
            self.blocked_until = datetime.now() + timedelta(minutes=block_duration_minutes)
            logger.warning(f"API key for {self.service} blocked until {self.blocked_until} after {self.failure_count} failures")
    
    def is_available(self) -> bool:
        """Check if this key is available for use."""
        if not self.is_blocked:
            return True
        
        # Check if block has expired
        if self.blocked_until and datetime.now() > self.blocked_until:
            self.is_blocked = False
            self.blocked_until = None
            self.failure_count = 0
            logger.info(f"API key for {self.service} unblocked after cooldown period")
            return True
        
        return False
    
    def get_success_rate(self) -> float:
        """Calculate success rate percentage."""
        if self.total_requests == 0:
            return 100.0
        return (self.successful_requests / self.total_requests) * 100


class APIKeyManager:
    """
    Manages multiple API keys for different services with automatic failover.
    
    Supports multiple keys per service and automatically rotates to backup keys
    when one fails or hits rate limits.
    """
    
    def __init__(self):
        self.keys: Dict[str, List[APIKeyStatus]] = {}
        self.current_index: Dict[str, int] = {}
        self.lock = threading.Lock()
        self._load_keys_from_env()
    
    def _load_keys_from_env(self):
        """Load API keys from environment variables."""
        
        # NVIDIA API Keys
        nvidia_keys = self._get_keys_from_env('NVIDIA_API_KEY')
        if nvidia_keys:
            self.register_service('nvidia', nvidia_keys)
        
        # Gemini API Keys
        gemini_keys = self._get_keys_from_env('GEMINI_API_KEY')
        google_keys = self._get_keys_from_env('GOOGLE_API_KEY')
        all_gemini_keys = gemini_keys + google_keys
        if all_gemini_keys:
            self.register_service('gemini', all_gemini_keys)
        
        # OpenRouter API Keys (for Nova)
        openrouter_keys = self._get_keys_from_env('OPENROUTER_API_KEY')
        if openrouter_keys:
            self.register_service('openrouter', openrouter_keys)
        
        logger.info(f"Loaded API keys: NVIDIA={len(nvidia_keys)}, Gemini={len(all_gemini_keys)}, OpenRouter={len(openrouter_keys)}")
    
    def _get_keys_from_env(self, base_name: str) -> List[str]:
        """
        Get API keys from environment variables.
        Loads keys in order:
        1. BASE_NAME (as index 0)
        2. BASE_NAME_1, BASE_NAME_2, BASE_NAME_3, etc. (as indices 1, 2, 3...)
        
        Example:
        - GEMINI_API_KEY      → index 0
        - GEMINI_API_KEY_1    → index 1
        - GEMINI_API_KEY_2    → index 2
        """
        keys = []
        
        # First, try base key (index 0)
        base_key = os.environ.get(base_name)
        if base_key:
            keys.append(base_key)
        
        # Then try numbered keys (1-10)
        for i in range(1, 11):
            numbered_key = os.environ.get(f"{base_name}_{i}")
            if numbered_key:
                keys.append(numbered_key)
        
        # Remove duplicates while preserving order
        seen = set()
        unique_keys = []
        for key in keys:
            if key not in seen:
                seen.add(key)
                unique_keys.append(key)
        
        return unique_keys
    
    def register_service(self, service: str, api_keys: List[str]):
        """Register multiple API keys for a service."""
        with self.lock:
            self.keys[service] = [
                APIKeyStatus(key=key, service=service) 
                for key in api_keys
            ]
            self.current_index[service] = 0
            logger.info(f"Registered {len(api_keys)} API key(s) for service: {service}")
    
    def get_key(self, service: str) -> Optional[Tuple[str, int]]:
        """
        Get an available API key for the specified service.
        Returns (api_key, key_index) or (None, -1) if no keys available.
        """
        with self.lock:
            if service not in self.keys or not self.keys[service]:
                logger.warning(f"No API keys registered for service: {service}")
                return None, -1
            
            service_keys = self.keys[service]
            start_index = self.current_index[service]
            
            # Try to find an available key, starting from current index
            for attempt in range(len(service_keys)):
                current_idx = (start_index + attempt) % len(service_keys)
                key_status = service_keys[current_idx]
                
                if key_status.is_available():
                    self.current_index[service] = current_idx
                    logger.debug(f"Using API key {current_idx + 1}/{len(service_keys)} for {service}")
                    return key_status.key, current_idx
            
            # All keys are blocked
            logger.error(f"All API keys for {service} are currently blocked or unavailable")
            return None, -1
    
    def mark_success(self, service: str, key_index: int):
        """Mark an API call as successful."""
        with self.lock:
            if service in self.keys and 0 <= key_index < len(self.keys[service]):
                self.keys[service][key_index].mark_success()
                logger.debug(f"API key {key_index + 1} for {service} marked as successful")
                
                # Move to next key for load balancing (round-robin)
                self.current_index[service] = (key_index + 1) % len(self.keys[service])
    
    def mark_failure(self, service: str, key_index: int, block_duration_minutes: int = 5):
        """Mark an API call as failed and potentially block the key."""
        with self.lock:
            if service in self.keys and 0 <= key_index < len(self.keys[service]):
                self.keys[service][key_index].mark_failure(block_duration_minutes)
                logger.warning(f"API key {key_index + 1} for {service} marked as failed")
                
                # Move to next key immediately
                self.current_index[service] = (key_index + 1) % len(self.keys[service])
    
    def get_service_status(self, service: str) -> Dict:
        """Get status information for a service."""
        with self.lock:
            if service not in self.keys:
                return {
                    'service': service,
                    'available': False,
                    'total_keys': 0,
                    'available_keys': 0,
                    'blocked_keys': 0
                }
            
            service_keys = self.keys[service]
            available_keys = sum(1 for k in service_keys if k.is_available())
            blocked_keys = sum(1 for k in service_keys if k.is_blocked)
            
            return {
                'service': service,
                'available': available_keys > 0,
                'total_keys': len(service_keys),
                'available_keys': available_keys,
                'blocked_keys': blocked_keys,
                'keys': [
                    {
                        'index': i,
                        'is_available': k.is_available(),
                        'is_blocked': k.is_blocked,
                        'failure_count': k.failure_count,
                        'total_requests': k.total_requests,
                        'success_rate': round(k.get_success_rate(), 2),
                        'blocked_until': k.blocked_until.isoformat() if k.blocked_until else None
                    }
                    for i, k in enumerate(service_keys)
                ]
            }
    
    def get_all_services_status(self) -> Dict[str, Dict]:
        """Get status for all registered services."""
        return {
            service: self.get_service_status(service)
            for service in self.keys.keys()
        }
    
    def reset_service(self, service: str):
        """Reset all keys for a service (unblock and clear stats)."""
        with self.lock:
            if service in self.keys:
                for key_status in self.keys[service]:
                    key_status.is_blocked = False
                    key_status.blocked_until = None
                    key_status.failure_count = 0
                logger.info(f"Reset all keys for service: {service}")


# Global singleton instance
_api_key_manager = None

def get_api_key_manager() -> APIKeyManager:
    """Get the global API key manager instance."""
    global _api_key_manager
    if _api_key_manager is None:
        _api_key_manager = APIKeyManager()
    return _api_key_manager