Mirrowel commited on
Commit
da8d5f4
·
1 Parent(s): c55fc2a

feat(client): add pre-request callback support to API client

Browse files

Introduce optional asynchronous callbacks that execute before each API request attempt, allowing for custom preprocessing such as validation or logging.

- Includes a new `PreRequestCallbackError` exception for handling callback failures.
- Adds `abort_on_callback_error` parameter to control whether exceptions in callbacks stop the request or are logged as warnings.
- Extends support to both completion and embedding methods, with detailed docstrings for usage.

This feature enables users to inject custom logic into the request pipeline without modifying core client behavior.

src/rotator_library/client.py CHANGED
@@ -20,7 +20,7 @@ lib_logger.propagate = False
20
 
21
  from .usage_manager import UsageManager
22
  from .failure_logger import log_failure
23
- from .error_handler import classify_error, AllProviders, NoAvailableKeysError
24
  from .providers import PROVIDER_PLUGINS
25
  from .request_sanitizer import sanitize_request_payload
26
  from .cooldown_manager import CooldownManager
@@ -36,7 +36,7 @@ class RotatingClient:
36
  A client that intelligently rotates and retries API keys using LiteLLM,
37
  with support for both streaming and non-streaming responses.
38
  """
39
- def __init__(self, api_keys: Dict[str, List[str]], max_retries: int = 2, usage_file_path: str = "key_usage.json", configure_logging: bool = True, global_timeout: int = 30):
40
  os.environ["LITELLM_LOG"] = "ERROR"
41
  litellm.set_verbose = False
42
  litellm.drop_params = True
@@ -56,6 +56,7 @@ class RotatingClient:
56
  self.api_keys = api_keys
57
  self.max_retries = max_retries
58
  self.global_timeout = global_timeout
 
59
  self.usage_manager = UsageManager(file_path=usage_file_path)
60
  self._model_list_cache = {}
61
  self._provider_plugins = PROVIDER_PLUGINS
@@ -300,7 +301,7 @@ class RotatingClient:
300
  if stream_completed and (not request or not await request.is_disconnected()):
301
  yield "data: [DONE]\n\n"
302
 
303
- async def _execute_with_retry(self, api_call: callable, request: Optional[Any], **kwargs) -> Any:
304
  """A generic retry mechanism for non-streaming API calls."""
305
  model = kwargs.get("model")
306
  if not model:
@@ -374,6 +375,16 @@ class RotatingClient:
374
  for attempt in range(self.max_retries):
375
  try:
376
  lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
 
 
 
 
 
 
 
 
 
 
377
  response = await api_call(
378
  api_key=current_key,
379
  **litellm_kwargs,
@@ -462,7 +473,7 @@ class RotatingClient:
462
  # Return None to indicate failure without propagating a disruptive exception.
463
  return None
464
 
465
- async def _streaming_acompletion_with_retry(self, request: Optional[Any], **kwargs) -> AsyncGenerator[str, None]:
466
  """A dedicated generator for retrying streaming completions with full request preparation and per-key retries."""
467
  model = kwargs.get("model")
468
  provider = model.split('/')[0]
@@ -527,6 +538,16 @@ class RotatingClient:
527
  for attempt in range(self.max_retries):
528
  try:
529
  lib_logger.info(f"Attempting stream with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
 
 
 
 
 
 
 
 
 
 
530
  response = await litellm.acompletion(
531
  api_key=current_key,
532
  **litellm_kwargs,
@@ -700,16 +721,40 @@ class RotatingClient:
700
  yield f"data: {json.dumps(error_data)}\n\n"
701
  yield "data: [DONE]\n\n"
702
 
703
- def acompletion(self, request: Optional[Any] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
704
- """Dispatcher for completion requests."""
 
 
 
 
 
 
 
 
 
 
 
 
705
  if kwargs.get("stream"):
706
- return self._streaming_acompletion_with_retry(request, **kwargs)
707
  else:
708
- return self._execute_with_retry(litellm.acompletion, request, **kwargs)
709
 
710
- def aembedding(self, request: Optional[Any] = None, **kwargs) -> Any:
711
- """Executes an embedding request with retry logic."""
712
- return self._execute_with_retry(litellm.aembedding, request, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
713
 
714
  def token_count(self, **kwargs) -> int:
715
  """Calculates the number of tokens for a given text or list of messages."""
 
20
 
21
  from .usage_manager import UsageManager
22
  from .failure_logger import log_failure
23
+ from .error_handler import PreRequestCallbackError, classify_error, AllProviders, NoAvailableKeysError
24
  from .providers import PROVIDER_PLUGINS
25
  from .request_sanitizer import sanitize_request_payload
26
  from .cooldown_manager import CooldownManager
 
36
  A client that intelligently rotates and retries API keys using LiteLLM,
37
  with support for both streaming and non-streaming responses.
38
  """
39
+ def __init__(self, api_keys: Dict[str, List[str]], max_retries: int = 2, usage_file_path: str = "key_usage.json", configure_logging: bool = True, global_timeout: int = 30, abort_on_callback_error: bool = True):
40
  os.environ["LITELLM_LOG"] = "ERROR"
41
  litellm.set_verbose = False
42
  litellm.drop_params = True
 
56
  self.api_keys = api_keys
57
  self.max_retries = max_retries
58
  self.global_timeout = global_timeout
59
+ self.abort_on_callback_error = abort_on_callback_error
60
  self.usage_manager = UsageManager(file_path=usage_file_path)
61
  self._model_list_cache = {}
62
  self._provider_plugins = PROVIDER_PLUGINS
 
