Mirrowel commited on
Commit
00b549c
·
1 Parent(s): 80dbe0b

feat: Implement global provider cooldown for IP rate limits

Browse files

Add a `CooldownManager` to handle IP-based rate limiting (HTTP 429 errors).
When a provider returns a 429 `RateLimitError`, it is put into a global cooldown state for a duration specified by `retry_after` or a default.
Subsequent requests to that provider will pause until the cooldown expires.

This prevents continuously hitting rate-limited IPs, improving reliability and reducing unnecessary retries.

src/rotator_library/client.py CHANGED
@@ -20,6 +20,7 @@ from .failure_logger import log_failure
20
  from .error_handler import classify_error, AllProviders
21
  from .providers import PROVIDER_PLUGINS
22
  from .request_sanitizer import sanitize_request_payload
 
23
 
24
  class StreamedAPIError(Exception):
25
  """Custom exception to signal an API error received over a stream."""
@@ -46,6 +47,7 @@ class RotatingClient:
46
  self._provider_instances = {}
47
  self.http_client = httpx.AsyncClient()
48
  self.all_providers = AllProviders()
 
49
 
50
  async def __aenter__(self):
51
  return self
@@ -190,6 +192,11 @@ class RotatingClient:
190
  current_key = None
191
  key_acquired = False
192
  try:
 
 
 
 
 
193
  keys_to_try = [k for k in keys_for_provider if k not in tried_keys]
194
  if not keys_to_try:
195
  break
@@ -236,14 +243,19 @@ class RotatingClient:
236
  key_acquired = False
237
  return response
238
 
239
- except (StreamedAPIError, APIConnectionError) as e:
240
- # These errors are caught to allow retrying with the next key.
241
  last_exception = e
242
  log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
243
  classified_error = classify_error(e)
 
 
 
 
 
 
244
  await self.usage_manager.record_failure(current_key, model, classified_error)
245
  lib_logger.warning(f"Key ...{current_key[-4:]} encountered '{classified_error.error_type}'. Trying next key.")
246
- break # Break from retry loop, try next key
247
 
248
  except Exception as e:
249
  last_exception = e
 
20
  from .error_handler import classify_error, AllProviders
21
  from .providers import PROVIDER_PLUGINS
22
  from .request_sanitizer import sanitize_request_payload
23
+ from .cooldown_manager import CooldownManager
24
 
25
  class StreamedAPIError(Exception):
26
  """Custom exception to signal an API error received over a stream."""
 
47
  self._provider_instances = {}
48
  self.http_client = httpx.AsyncClient()
49
  self.all_providers = AllProviders()
50
+ self.cooldown_manager = CooldownManager()
51
 
52
  async def __aenter__(self):
53
  return self
 
192
  current_key = None
193
  key_acquired = False
194
  try:
195
+ if await self.cooldown_manager.is_cooling_down(provider):
196
+ remaining_time = await self.cooldown_manager.get_cooldown_remaining(provider)
197
+ lib_logger.warning(f"Provider {provider} is in cooldown. Waiting for {remaining_time:.2f} seconds.")
198
+ await asyncio.sleep(remaining_time)
199
+
200
  keys_to_try = [k for k in keys_for_provider if k not in tried_keys]
201
  if not keys_to_try:
202
  break
 
243
  key_acquired = False
244
  return response
245
 
246
+ except (StreamedAPIError, APIConnectionError, litellm.RateLimitError) as e:
 
247
  last_exception = e
248
  log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
249
  classified_error = classify_error(e)
250
+
251
+ if classified_error.error_type == 'rate_limit' and classified_error.status_code == 429:
252
+ cooldown_duration = classified_error.retry_after or 60
253
+ await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
254
+ lib_logger.error(f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown.")
255
+
256
  await self.usage_manager.record_failure(current_key, model, classified_error)
257
  lib_logger.warning(f"Key ...{current_key[-4:]} encountered '{classified_error.error_type}'. Trying next key.")
258
+ break
259
 
260
  except Exception as e:
261
  last_exception = e
src/rotator_library/cooldown_manager.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import time
3
+ from typing import Dict
4
+
5
+ class CooldownManager:
6
+ """
7
+ Manages global cooldown periods for API providers to handle IP-based rate limiting.
8
+ This ensures that once a 429 error is received for a provider, all subsequent
9
+ requests to that provider are paused for a specified duration.
10
+ """
11
+ def __init__(self):
12
+ self._cooldowns: Dict[str, float] = {}
13
+ self._lock = asyncio.Lock()
14
+
15
+ async def is_cooling_down(self, provider: str) -> bool:
16
+ """Checks if a provider is currently in a cooldown period."""
17
+ async with self._lock:
18
+ return provider in self._cooldowns and time.time() < self._cooldowns[provider]
19
+
20
+ async def start_cooldown(self, provider: str, duration: int):
21
+ """
22
+ Initiates or extends a cooldown period for a provider.
23
+ The cooldown is set to the current time plus the specified duration.
24
+ """
25
+ async with self._lock:
26
+ self._cooldowns[provider] = time.time() + duration
27
+
28
+ async def get_cooldown_remaining(self, provider: str) -> float:
29
+ """
30
+ Returns the remaining cooldown time in seconds for a provider.
31
+ Returns 0 if the provider is not in a cooldown period.
32
+ """
33
+ async with self._lock:
34
+ if provider in self._cooldowns:
35
+ remaining = self._cooldowns[provider] - time.time()
36
+ return max(0, remaining)
37
+ return 0