Spaces:
Running
Running
File size: 8,064 Bytes
90c099b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | """
Asta API Key Pool Manager
Manage multiple Asta API keys, implement key rotation and error handling.
"""
import os
import random
import time
from pathlib import Path
from typing import List, Optional, Dict
from threading import Lock
class AstaAPIKeyPool:
"""
Asta API Key Pool Manager
Features:
1. Load multiple API keys from file
2. Randomly rotate keys
3. Track each key's usage status and errors
4. Implement debounce retry strategy
"""
def __init__(self, pool_path: Optional[str] = None, keys: Optional[List[str]] = None):
"""
Initialize API Key Pool
Args:
pool_path: API keys file path (one key per line)
keys: directly provide keys list (prior to pool_path)
"""
self.keys: List[str] = []
self.used_indices: List[int] = [] # indices used in current rotation
self.key_status: Dict[str, Dict] = {} # status information for each key
self.lock = Lock() # thread safe lock
# load keys
if keys:
self.keys = [k.strip() for k in keys if k.strip()]
elif os.environ.get("ASTA_API_KEY"):
# Try to get one or more keys from environment variable (comma-separated)
self.keys = [k.strip() for k in os.environ.get("ASTA_API_KEY").split(",") if k.strip()]
elif pool_path:
self._load_from_file(pool_path)
else:
raise ValueError(
"No API keys available. Provide keys via pool_path, keys parameter, "
"or ASTA_API_KEY environment variable."
)
if not self.keys:
raise ValueError(
"No API keys available. Provide keys via pool_path, keys parameter, "
"or ASTA_API_KEY environment variable."
)
# initialize status for each key
for key in self.keys:
self.key_status[key] = {
'error_count': 0,
'last_error_time': None,
'consecutive_errors': 0,
'total_requests': 0,
'successful_requests': 0,
}
def _load_from_file(self, pool_path: str):
"""Load API keys from file"""
path = Path(pool_path)
# if relative path, try to find file relative to shared/configs
if not path.is_absolute():
# try to find file relative to project root
project_root = Path(__file__).parent.parent.parent
path = project_root / "shared" / "configs" / pool_path
if not path.exists():
# try to find file relative to shared/configs
path = Path(__file__).parent.parent / "configs" / pool_path
if not path.exists():
raise FileNotFoundError(
f"API key pool file not found: {pool_path} (tried: {path})"
)
with open(path, 'r', encoding='utf-8') as f:
lines = f.readlines()
self.keys = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')]
if not self.keys:
raise ValueError(f"No valid API keys found in pool file: {pool_path}")
def get_key(self) -> str:
"""
Get next available API key (rotation strategy)
Strategy:
1. If current rotation is not complete, continue using unused keys
2. If current rotation is complete, start a new round (reset used_indices)
3. Prioritize keys with no recent errors
Returns:
Available API key
"""
with self.lock:
if not self.keys:
raise ValueError("No API keys available in pool")
# if current rotation is complete, start a new round
if len(self.used_indices) >= len(self.keys):
self.used_indices = []
# get indices not used in current rotation
available_indices = [i for i in range(len(self.keys)) if i not in self.used_indices]
if not available_indices:
# all keys are used in current rotation, start a new round
available_indices = list(range(len(self.keys)))
self.used_indices = []
# prioritize keys with fewer errors (randomly select, but prioritize keys with higher success rate and fewer errors)
key_scores = []
for idx in available_indices:
key = self.keys[idx]
status = self.key_status[key]
# calculate score: error count, success rate, score越高
error_count = status['error_count']
total = status['total_requests']
success_rate = (status['successful_requests'] / total) if total > 0 else 1.0
# if recent error, reduce score
recent_error_penalty = 0
if status['last_error_time']:
time_since_error = time.time() - status['last_error_time']
if time_since_error < 60: # 1 minute
recent_error_penalty = 0.5
score = success_rate - (error_count * 0.1) - recent_error_penalty
key_scores.append((idx, score))
# sort by score, select highest score (but add some randomness)
key_scores.sort(key=lambda x: x[1], reverse=True)
# select from top 50% (add randomness but prioritize better keys)
top_n = max(1, len(key_scores) // 2) if len(key_scores) > 1 else 1
selected_idx, _ = random.choice(key_scores[:top_n])
# mark as used
self.used_indices.append(selected_idx)
selected_key = self.keys[selected_idx]
self.key_status[selected_key]['total_requests'] += 1
return selected_key
def mark_success(self, key: str):
"""mark key as successful"""
with self.lock:
if key in self.key_status:
self.key_status[key]['successful_requests'] += 1
self.key_status[key]['consecutive_errors'] = 0
def mark_error(self, key: str, error_type: str = "rate_limit"):
"""
mark key as failed
Args:
key: failed API key
error_type: error type ("rate_limit", "auth_error", "server_error", "other")
"""
with self.lock:
if key in self.key_status:
status = self.key_status[key]
status['error_count'] += 1
status['consecutive_errors'] += 1
status['last_error_time'] = time.time()
def get_status(self) -> Dict:
"""get pool status information (for debugging)"""
with self.lock:
return {
'total_keys': len(self.keys),
'current_round_progress': f"{len(self.used_indices)}/{len(self.keys)}",
'keys_status': {
key: {
'error_count': status['error_count'],
'successful_requests': status['successful_requests'],
'total_requests': status['total_requests'],
'success_rate': (
status['successful_requests'] / status['total_requests']
if status['total_requests'] > 0 else 0.0
),
'consecutive_errors': status['consecutive_errors'],
'last_error_time': status['last_error_time'],
}
for key, status in self.key_status.items()
}
}
def reset_round(self):
"""reset current rotation (force start a new round)"""
with self.lock:
self.used_indices = []
|