Spaces:
Paused
Paused
Mirrowel commited on
Commit ·
00b549c
1
Parent(s): 80dbe0b
feat: Implement global provider cooldown for IP rate limits
Browse filesAdd a `CooldownManager` to handle IP-based rate limiting (HTTP 429 errors).
When a provider returns a 429 `RateLimitError`, it is put into a global cooldown state for a duration specified by `retry_after` or a default.
Subsequent requests to that provider will pause until the cooldown expires.
This prevents continuously hitting rate-limited IPs, improving reliability and reducing unnecessary retries.
src/rotator_library/client.py
CHANGED
|
@@ -20,6 +20,7 @@ from .failure_logger import log_failure
|
|
| 20 |
from .error_handler import classify_error, AllProviders
|
| 21 |
from .providers import PROVIDER_PLUGINS
|
| 22 |
from .request_sanitizer import sanitize_request_payload
|
|
|
|
| 23 |
|
| 24 |
class StreamedAPIError(Exception):
|
| 25 |
"""Custom exception to signal an API error received over a stream."""
|
|
@@ -46,6 +47,7 @@ class RotatingClient:
|
|
| 46 |
self._provider_instances = {}
|
| 47 |
self.http_client = httpx.AsyncClient()
|
| 48 |
self.all_providers = AllProviders()
|
|
|
|
| 49 |
|
| 50 |
async def __aenter__(self):
|
| 51 |
return self
|
|
@@ -190,6 +192,11 @@ class RotatingClient:
|
|
| 190 |
current_key = None
|
| 191 |
key_acquired = False
|
| 192 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
keys_to_try = [k for k in keys_for_provider if k not in tried_keys]
|
| 194 |
if not keys_to_try:
|
| 195 |
break
|
|
@@ -236,14 +243,19 @@ class RotatingClient:
|
|
| 236 |
key_acquired = False
|
| 237 |
return response
|
| 238 |
|
| 239 |
-
except (StreamedAPIError, APIConnectionError) as e:
|
| 240 |
-
# These errors are caught to allow retrying with the next key.
|
| 241 |
last_exception = e
|
| 242 |
log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
|
| 243 |
classified_error = classify_error(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
await self.usage_manager.record_failure(current_key, model, classified_error)
|
| 245 |
lib_logger.warning(f"Key ...{current_key[-4:]} encountered '{classified_error.error_type}'. Trying next key.")
|
| 246 |
-
break
|
| 247 |
|
| 248 |
except Exception as e:
|
| 249 |
last_exception = e
|
|
|
|
| 20 |
from .error_handler import classify_error, AllProviders
|
| 21 |
from .providers import PROVIDER_PLUGINS
|
| 22 |
from .request_sanitizer import sanitize_request_payload
|
| 23 |
+
from .cooldown_manager import CooldownManager
|
| 24 |
|
| 25 |
class StreamedAPIError(Exception):
|
| 26 |
"""Custom exception to signal an API error received over a stream."""
|
|
|
|
| 47 |
self._provider_instances = {}
|
| 48 |
self.http_client = httpx.AsyncClient()
|
| 49 |
self.all_providers = AllProviders()
|
| 50 |
+
self.cooldown_manager = CooldownManager()
|
| 51 |
|
| 52 |
async def __aenter__(self):
|
| 53 |
return self
|
|
|
|
| 192 |
current_key = None
|
| 193 |
key_acquired = False
|
| 194 |
try:
|
| 195 |
+
if await self.cooldown_manager.is_cooling_down(provider):
|
| 196 |
+
remaining_time = await self.cooldown_manager.get_cooldown_remaining(provider)
|
| 197 |
+
lib_logger.warning(f"Provider {provider} is in cooldown. Waiting for {remaining_time:.2f} seconds.")
|
| 198 |
+
await asyncio.sleep(remaining_time)
|
| 199 |
+
|
| 200 |
keys_to_try = [k for k in keys_for_provider if k not in tried_keys]
|
| 201 |
if not keys_to_try:
|
| 202 |
break
|
|
|
|
| 243 |
key_acquired = False
|
| 244 |
return response
|
| 245 |
|
| 246 |
+
except (StreamedAPIError, APIConnectionError, litellm.RateLimitError) as e:
|
|
|
|
| 247 |
last_exception = e
|
| 248 |
log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
|
| 249 |
classified_error = classify_error(e)
|
| 250 |
+
|
| 251 |
+
if classified_error.error_type == 'rate_limit' and classified_error.status_code == 429:
|
| 252 |
+
cooldown_duration = classified_error.retry_after or 60
|
| 253 |
+
await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
|
| 254 |
+
lib_logger.error(f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown.")
|
| 255 |
+
|
| 256 |
await self.usage_manager.record_failure(current_key, model, classified_error)
|
| 257 |
lib_logger.warning(f"Key ...{current_key[-4:]} encountered '{classified_error.error_type}'. Trying next key.")
|
| 258 |
+
break
|
| 259 |
|
| 260 |
except Exception as e:
|
| 261 |
last_exception = e
|
src/rotator_library/cooldown_manager.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import time
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
class CooldownManager:
|
| 6 |
+
"""
|
| 7 |
+
Manages global cooldown periods for API providers to handle IP-based rate limiting.
|
| 8 |
+
This ensures that once a 429 error is received for a provider, all subsequent
|
| 9 |
+
requests to that provider are paused for a specified duration.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self._cooldowns: Dict[str, float] = {}
|
| 13 |
+
self._lock = asyncio.Lock()
|
| 14 |
+
|
| 15 |
+
async def is_cooling_down(self, provider: str) -> bool:
|
| 16 |
+
"""Checks if a provider is currently in a cooldown period."""
|
| 17 |
+
async with self._lock:
|
| 18 |
+
return provider in self._cooldowns and time.time() < self._cooldowns[provider]
|
| 19 |
+
|
| 20 |
+
async def start_cooldown(self, provider: str, duration: int):
|
| 21 |
+
"""
|
| 22 |
+
Initiates or extends a cooldown period for a provider.
|
| 23 |
+
The cooldown is set to the current time plus the specified duration.
|
| 24 |
+
"""
|
| 25 |
+
async with self._lock:
|
| 26 |
+
self._cooldowns[provider] = time.time() + duration
|
| 27 |
+
|
| 28 |
+
async def get_cooldown_remaining(self, provider: str) -> float:
|
| 29 |
+
"""
|
| 30 |
+
Returns the remaining cooldown time in seconds for a provider.
|
| 31 |
+
Returns 0 if the provider is not in a cooldown period.
|
| 32 |
+
"""
|
| 33 |
+
async with self._lock:
|
| 34 |
+
if provider in self._cooldowns:
|
| 35 |
+
remaining = self._cooldowns[provider] - time.time()
|
| 36 |
+
return max(0, remaining)
|
| 37 |
+
return 0
|