File size: 6,990 Bytes
ded29b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import time
from typing import List, Optional
from dotenv import load_dotenv
import logging
import sys

# Configure logging to stdout only
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()

# API key management constants
KEY_COOLDOWN = 60  # Seconds to wait before retrying a failed key

class ApiKeyManager:
    """
    Manages the rotation and cooldown of API keys to prevent rate limiting issues
    """
    def __init__(self, api_keys: Optional[List[str]] = None):
        """
        Initialize the API key manager with a list of API keys
        
        Args:
            api_keys: List of API keys. If None, keys will be loaded from environment variables
        """
        if api_keys is None:
            # Load API keys from environment variable
            env_keys = os.environ.get('GEMINI_API_KEYS', '')
            if not env_keys:
                logger.error("No API keys found in environment variables")
                raise ValueError("No API keys found in environment variables")
            
            api_keys = [key.strip() for key in env_keys.split(',') if key.strip()]
            logger.info(f"Loaded {len(api_keys)} API keys from environment variables")
            
        self.api_keys = api_keys
        self.current_index = 0
        self.failed_keys = {}  # Maps key to timestamp of failure
        
        if not self.api_keys:
            logger.error("No valid API keys provided")
            raise ValueError("No API keys provided. Please add at least one valid API key.")
        
        # Log the first few characters of each key for debugging
        for i, key in enumerate(self.api_keys):
            masked_key = key[:8] + "..." if len(key) > 8 else "..."
            logger.info(f"API key {i+1}: {masked_key}")
        
        logger.info(f"ApiKeyManager initialized with {len(self.api_keys)} keys")
    
    def get_next_api_key(self) -> str:
        """
        Get the next working API key in rotation
        
        Returns:
            str: The next available API key
        
        Raises:
            Exception: If no working API keys are available
        """
        current_time = time.time()
        
        # Check for keys that have completed their cooldown period
        for key in list(self.failed_keys.keys()):
            if current_time - self.failed_keys[key] > KEY_COOLDOWN:
                masked_key = key[:8] + "..." if len(key) > 8 else "..."
                logger.info(f"API key {masked_key} has cooled down after {KEY_COOLDOWN} seconds and is being retried")
                del self.failed_keys[key]
        
        # If all keys are in cooldown
        if len(self.failed_keys) == len(self.api_keys):
            logger.warning("All API keys are currently in cooldown. Waiting for the first key to be available again...")
            # Find the key with the earliest cooldown
            earliest_key = min(self.failed_keys, key=self.failed_keys.get)
            wait_time = KEY_COOLDOWN - (current_time - self.failed_keys[earliest_key]) + 1
            wait_time = max(1, int(wait_time))
            logger.info(f"Waiting {wait_time} seconds for cooldown...")
            time.sleep(wait_time)
            # Remove this key from the failed keys to retry it
            masked_key = earliest_key[:8] + "..." if len(earliest_key) > 8 else "..."
            logger.info(f"Cooldown complete for {masked_key}, removing from failed keys")
            del self.failed_keys[earliest_key]
            return earliest_key
        
        # Find the next key not in cooldown
        start_index = self.current_index
        while True:
            api_key = self.api_keys[self.current_index]
            self.current_index = (self.current_index + 1) % len(self.api_keys)
            
            if api_key not in self.failed_keys:
                masked_key = api_key[:8] + "..." if len(api_key) > 8 else "..."
                logger.info(f"Using API key: {masked_key}")
                return api_key
            
            # If we've checked all keys and none are available
            if self.current_index == start_index:
                # This shouldn't happen because we've already checked above
                logger.error("Could not find an available API key, which is unexpected")
                raise ValueError("Could not find an available API key, which is unexpected")
    
    def mark_key_as_failed(self, api_key: str) -> None:
        """
        Mark an API key as temporarily unavailable
        
        Args:
            api_key: The API key to mark as failed
        """
        self.failed_keys[api_key] = time.time()
        masked_key = api_key[:8] + "..." if len(api_key) > 8 else "..."
        logger.warning(f"API key {masked_key} marked as temporarily unavailable")
        available_keys = len(self.api_keys) - len(self.failed_keys)
        logger.info(f"{available_keys} keys immediately available. {len(self.failed_keys)} keys in cooldown.")
        
        # If we're running out of keys, log a more urgent warning
        if available_keys <= 1:
            logger.warning(f"WARNING: Only {available_keys} API keys remaining. Consider adding more keys.")
    
    def has_working_keys(self) -> bool:
        """
        Check if there are any working keys left
        
        Returns:
            bool: True if there are working keys, False otherwise
        """
        available_keys = len(self.api_keys) - len(self.failed_keys)
        logger.info(f"Checking API key availability: {available_keys} keys available")
        return available_keys > 0

    def get_api_keys_status(self) -> dict:
        """
        Get status information about API keys
        
        Returns:
            dict: Status information about the API keys
        """
        current_time = time.time()
        available_keys = []
        cooldown_keys = []
        
        for i, key in enumerate(self.api_keys):
            masked_key = key[:8] + "..." if len(key) > 8 else "..."
            if key in self.failed_keys:
                cooldown_remaining = max(0, KEY_COOLDOWN - (current_time - self.failed_keys[key]))
                cooldown_keys.append({
                    "key_id": i,
                    "masked_key": masked_key,
                    "cooldown_remaining_seconds": int(cooldown_remaining)
                })
            else:
                available_keys.append({
                    "key_id": i,
                    "masked_key": masked_key
                })
        
        return {
            "total_keys": len(self.api_keys),
            "available_keys": len(available_keys),
            "cooldown_keys": len(cooldown_keys),
            "available_key_details": available_keys,
            "cooldown_key_details": cooldown_keys
        }