301
  if stream_completed and (not request or not await request.is_disconnected()):
302
  yield "data: [DONE]\n\n"
303
 
304
+ async def _execute_with_retry(self, api_call: callable, request: Optional[Any], pre_request_callback: Optional[callable] = None, **kwargs) -> Any:
305
  """A generic retry mechanism for non-streaming API calls."""
306
  model = kwargs.get("model")
307
  if not model:
 
375
  for attempt in range(self.max_retries):
376
  try:
377
  lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
378
+
379
+ if pre_request_callback:
380
+ try:
381
+ await pre_request_callback(request, litellm_kwargs)
382
+ except Exception as e:
383
+ if self.abort_on_callback_error:
384
+ raise PreRequestCallbackError(f"Pre-request callback failed: {e}") from e
385
+ else:
386
+ lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
387
+
388
  response = await api_call(
389
  api_key=current_key,
390
  **litellm_kwargs,
 
473
  # Return None to indicate failure without propagating a disruptive exception.
474
  return None
475
 
476
+ async def _streaming_acompletion_with_retry(self, request: Optional[Any], pre_request_callback: Optional[callable] = None, **kwargs) -> AsyncGenerator[str, None]:
477
  """A dedicated generator for retrying streaming completions with full request preparation and per-key retries."""
478
  model = kwargs.get("model")
479
  provider = model.split('/')[0]
 
538
  for attempt in range(self.max_retries):
539
  try:
540
  lib_logger.info(f"Attempting stream with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
541
+
542
+ if pre_request_callback:
543
+ try:
544
+ await pre_request_callback(request, litellm_kwargs)
545
+ except Exception as e:
546
+ if self.abort_on_callback_error:
547
+ raise PreRequestCallbackError(f"Pre-request callback failed: {e}") from e
548
+ else:
549
+ lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
550
+
551
  response = await litellm.acompletion(
552
  api_key=current_key,
553
  **litellm_kwargs,
 
721
  yield f"data: {json.dumps(error_data)}\n\n"
722
  yield "data: [DONE]\n\n"
723
 
724
+ def acompletion(self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
725
+ """
726
+ Dispatcher for completion requests.
727
+
728
+ Args:
729
+ request: Optional request object, used for client disconnect checks and logging.
730
+ pre_request_callback: Optional async callback function to be called before each API request attempt.
731
+ The callback will receive the `request` object and the prepared request `kwargs` as arguments.
732
+ This can be used for custom logic such as request validation, logging, or rate limiting.
733
+ If the callback raises an exception, the completion request will be aborted and the exception will propagate.
734
+
735
+ Returns:
736
+ The completion response object, or an async generator for streaming responses, or None if all retries fail.
737
+ """
738
  if kwargs.get("stream"):
739
+ return self._streaming_acompletion_with_retry(request=request, pre_request_callback=pre_request_callback, **kwargs)
740
  else:
741
+ return self._execute_with_retry(litellm.acompletion, request=request, pre_request_callback=pre_request_callback, **kwargs)
742
 
743
+ def aembedding(self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs) -> Any:
744
+ """
745
+ Executes an embedding request with retry logic.
746
+
747
+ Args:
748
+ request: Optional request object, used for client disconnect checks and logging.
749
+ pre_request_callback: Optional async callback function to be called before each API request attempt.
750
+ The callback will receive the `request` object and the prepared request `kwargs` as arguments.
751
+ This can be used for custom logic such as request validation, logging, or rate limiting.
752
+ If the callback raises an exception, the embedding request will be aborted and the exception will propagate.
753
+
754
+ Returns:
755
+ The embedding response object, or None if all retries fail.
756
+ """
757
+ return self._execute_with_retry(litellm.aembedding, request=request, pre_request_callback=pre_request_callback, **kwargs)
758
 
759
  def token_count(self, **kwargs) -> int:
760
  """Calculates the number of tokens for a given text or list of messages."""
src/rotator_library/error_handler.py CHANGED
@@ -7,6 +7,10 @@ class NoAvailableKeysError(Exception):
7
  """Raised when no API keys are available for a request after waiting."""
8
  pass
9
 
 
 
 
 
10
  class ClassifiedError:
11
  """A structured representation of a classified error."""
12
  def __init__(self, error_type: str, original_exception: Exception, status_code: Optional[int] = None, retry_after: Optional[int] = None):
@@ -79,6 +83,13 @@ def classify_error(e: Exception) -> ClassifiedError:
79
  """
80
  status_code = getattr(e, 'status_code', None)
81
 
 
 
 
 
 
 
 
82
  if isinstance(e, RateLimitError):
83
  retry_after = get_retry_after(e)
84
  return ClassifiedError(
 
7
  """Raised when no API keys are available for a request after waiting."""
8
  pass
9
 
10
+ class PreRequestCallbackError(Exception):
11
+ """Raised when a pre-request callback fails."""
12
+ pass
13
+
14
  class ClassifiedError:
15
  """A structured representation of a classified error."""
16
  def __init__(self, error_type: str, original_exception: Exception, status_code: Optional[int] = None, retry_after: Optional[int] = None):
 
83
  """
84
  status_code = getattr(e, 'status_code', None)
85
 
86
+ if isinstance(e, PreRequestCallbackError):
87
+ return ClassifiedError(
88
+ error_type='pre_request_callback_error',
89
+ original_exception=e,
90
+ status_code=400 # Treat as a bad request
91
+ )
92
+
93
  if isinstance(e, RateLimitError):
94
  retry_after = get_retry_after(e)
95
  return ClassifiedError(