Mirrowel commited on
Commit
f8d59cd
·
1 Parent(s): 1d838ea

perf(rotator): prevent indefinite waits for API key acquisition

Browse files

Introduces a timeout for API key acquisition and enhances streaming error handling.

- Implements a `wait_timeout` in `UsageManager.acquire_key` to prevent indefinite
waiting when all keys are busy or on cooldown.
- Introduces `NoAvailableKeysError` to explicitly signal when no key can be
acquired within the timeout.
- Modifies `RotatingClient.acompletion` to gracefully handle `NoAvailableKeysError`
and other general exceptions by yielding structured error chunks for streaming
clients instead of raising exceptions directly.
- Adjusts key release logic in `RotatingClient` to ensure proper resource management.
- Refines logging for key acquisition wait re-evaluations.

src/rotator_library/client.py CHANGED
@@ -17,7 +17,7 @@ lib_logger.propagate = False
17
 
18
  from .usage_manager import UsageManager
19
  from .failure_logger import log_failure
20
- from .error_handler import classify_error, AllProviders
21
  from .providers import PROVIDER_PLUGINS
22
  from .request_sanitizer import sanitize_request_payload
23
  from .cooldown_manager import CooldownManager
@@ -313,121 +313,125 @@ class RotatingClient:
313
  keys_for_provider = self.api_keys[provider]
314
  tried_keys = set()
315
  last_exception = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
