File size: 8,064 Bytes
90c099b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""
Asta API Key Pool Manager

Manage multiple Asta API keys, implement key rotation and error handling.
"""
import os
import random
import time
from pathlib import Path
from typing import List, Optional, Dict
from threading import Lock


class AstaAPIKeyPool:
    """
    Asta API Key Pool Manager
    
    Features:
    1. Load multiple API keys from file
    2. Randomly rotate keys
    3. Track each key's usage status and errors
    4. Implement debounce retry strategy
    """
    
    def __init__(self, pool_path: Optional[str] = None, keys: Optional[List[str]] = None):
        """
        Initialize API Key Pool
        
        Args:
            pool_path: API keys file path (one key per line)
            keys: directly provide keys list (prior to pool_path)
        """
        self.keys: List[str] = []
        self.used_indices: List[int] = []  # indices used in current rotation
        self.key_status: Dict[str, Dict] = {}  # status information for each key
        self.lock = Lock()  # thread safe lock
        
        # load keys
        if keys:
            self.keys = [k.strip() for k in keys if k.strip()]
        elif os.environ.get("ASTA_API_KEY"):
            # Try to get one or more keys from environment variable (comma-separated)
            self.keys = [k.strip() for k in os.environ.get("ASTA_API_KEY").split(",") if k.strip()]
        elif pool_path:
            self._load_from_file(pool_path)
        else:
            raise ValueError(
                "No API keys available. Provide keys via pool_path, keys parameter, "
                "or ASTA_API_KEY environment variable."
            )
            
        if not self.keys:
            raise ValueError(
                "No API keys available. Provide keys via pool_path, keys parameter, "
                "or ASTA_API_KEY environment variable."
            )
        
        # initialize status for each key
        for key in self.keys:
            self.key_status[key] = {
                'error_count': 0,
                'last_error_time': None,
                'consecutive_errors': 0,
                'total_requests': 0,
                'successful_requests': 0,
            }
    
    def _load_from_file(self, pool_path: str):
        """Load API keys from file"""
        path = Path(pool_path)
        
        # if relative path, try to find file relative to shared/configs
        if not path.is_absolute():
            # try to find file relative to project root
            project_root = Path(__file__).parent.parent.parent
            path = project_root / "shared" / "configs" / pool_path
            if not path.exists():
                # try to find file relative to shared/configs
                path = Path(__file__).parent.parent / "configs" / pool_path
        
        if not path.exists():
            raise FileNotFoundError(
                f"API key pool file not found: {pool_path} (tried: {path})"
            )
        
        with open(path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        self.keys = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')]
        
        if not self.keys:
            raise ValueError(f"No valid API keys found in pool file: {pool_path}")
    
    def get_key(self) -> str:
        """
        Get next available API key (rotation strategy)
        
        Strategy:
        1. If current rotation is not complete, continue using unused keys
        2. If current rotation is complete, start a new round (reset used_indices)
        3. Prioritize keys with no recent errors
        
        Returns:
            Available API key
        """
        with self.lock:
            if not self.keys:
                raise ValueError("No API keys available in pool")
            
            # if current rotation is complete, start a new round
            if len(self.used_indices) >= len(self.keys):
                self.used_indices = []
            
            # get indices not used in current rotation
            available_indices = [i for i in range(len(self.keys)) if i not in self.used_indices]
            
            if not available_indices:
                # all keys are used in current rotation, start a new round
                available_indices = list(range(len(self.keys)))
                self.used_indices = []
            
            # prioritize keys with fewer errors (randomly select, but prioritize keys with higher success rate and fewer errors)
            key_scores = []
            for idx in available_indices:
                key = self.keys[idx]
                status = self.key_status[key]
                
                # calculate score: error count, success rate, score越高
                error_count = status['error_count']
                total = status['total_requests']
                success_rate = (status['successful_requests'] / total) if total > 0 else 1.0
                
                # if recent error, reduce score
                recent_error_penalty = 0
                if status['last_error_time']:
                    time_since_error = time.time() - status['last_error_time']
                    if time_since_error < 60:  # 1 minute
                        recent_error_penalty = 0.5
                
                score = success_rate - (error_count * 0.1) - recent_error_penalty
                key_scores.append((idx, score))
            
            # sort by score, select highest score (but add some randomness)
            key_scores.sort(key=lambda x: x[1], reverse=True)
            
            # select from top 50% (add randomness but prioritize better keys)
            top_n = max(1, len(key_scores) // 2) if len(key_scores) > 1 else 1
            selected_idx, _ = random.choice(key_scores[:top_n])
            
            # mark as used
            self.used_indices.append(selected_idx)
            
            selected_key = self.keys[selected_idx]
            self.key_status[selected_key]['total_requests'] += 1
            
            return selected_key
    
    def mark_success(self, key: str):
        """mark key as successful"""
        with self.lock:
            if key in self.key_status:
                self.key_status[key]['successful_requests'] += 1
                self.key_status[key]['consecutive_errors'] = 0
    
    def mark_error(self, key: str, error_type: str = "rate_limit"):
        """
        mark key as failed
        
        Args:
            key: failed API key
            error_type: error type ("rate_limit", "auth_error", "server_error", "other")
        """
        with self.lock:
            if key in self.key_status:
                status = self.key_status[key]
                status['error_count'] += 1
                status['consecutive_errors'] += 1
                status['last_error_time'] = time.time()
    
    def get_status(self) -> Dict:
        """get pool status information (for debugging)"""
        with self.lock:
            return {
                'total_keys': len(self.keys),
                'current_round_progress': f"{len(self.used_indices)}/{len(self.keys)}",
                'keys_status': {
                    key: {
                        'error_count': status['error_count'],
                        'successful_requests': status['successful_requests'],
                        'total_requests': status['total_requests'],
                        'success_rate': (
                            status['successful_requests'] / status['total_requests']
                            if status['total_requests'] > 0 else 0.0
                        ),
                        'consecutive_errors': status['consecutive_errors'],
                        'last_error_time': status['last_error_time'],
                    }
                    for key, status in self.key_status.items()
                }
            }
    
    def reset_round(self):
        """reset current rotation (force start a new round)"""
        with self.lock:
            self.used_indices = []