Spaces:
Paused
Paused
Mirrowel
commited on
Commit
·
da8d5f4
1
Parent(s):
c55fc2a
feat(client): add pre-request callback support to API client
Browse filesIntroduce optional asynchronous callbacks that execute before each API request attempt, allowing for custom preprocessing such as validation or logging.
- Includes a new `PreRequestCallbackError` exception for handling callback failures.
- Adds `abort_on_callback_error` parameter to control whether exceptions in callbacks stop the request or are logged as warnings.
- Extends support to both completion and embedding methods, with detailed docstrings for usage.
This feature enables users to inject custom logic into the request pipeline without modifying core client behavior.
src/rotator_library/client.py
CHANGED
|
@@ -20,7 +20,7 @@ lib_logger.propagate = False
|
|
| 20 |
|
| 21 |
from .usage_manager import UsageManager
|
| 22 |
from .failure_logger import log_failure
|
| 23 |
-
from .error_handler import classify_error, AllProviders, NoAvailableKeysError
|
| 24 |
from .providers import PROVIDER_PLUGINS
|
| 25 |
from .request_sanitizer import sanitize_request_payload
|
| 26 |
from .cooldown_manager import CooldownManager
|
|
@@ -36,7 +36,7 @@ class RotatingClient:
|
|
| 36 |
A client that intelligently rotates and retries API keys using LiteLLM,
|
| 37 |
with support for both streaming and non-streaming responses.
|
| 38 |
"""
|
| 39 |
-
def __init__(self, api_keys: Dict[str, List[str]], max_retries: int = 2, usage_file_path: str = "key_usage.json", configure_logging: bool = True, global_timeout: int = 30):
|
| 40 |
os.environ["LITELLM_LOG"] = "ERROR"
|
| 41 |
litellm.set_verbose = False
|
| 42 |
litellm.drop_params = True
|
|
@@ -56,6 +56,7 @@ class RotatingClient:
|
|
| 56 |
self.api_keys = api_keys
|
| 57 |
self.max_retries = max_retries
|
| 58 |
self.global_timeout = global_timeout
|
|
|
|
| 59 |
self.usage_manager = UsageManager(file_path=usage_file_path)
|
| 60 |
self._model_list_cache = {}
|
| 61 |
self._provider_plugins = PROVIDER_PLUGINS
|
|
@@ -300,7 +301,7 @@ class RotatingClient:
|
|
| 300 |
if stream_completed and (not request or not await request.is_disconnected()):
|
| 301 |
yield "data: [DONE]\n\n"
|
| 302 |
|
| 303 |
-
async def _execute_with_retry(self, api_call: callable, request: Optional[Any], **kwargs) -> Any:
|
| 304 |
"""A generic retry mechanism for non-streaming API calls."""
|
| 305 |
model = kwargs.get("model")
|
| 306 |
if not model:
|
|
@@ -374,6 +375,16 @@ class RotatingClient:
|
|
| 374 |
for attempt in range(self.max_retries):
|
| 375 |
try:
|
| 376 |
lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
response = await api_call(
|
| 378 |
api_key=current_key,
|
| 379 |
**litellm_kwargs,
|
|
@@ -462,7 +473,7 @@ class RotatingClient:
|
|
| 462 |
# Return None to indicate failure without propagating a disruptive exception.
|
| 463 |
return None
|
| 464 |
|
| 465 |
-
async def _streaming_acompletion_with_retry(self, request: Optional[Any], **kwargs) -> AsyncGenerator[str, None]:
|
| 466 |
"""A dedicated generator for retrying streaming completions with full request preparation and per-key retries."""
|
| 467 |
model = kwargs.get("model")
|
| 468 |
provider = model.split('/')[0]
|
|
@@ -527,6 +538,16 @@ class RotatingClient:
|
|
| 527 |
for attempt in range(self.max_retries):
|
| 528 |
try:
|
| 529 |
lib_logger.info(f"Attempting stream with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
response = await litellm.acompletion(
|
| 531 |
api_key=current_key,
|
| 532 |
**litellm_kwargs,
|
|
@@ -700,16 +721,40 @@ class RotatingClient:
|
|
| 700 |
yield f"data: {json.dumps(error_data)}\n\n"
|
| 701 |
yield "data: [DONE]\n\n"
|
| 702 |
|
| 703 |
-
def acompletion(self, request: Optional[Any] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
|
| 704 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
if kwargs.get("stream"):
|
| 706 |
-
return self._streaming_acompletion_with_retry(request, **kwargs)
|
| 707 |
else:
|
| 708 |
-
return self._execute_with_retry(litellm.acompletion, request, **kwargs)
|
| 709 |
|
| 710 |
-
def aembedding(self, request: Optional[Any] = None, **kwargs) -> Any:
|
| 711 |
-
"""
|
| 712 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
|
| 714 |
def token_count(self, **kwargs) -> int:
|
| 715 |
"""Calculates the number of tokens for a given text or list of messages."""
|
|
|
|
| 20 |
|
| 21 |
from .usage_manager import UsageManager
|
| 22 |
from .failure_logger import log_failure
|
| 23 |
+
from .error_handler import PreRequestCallbackError, classify_error, AllProviders, NoAvailableKeysError
|
| 24 |
from .providers import PROVIDER_PLUGINS
|
| 25 |
from .request_sanitizer import sanitize_request_payload
|
| 26 |
from .cooldown_manager import CooldownManager
|
|
|
|
| 36 |
A client that intelligently rotates and retries API keys using LiteLLM,
|
| 37 |
with support for both streaming and non-streaming responses.
|
| 38 |
"""
|
| 39 |
+
def __init__(self, api_keys: Dict[str, List[str]], max_retries: int = 2, usage_file_path: str = "key_usage.json", configure_logging: bool = True, global_timeout: int = 30, abort_on_callback_error: bool = True):
|
| 40 |
os.environ["LITELLM_LOG"] = "ERROR"
|
| 41 |
litellm.set_verbose = False
|
| 42 |
litellm.drop_params = True
|
|
|
|
| 56 |
self.api_keys = api_keys
|
| 57 |
self.max_retries = max_retries
|
| 58 |
self.global_timeout = global_timeout
|
| 59 |
+
self.abort_on_callback_error = abort_on_callback_error
|
| 60 |
self.usage_manager = UsageManager(file_path=usage_file_path)
|
| 61 |
self._model_list_cache = {}
|
| 62 |
self._provider_plugins = PROVIDER_PLUGINS
|
|
|
|
| 301 |
if stream_completed and (not request or not await request.is_disconnected()):
|
| 302 |
yield "data: [DONE]\n\n"
|
| 303 |
|
| 304 |
+
async def _execute_with_retry(self, api_call: callable, request: Optional[Any], pre_request_callback: Optional[callable] = None, **kwargs) -> Any:
|
| 305 |
"""A generic retry mechanism for non-streaming API calls."""
|
| 306 |
model = kwargs.get("model")
|
| 307 |
if not model:
|
|
|
|
| 375 |
for attempt in range(self.max_retries):
|
| 376 |
try:
|
| 377 |
lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
|
| 378 |
+
|
| 379 |
+
if pre_request_callback:
|
| 380 |
+
try:
|
| 381 |
+
await pre_request_callback(request, litellm_kwargs)
|
| 382 |
+
except Exception as e:
|
| 383 |
+
if self.abort_on_callback_error:
|
| 384 |
+
raise PreRequestCallbackError(f"Pre-request callback failed: {e}") from e
|
| 385 |
+
else:
|
| 386 |
+
lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
|
| 387 |
+
|
| 388 |
response = await api_call(
|
| 389 |
api_key=current_key,
|
| 390 |
**litellm_kwargs,
|
|
|
|
| 473 |
# Return None to indicate failure without propagating a disruptive exception.
|
| 474 |
return None
|
| 475 |
|
| 476 |
+
async def _streaming_acompletion_with_retry(self, request: Optional[Any], pre_request_callback: Optional[callable] = None, **kwargs) -> AsyncGenerator[str, None]:
|
| 477 |
"""A dedicated generator for retrying streaming completions with full request preparation and per-key retries."""
|
| 478 |
model = kwargs.get("model")
|
| 479 |
provider = model.split('/')[0]
|
|
|
|
| 538 |
for attempt in range(self.max_retries):
|
| 539 |
try:
|
| 540 |
lib_logger.info(f"Attempting stream with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
|
| 541 |
+
|
| 542 |
+
if pre_request_callback:
|
| 543 |
+
try:
|
| 544 |
+
await pre_request_callback(request, litellm_kwargs)
|
| 545 |
+
except Exception as e:
|
| 546 |
+
if self.abort_on_callback_error:
|
| 547 |
+
raise PreRequestCallbackError(f"Pre-request callback failed: {e}") from e
|
| 548 |
+
else:
|
| 549 |
+
lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
|
| 550 |
+
|
| 551 |
response = await litellm.acompletion(
|
| 552 |
api_key=current_key,
|
| 553 |
**litellm_kwargs,
|
|
|
|
| 721 |
yield f"data: {json.dumps(error_data)}\n\n"
|
| 722 |
yield "data: [DONE]\n\n"
|
| 723 |
|
| 724 |
+
def acompletion(self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
|
| 725 |
+
"""
|
| 726 |
+
Dispatcher for completion requests.
|
| 727 |
+
|
| 728 |
+
Args:
|
| 729 |
+
request: Optional request object, used for client disconnect checks and logging.
|
| 730 |
+
pre_request_callback: Optional async callback function to be called before each API request attempt.
|
| 731 |
+
The callback will receive the `request` object and the prepared request `kwargs` as arguments.
|
| 732 |
+
This can be used for custom logic such as request validation, logging, or rate limiting.
|
| 733 |
+
If the callback raises an exception, the completion request will be aborted and the exception will propagate.
|
| 734 |
+
|
| 735 |
+
Returns:
|
| 736 |
+
The completion response object, or an async generator for streaming responses, or None if all retries fail.
|
| 737 |
+
"""
|
| 738 |
if kwargs.get("stream"):
|
| 739 |
+
return self._streaming_acompletion_with_retry(request=request, pre_request_callback=pre_request_callback, **kwargs)
|
| 740 |
else:
|
| 741 |
+
return self._execute_with_retry(litellm.acompletion, request=request, pre_request_callback=pre_request_callback, **kwargs)
|
| 742 |
|
| 743 |
+
def aembedding(self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs) -> Any:
|
| 744 |
+
"""
|
| 745 |
+
Executes an embedding request with retry logic.
|
| 746 |
+
|
| 747 |
+
Args:
|
| 748 |
+
request: Optional request object, used for client disconnect checks and logging.
|
| 749 |
+
pre_request_callback: Optional async callback function to be called before each API request attempt.
|
| 750 |
+
The callback will receive the `request` object and the prepared request `kwargs` as arguments.
|
| 751 |
+
This can be used for custom logic such as request validation, logging, or rate limiting.
|
| 752 |
+
If the callback raises an exception, the embedding request will be aborted and the exception will propagate.
|
| 753 |
+
|
| 754 |
+
Returns:
|
| 755 |
+
The embedding response object, or None if all retries fail.
|
| 756 |
+
"""
|
| 757 |
+
return self._execute_with_retry(litellm.aembedding, request=request, pre_request_callback=pre_request_callback, **kwargs)
|
| 758 |
|
| 759 |
def token_count(self, **kwargs) -> int:
|
| 760 |
"""Calculates the number of tokens for a given text or list of messages."""
|
src/rotator_library/error_handler.py
CHANGED
|
@@ -7,6 +7,10 @@ class NoAvailableKeysError(Exception):
|
|
| 7 |
"""Raised when no API keys are available for a request after waiting."""
|
| 8 |
pass
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
class ClassifiedError:
|
| 11 |
"""A structured representation of a classified error."""
|
| 12 |
def __init__(self, error_type: str, original_exception: Exception, status_code: Optional[int] = None, retry_after: Optional[int] = None):
|
|
@@ -79,6 +83,13 @@ def classify_error(e: Exception) -> ClassifiedError:
|
|
| 79 |
"""
|
| 80 |
status_code = getattr(e, 'status_code', None)
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
if isinstance(e, RateLimitError):
|
| 83 |
retry_after = get_retry_after(e)
|
| 84 |
return ClassifiedError(
|
|
|
|
| 7 |
"""Raised when no API keys are available for a request after waiting."""
|
| 8 |
pass
|
| 9 |
|
| 10 |
+
class PreRequestCallbackError(Exception):
|
| 11 |
+
"""Raised when a pre-request callback fails."""
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
class ClassifiedError:
|
| 15 |
"""A structured representation of a classified error."""
|
| 16 |
def __init__(self, error_type: str, original_exception: Exception, status_code: Optional[int] = None, retry_after: Optional[int] = None):
|
|
|
|
| 83 |
"""
|
| 84 |
status_code = getattr(e, 'status_code', None)
|
| 85 |
|
| 86 |
+
if isinstance(e, PreRequestCallbackError):
|
| 87 |
+
return ClassifiedError(
|
| 88 |
+
error_type='pre_request_callback_error',
|
| 89 |
+
original_exception=e,
|
| 90 |
+
status_code=400 # Treat as a bad request
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
if isinstance(e, RateLimitError):
|
| 94 |
retry_after = get_retry_after(e)
|
| 95 |
return ClassifiedError(
|