- while len(tried_keys) < len(keys_for_provider):
318
- current_key = None
319
- key_acquired = False
320
- try:
321
- if await self.cooldown_manager.is_cooling_down(provider):
322
- remaining_time = await self.cooldown_manager.get_cooldown_remaining(provider)
323
- lib_logger.warning(f"Provider {provider} is in cooldown. Waiting for {remaining_time:.2f} seconds.")
324
- await asyncio.sleep(remaining_time)
325
-
326
- keys_to_try = [k for k in keys_for_provider if k not in tried_keys]
327
- if not keys_to_try:
328
- break
329
-
330
- current_key = await self.usage_manager.acquire_key(available_keys=keys_to_try, model=model)
331
- key_acquired = True
332
- tried_keys.add(current_key)
333
-
334
- # --- Full Request Preparation Logic ---
335
- litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
336
- provider_instance = self._get_provider_instance(provider)
337
- if provider_instance:
338
- if "safety_settings" in litellm_kwargs:
339
- converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
340
- if converted_settings is not None:
341
- litellm_kwargs["safety_settings"] = converted_settings
342
- else:
343
- del litellm_kwargs["safety_settings"]
344
-
345
- if provider == "gemini" and provider_instance:
346
- provider_instance.handle_thinking_parameter(litellm_kwargs, model)
347
-
348
- if "gemma-3" in model and "messages" in litellm_kwargs:
349
- litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
350
-
351
- litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
352
- # --- End of Request Preparation ---
353
-
354
- for attempt in range(self.max_retries):
355
- try:
356
- lib_logger.info(f"Attempting stream with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
357
- response = await litellm.acompletion(api_key=current_key, **litellm_kwargs)
358
-
359
- key_acquired = False # Wrapper now handles the key release
360
- stream_generator = self._safe_streaming_wrapper(response, current_key, model, request)
361
-
362
- async for chunk in stream_generator:
363
- yield chunk
364
- return # Successful stream, exit the entire retry mechanism
365
-
366
- except (StreamedAPIError, litellm.RateLimitError) as e:
367
- last_exception = e
368
- log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
369
- classified_error = classify_error(e)
370
- error_message = str(e).split('\n')[0]
371
- lib_logger.warning(f"Key ...{current_key[-4:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message}. Rotating key.")
372
-
373
- if classified_error.error_type == 'rate_limit' and classified_error.status_code == 429:
374
- cooldown_duration = classified_error.retry_after or 60
375
- await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
376
- lib_logger.warning(f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown.")
377
-
378
- await self.usage_manager.record_failure(current_key, model, classified_error)
379
- lib_logger.info(f"Key ...{current_key[-4:]} failed during stream initiation. Trying next key.")
380
- break # Break inner loop to try next key
381
-
382
- except (APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError) as e:
383
- last_exception = e
384
- log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
385
- classified_error = classify_error(e)
386
- await self.usage_manager.record_failure(current_key, model, classified_error)
387
-
388
- if attempt >= self.max_retries - 1:
389
- error_message = str(e).split('\n')[0]
390
- lib_logger.warning(f"Key ...{current_key[-4:]} failed after {self.max_retries} retries with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message}. Rotating key.")
391
- break # Move to the next key
392
-
393
- wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
394
- error_message = str(e).split('\n')[0]
395
- lib_logger.warning(f"Key ...{current_key[-4:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message}. Retrying in {wait_time:.2f} seconds.")
396
- await asyncio.sleep(wait_time)
397
- continue # Retry with the same key
398
-
399
- except Exception as e:
400
- last_exception = e
401
- log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
402
- classified_error = classify_error(e)
403
- error_message = str(e).split('\n')[0]
404
- lib_logger.warning(f"Key ...{current_key[-4:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message}. Rotating key.")
405
-
406
- if classified_error.status_code == 429:
407
- cooldown_duration = classified_error.retry_after or 60
408
- await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
409
- lib_logger.warning(f"IP-based rate limit detected for {provider} from generic stream exception. Starting a {cooldown_duration}-second global cooldown.")
410
-
411
- if classified_error.error_type in ['invalid_request', 'context_window_exceeded', 'authentication']:
412
- raise last_exception # Do not retry for these errors
413
-
414
- await self.usage_manager.record_failure(current_key, model, classified_error)
415
- break # Try next key for other errors
416
 
417
- finally:
418
- if key_acquired and current_key:
419
- await self.usage_manager.release_key(current_key, model)
420
-
421
- if last_exception:
422
- # After trying all keys, if an exception was caught, we need to inform the client.
423
- # We can't raise it directly as the stream is already open.
424
- # Instead, we yield a final error message.
425
- error_data = {"error": {"message": f"Failed to complete the streaming request after trying all keys. Last error: {str(last_exception)}", "type": "proxy_error"}}
426
  yield f"data: {json.dumps(error_data)}\n\n"
427
  yield "data: [DONE]\n\n"
428
- else:
429
- # If all keys were tried and none succeeded (e.g., all were busy), raise a generic error.
430
- raise Exception("Failed to complete the streaming request: No available API keys for the provider or all keys failed.")
431
 
432
  def acompletion(self, request: Optional[Any] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
433
  """Dispatcher for completion requests."""
 
17
 
18
  from .usage_manager import UsageManager
19
  from .failure_logger import log_failure
20
+ from .error_handler import classify_error, AllProviders, NoAvailableKeysError
21
  from .providers import PROVIDER_PLUGINS
22
  from .request_sanitizer import sanitize_request_payload
23
  from .cooldown_manager import CooldownManager
 
313
  keys_for_provider = self.api_keys[provider]
314
  tried_keys = set()
315
  last_exception = None
316
+ try:
317
+ while len(tried_keys) < len(keys_for_provider):
318
+ current_key = None
319
+ key_acquired = False
320
+ try:
321
+ if await self.cooldown_manager.is_cooling_down(provider):
322
+ remaining_time = await self.cooldown_manager.get_cooldown_remaining(provider)
323
+ lib_logger.warning(f"Provider {provider} is in cooldown. Waiting for {remaining_time:.2f} seconds.")
324
+ await asyncio.sleep(remaining_time)
325
+
326
+ keys_to_try = [k for k in keys_for_provider if k not in tried_keys]
327
+ if not keys_to_try:
328
+ lib_logger.warning(f"All keys for provider {provider} have been tried. No more keys to rotate to.")
329
+ break
330
+
331
+ lib_logger.info(f"Acquiring key for model {model}. Tried keys: {len(tried_keys)}/{len(keys_for_provider)}")
332
+ current_key = await self.usage_manager.acquire_key(available_keys=keys_to_try, model=model)
333
+ key_acquired = True
334
+ tried_keys.add(current_key)
335
+
336
+ litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
337
+ provider_instance = self._get_provider_instance(provider)
338
+ if provider_instance:
339
+ if "safety_settings" in litellm_kwargs:
340
+ converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
341
+ if converted_settings is not None:
342
+ litellm_kwargs["safety_settings"] = converted_settings
343
+ else:
344
+ del litellm_kwargs["safety_settings"]
345
+
346
+ if provider == "gemini" and provider_instance:
347
+ provider_instance.handle_thinking_parameter(litellm_kwargs, model)
348
 
349
+ if "gemma-3" in model and "messages" in litellm_kwargs:
350
+ litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
351
+
352
+ litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
353
+
354
+ for attempt in range(self.max_retries):
355
+ try:
356
+ lib_logger.info(f"Attempting stream with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
357
+ response = await litellm.acompletion(api_key=current_key, **litellm_kwargs)
358
+
359
+ key_acquired = False
360
+ stream_generator = self._safe_streaming_wrapper(response, current_key, model, request)
361
+
362
+ async for chunk in stream_generator:
363
+ yield chunk
364
+ return
365
+
366
+ except (StreamedAPIError, litellm.RateLimitError) as e:
367
+ last_exception = e
368
+ log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
369
+ classified_error = classify_error(e)
370
+ lib_logger.warning(f"Key ...{current_key[-4:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {str(e)}. Rotating key.")
371
+
372
+ if classified_error.error_type == 'rate_limit' and classified_error.status_code == 429:
373
+ cooldown_duration = classified_error.retry_after or 60
374
+ await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
375
+ lib_logger.warning(f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown.")
376
+
377
+ await self.usage_manager.record_failure(current_key, model, classified_error)
378
+ break
379
+
380
+ except (APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError) as e:
381
+ last_exception = e
382
+ log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
383
+ classified_error = classify_error(e)
384
+ await self.usage_manager.record_failure(current_key, model, classified_error)
385
+
386
+ if attempt >= self.max_retries - 1:
387
+ lib_logger.warning(f"Key ...{current_key[-4:]} failed after {self.max_retries} retries with {classified_error.error_type}. Rotating key.")
388
+ break
389
+
390
+ wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
391
+ lib_logger.warning(f"Key ...{current_key[-4:]} failed with {classified_error.error_type}. Retrying in {wait_time:.2f} seconds.")
392
+ await asyncio.sleep(wait_time)
393
+ continue
394
+
395
+ except Exception as e:
396
+ last_exception = e
397
+ log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
398
+ classified_error = classify_error(e)
399
+ lib_logger.warning(f"Key ...{current_key[-4:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {str(e)}. Rotating key.")
400
+
401
+ if classified_error.status_code == 429:
402
+ cooldown_duration = classified_error.retry_after or 60
403
+ await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
404
+ lib_logger.warning(f"IP-based rate limit detected for {provider} from generic stream exception. Starting a {cooldown_duration}-second global cooldown.")
405
+
406
+ if classified_error.error_type in ['invalid_request', 'context_window_exceeded', 'authentication']:
407
+ raise last_exception
408
+
409
+ await self.usage_manager.record_failure(current_key, model, classified_error)
410
+ break
411
+
412
+ finally:
413
+ if key_acquired and current_key:
414
+ await self.usage_manager.release_key(current_key, model)
415
+
416
+ if last_exception:
417
+ error_data = {"error": {"message": f"Failed to complete the streaming request. Last error: {str(last_exception)}", "type": "proxy_error"}}
418
+ yield f"data: {json.dumps(error_data)}\n\n"
419
+ else:
420
+ error_data = {"error": {"message": "Failed to complete the streaming request: No available API keys after rotation.", "type": "proxy_error"}}
421
+ yield f"data: {json.dumps(error_data)}\n\n"
422
+
423
+ yield "data: [DONE]\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
+ except NoAvailableKeysError as e:
426
+ lib_logger.error(f"A streaming request failed because no keys were available: {e}")
427
+ error_data = {"error": {"message": str(e), "type": "proxy_busy"}}
428
+ yield f"data: {json.dumps(error_data)}\n\n"
429
+ yield "data: [DONE]\n\n"
430
+ except Exception as e:
431
+ lib_logger.error(f"An unhandled exception occurred in streaming retry logic: {e}")
432
+ error_data = {"error": {"message": f"An unexpected error occurred: {str(e)}", "type": "proxy_internal_error"}}
 
433
  yield f"data: {json.dumps(error_data)}\n\n"
434
  yield "data: [DONE]\n\n"
 
 
 
435
 
436
  def acompletion(self, request: Optional[Any] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
437
  """Dispatcher for completion requests."""
src/rotator_library/error_handler.py CHANGED
@@ -3,6 +3,10 @@ from typing import Optional, Dict, Any
3
 
4
  from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError, AuthenticationError, InvalidRequestError, BadRequestError, OpenAIError, InternalServerError, Timeout, ContextWindowExceededError
5
 
 
 
 
 
6
  class ClassifiedError:
7
  """A structured representation of a classified error."""
8
  def __init__(self, error_type: str, original_exception: Exception, status_code: Optional[int] = None, retry_after: Optional[int] = None):
 
3
 
4
  from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError, AuthenticationError, InvalidRequestError, BadRequestError, OpenAIError, InternalServerError, Timeout, ContextWindowExceededError
5
 
6
+ class NoAvailableKeysError(Exception):
7
+ """Raised when no API keys are available for a request after waiting."""
8
+ pass
9
+
10
  class ClassifiedError:
11
  """A structured representation of a classified error."""
12
  def __init__(self, error_type: str, original_exception: Exception, status_code: Optional[int] = None, retry_after: Optional[int] = None):
src/rotator_library/usage_manager.py CHANGED
@@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional, Set
8
  import aiofiles
9
  import litellm
10
 
11
- from .error_handler import ClassifiedError
12
 
13
  lib_logger = logging.getLogger('rotator_library')
14
  lib_logger.propagate = False
@@ -136,14 +136,14 @@ class UsageManager:
136
  await self._lazy_init()
137
  self._initialize_key_states(available_keys)
138
 
139
- while True:
 
140
  tier1_keys, tier2_keys = [], []
141
  async with self._data_lock:
142
  now = time.time()
143
  for key in available_keys:
144
  key_data = self._usage_data.get(key, {})
145
 
146
- # Skip keys on global or model-specific cooldown
147
  if (key_data.get("key_cooldown_until") or 0) > now or \
148
  (key_data.get("model_cooldowns", {}).get(model) or 0) > now:
149
  continue
@@ -156,11 +156,9 @@ class UsageManager:
156
  elif model not in key_state["models_in_use"]:
157
  tier2_keys.append((key, usage_count))
158
 
159
- # Sort keys by usage count (ascending)
160
  tier1_keys.sort(key=lambda x: x[1])
161
  tier2_keys.sort(key=lambda x: x[1])
162
 
163
- # Attempt to acquire from Tier 1 (completely free)
164
  for key, _ in tier1_keys:
165
  state = self.key_states[key]
166
  async with state["lock"]:
@@ -169,7 +167,6 @@ class UsageManager:
169
  lib_logger.info(f"Acquired Tier 1 key ...{key[-4:]} for model {model}")
170
  return key
171
 
172
- # Attempt to acquire from Tier 2 (in use by other models)
173
  for key, _ in tier2_keys:
174
  state = self.key_states[key]
175
  async with state["lock"]:
@@ -178,26 +175,28 @@ class UsageManager:
178
  lib_logger.info(f"Acquired Tier 2 key ...{key[-4:]} for model {model}")
179
  return key
180
 
181
- # If no key is available, wait for one to be released
182
  lib_logger.info("All eligible keys are currently locked for this model. Waiting...")
183
 
184
- # Create a combined list of all potentially usable keys to wait on
185
  all_potential_keys = tier1_keys + tier2_keys
186
  if not all_potential_keys:
187
- lib_logger.warning("No keys are eligible at all (all on cooldown). Waiting before re-evaluating.")
188
- await asyncio.sleep(5)
189
  continue
190
 
191
- # Wait on the condition of the best available key
192
  best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0]
193
  wait_condition = self.key_states[best_wait_key]["condition"]
194
 
195
  try:
196
  async with wait_condition:
197
- await asyncio.wait_for(wait_condition.wait(), timeout=self.wait_timeout)
 
 
 
198
  lib_logger.info("Notified that a key was released. Re-evaluating...")
199
  except asyncio.TimeoutError:
200
- lib_logger.warning("Wait timed out. Re-evaluating for any available key.")
 
 
201
 
202
 
203
  async def release_key(self, key: str, model: str):
 
8
  import aiofiles
9
  import litellm
10
 
11
+ from .error_handler import ClassifiedError, NoAvailableKeysError
12
 
13
  lib_logger = logging.getLogger('rotator_library')
14
  lib_logger.propagate = False
 
136
  await self._lazy_init()
137
  self._initialize_key_states(available_keys)
138
 
139
+ start_time = time.time()
140
+ while time.time() - start_time < self.wait_timeout:
141
  tier1_keys, tier2_keys = [], []
142
  async with self._data_lock:
143
  now = time.time()
144
  for key in available_keys:
145
  key_data = self._usage_data.get(key, {})
146
 
 
147
  if (key_data.get("key_cooldown_until") or 0) > now or \
148
  (key_data.get("model_cooldowns", {}).get(model) or 0) > now:
149
  continue
 
156
  elif model not in key_state["models_in_use"]:
157
  tier2_keys.append((key, usage_count))
158
 
 
159
  tier1_keys.sort(key=lambda x: x[1])
160
  tier2_keys.sort(key=lambda x: x[1])
161
 
 
162
  for key, _ in tier1_keys:
163
  state = self.key_states[key]
164
  async with state["lock"]:
 
167
  lib_logger.info(f"Acquired Tier 1 key ...{key[-4:]} for model {model}")
168
  return key
169
 
 
170
  for key, _ in tier2_keys:
171
  state = self.key_states[key]
172
  async with state["lock"]:
 
175
  lib_logger.info(f"Acquired Tier 2 key ...{key[-4:]} for model {model}")
176
  return key
177
 
 
178
  lib_logger.info("All eligible keys are currently locked for this model. Waiting...")
179
 
 
180
  all_potential_keys = tier1_keys + tier2_keys
181
  if not all_potential_keys:
182
+ lib_logger.warning("No keys are eligible (all on cooldown). Waiting before re-evaluating.")
183
+ await asyncio.sleep(1)
184
  continue
185
 
 
186
  best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0]
187
  wait_condition = self.key_states[best_wait_key]["condition"]
188
 
189
  try:
190
  async with wait_condition:
191
+ remaining_timeout = self.wait_timeout - (time.time() - start_time)
192
+ if remaining_timeout <= 0:
193
+ break
194
+ await asyncio.wait_for(wait_condition.wait(), timeout=min(1, remaining_timeout))
195
  lib_logger.info("Notified that a key was released. Re-evaluating...")
196
  except asyncio.TimeoutError:
197
+ lib_logger.debug("Wait timed out. Re-evaluating for any available key.")
198
+
199
+ raise NoAvailableKeysError(f"Could not acquire a key for model {model} within the {self.wait_timeout}s timeout.")
200
 
201
 
202
  async def release_key(self, key: str, model: str):