Spaces:
Paused
Paused
Mirrowel commited on
Commit ·
982e3aa
1
Parent(s): 45e5783
Fix: Re-Add robust streaming wrapper for RotatingClient
Browse filesImplement `_safe_streaming_wrapper` to enhance streaming reliability for the `RotatingClient`. This new method:
- Buffers fragmented JSON chunks to reassemble complete objects.
- Gracefully handles client disconnections during active streams.
- Distinguishes between streamed content and API errors, raising `StreamedAPIError` for provider errors.
- Ensures accurate usage recording and timely key release for streaming operations.
- Sends the `[DONE]` signal upon successful stream completion.
- src/rotator_library/client.py +111 -34
src/rotator_library/client.py
CHANGED
|
@@ -84,6 +84,92 @@ class RotatingClient:
|
|
| 84 |
return None
|
| 85 |
return self._provider_instances[provider_name]
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
async def acompletion(self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
|
| 88 |
kwargs = self._convert_model_params(**kwargs)
|
| 89 |
model = kwargs.get("model")
|
|
@@ -96,25 +182,10 @@ class RotatingClient:
|
|
| 96 |
if provider not in self.api_keys:
|
| 97 |
raise ValueError(f"No API keys configured for provider: {provider}")
|
| 98 |
|
| 99 |
-
async def _streaming_generator(_stream, _key, _model):
|
| 100 |
-
"""Generator that yields stream chunks and handles key release."""
|
| 101 |
-
try:
|
| 102 |
-
async for chunk in _stream:
|
| 103 |
-
if request and await request.is_disconnected():
|
| 104 |
-
lib_logger.warning("Client disconnected, cancelling stream.")
|
| 105 |
-
raise asyncio.CancelledError("Client disconnected")
|
| 106 |
-
yield f"data: {json.dumps(chunk.dict())}\n\n"
|
| 107 |
-
|
| 108 |
-
await self.usage_manager.record_success(_key, _model, _stream)
|
| 109 |
-
yield "data: [DONE]\n\n"
|
| 110 |
-
finally:
|
| 111 |
-
await self.usage_manager.release_key(_key, _model)
|
| 112 |
-
lib_logger.info(f"STREAM FINISHED and lock released for key ...{_key[-4:]}.")
|
| 113 |
-
|
| 114 |
keys_for_provider = self.api_keys[provider]
|
| 115 |
tried_keys = set()
|
| 116 |
last_exception = None
|
| 117 |
-
|
| 118 |
while len(tried_keys) < len(keys_for_provider):
|
| 119 |
current_key = None
|
| 120 |
key_acquired = False
|
|
@@ -135,58 +206,64 @@ class RotatingClient:
|
|
| 135 |
converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
|
| 136 |
if converted_settings is not None: litellm_kwargs["safety_settings"] = converted_settings
|
| 137 |
else: del litellm_kwargs["safety_settings"]
|
|
|
|
| 138 |
if provider == "gemini":
|
| 139 |
if provider_instance: provider_instance.handle_thinking_parameter(litellm_kwargs, model)
|
|
|
|
| 140 |
if "gemma-3" in model and "messages" in litellm_kwargs:
|
| 141 |
litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
|
|
|
|
| 142 |
litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
|
| 143 |
|
| 144 |
for attempt in range(self.max_retries):
|
| 145 |
try:
|
| 146 |
lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
|
| 147 |
-
|
|
|
|
|
|
|
| 148 |
|
| 149 |
response = await litellm.acompletion(api_key=current_key, **litellm_kwargs)
|
| 150 |
|
| 151 |
if is_streaming:
|
| 152 |
key_acquired = False
|
| 153 |
-
return
|
| 154 |
else:
|
| 155 |
await self.usage_manager.record_success(current_key, model, response)
|
| 156 |
-
key_acquired = False
|
| 157 |
await self.usage_manager.release_key(current_key, model)
|
|
|
|
| 158 |
return response
|
| 159 |
|
| 160 |
-
except
|
|
|
|
| 161 |
last_exception = e
|
| 162 |
-
if isinstance(e, asyncio.CancelledError): raise e
|
| 163 |
-
|
| 164 |
log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
|
| 165 |
classified_error = classify_error(e)
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
-
|
| 168 |
-
|
|
|
|
| 169 |
|
| 170 |
if request and await request.is_disconnected():
|
| 171 |
lib_logger.warning(f"Client disconnected. Aborting retries for key ...{current_key[-4:]}.")
|
| 172 |
raise last_exception
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
if classified_error.error_type in ['server_error', 'api_connection']:
|
| 175 |
await self.usage_manager.record_failure(current_key, model, classified_error)
|
| 176 |
-
if attempt >= self.max_retries - 1:
|
|
|
|
| 177 |
wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
|
| 178 |
await asyncio.sleep(wait_time)
|
| 179 |
continue
|
| 180 |
-
|
| 181 |
await self.usage_manager.record_failure(current_key, model, classified_error)
|
| 182 |
break
|
| 183 |
-
|
| 184 |
-
except (litellm.InvalidRequestError, litellm.ContextWindowExceededError, asyncio.CancelledError) as e:
|
| 185 |
-
raise e
|
| 186 |
-
except Exception as e:
|
| 187 |
-
last_exception = e
|
| 188 |
-
lib_logger.error(f"An unexpected error occurred with key ...{current_key[-4:] if current_key else 'N/A'}: {e}")
|
| 189 |
-
continue
|
| 190 |
finally:
|
| 191 |
if key_acquired and current_key:
|
| 192 |
await self.usage_manager.release_key(current_key, model)
|
|
@@ -194,7 +271,7 @@ class RotatingClient:
|
|
| 194 |
if last_exception:
|
| 195 |
raise last_exception
|
| 196 |
|
| 197 |
-
raise Exception("Failed to complete the request: No available API keys or all keys failed.")
|
| 198 |
|
| 199 |
async def aembedding(self, request: Optional[Any] = None, **kwargs) -> Any:
|
| 200 |
kwargs = self._convert_model_params(**kwargs)
|
|
|
|
| 84 |
return None
|
| 85 |
return self._provider_instances[provider_name]
|
| 86 |
|
| 87 |
+
async def _safe_streaming_wrapper(self, stream: Any, key: str, model: str, request: Optional[Any] = None) -> AsyncGenerator[Any, None]:
|
| 88 |
+
"""
|
| 89 |
+
A hybrid wrapper for streaming that buffers fragmented JSON, handles client disconnections gracefully,
|
| 90 |
+
and distinguishes between content and streamed errors.
|
| 91 |
+
"""
|
| 92 |
+
usage_recorded = False
|
| 93 |
+
stream_completed = False
|
| 94 |
+
stream_iterator = stream.__aiter__()
|
| 95 |
+
json_buffer = ""
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
while True:
|
| 99 |
+
if request and await request.is_disconnected():
|
| 100 |
+
lib_logger.warning(f"Client disconnected. Aborting stream for key ...{key[-4:]}.")
|
| 101 |
+
# Do not yield [DONE] because the client is gone.
|
| 102 |
+
# The 'finally' block will handle key release.
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
chunk = await stream_iterator.__anext__()
|
| 107 |
+
if json_buffer:
|
| 108 |
+
lib_logger.warning(f"Discarding incomplete JSON buffer: {json_buffer}")
|
| 109 |
+
json_buffer = ""
|
| 110 |
+
|
| 111 |
+
yield f"data: {json.dumps(chunk.dict())}\n\n"
|
| 112 |
+
|
| 113 |
+
if not usage_recorded and hasattr(chunk, 'usage') and chunk.usage:
|
| 114 |
+
await self.usage_manager.record_success(key, model, chunk)
|
| 115 |
+
usage_recorded = True
|
| 116 |
+
|
| 117 |
+
except StopAsyncIteration:
|
| 118 |
+
stream_completed = True
|
| 119 |
+
if json_buffer:
|
| 120 |
+
lib_logger.warning(f"Stream ended with incomplete data in buffer: {json_buffer}")
|
| 121 |
+
break
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
try:
|
| 125 |
+
raw_chunk = str(e).split("Received chunk:")[-1].strip()
|
| 126 |
+
json_buffer += raw_chunk
|
| 127 |
+
parsed_data = json.loads(json_buffer)
|
| 128 |
+
|
| 129 |
+
lib_logger.info(f"Successfully reassembled JSON from buffer: {json_buffer}")
|
| 130 |
+
|
| 131 |
+
if "error" in parsed_data:
|
| 132 |
+
lib_logger.warning(f"Reassembled object is an API error. Passing it to the client and raising internally.")
|
| 133 |
+
yield f"data: {json.dumps(parsed_data)}\n\n"
|
| 134 |
+
# Signal the error to the outer retry loop so it can try the next key.
|
| 135 |
+
raise StreamedAPIError("Provider error received in stream", data=parsed_data)
|
| 136 |
+
else:
|
| 137 |
+
yield f"data: {json.dumps(parsed_data)}\n\n"
|
| 138 |
+
|
| 139 |
+
json_buffer = ""
|
| 140 |
+
except json.JSONDecodeError:
|
| 141 |
+
lib_logger.info(f"Buffer still incomplete. Waiting for more chunks: {json_buffer}")
|
| 142 |
+
continue
|
| 143 |
+
except StreamedAPIError:
|
| 144 |
+
# Re-raise to be caught by the outer handler
|
| 145 |
+
raise
|
| 146 |
+
except Exception as buffer_exc:
|
| 147 |
+
lib_logger.error(f"Error during stream buffering logic: {buffer_exc}. Discarding buffer.")
|
| 148 |
+
json_buffer = ""
|
| 149 |
+
continue
|
| 150 |
+
|
| 151 |
+
except StreamedAPIError:
|
| 152 |
+
# This is caught by the acompletion retry logic.
|
| 153 |
+
# We re-raise it to ensure it's not caught by the generic 'except Exception'.
|
| 154 |
+
raise
|
| 155 |
+
|
| 156 |
+
except Exception as e:
|
| 157 |
+
# Catch any other unexpected errors during streaming.
|
| 158 |
+
lib_logger.error(f"An unexpected error occurred during the stream for key ...{key[-4:]}: {e}")
|
| 159 |
+
# We still need to raise it so the client knows something went wrong.
|
| 160 |
+
raise
|
| 161 |
+
|
| 162 |
+
finally:
|
| 163 |
+
# Only record usage if the stream completed successfully and usage wasn't already recorded.
|
| 164 |
+
if stream_completed and not usage_recorded:
|
| 165 |
+
await self.usage_manager.record_success(key, model, stream)
|
| 166 |
+
|
| 167 |
+
await self.usage_manager.release_key(key, model)
|
| 168 |
+
lib_logger.info(f"STREAM FINISHED and lock released for key ...{key[-4:]}.")
|
| 169 |
+
|
| 170 |
+
if stream_completed:
|
| 171 |
+
yield "data: [DONE]\n\n"
|
| 172 |
+
|
| 173 |
async def acompletion(self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
|
| 174 |
kwargs = self._convert_model_params(**kwargs)
|
| 175 |
model = kwargs.get("model")
|
|
|
|
| 182 |
if provider not in self.api_keys:
|
| 183 |
raise ValueError(f"No API keys configured for provider: {provider}")
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
keys_for_provider = self.api_keys[provider]
|
| 186 |
tried_keys = set()
|
| 187 |
last_exception = None
|
| 188 |
+
|
| 189 |
while len(tried_keys) < len(keys_for_provider):
|
| 190 |
current_key = None
|
| 191 |
key_acquired = False
|
|
|
|
| 206 |
converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
|
| 207 |
if converted_settings is not None: litellm_kwargs["safety_settings"] = converted_settings
|
| 208 |
else: del litellm_kwargs["safety_settings"]
|
| 209 |
+
|
| 210 |
if provider == "gemini":
|
| 211 |
if provider_instance: provider_instance.handle_thinking_parameter(litellm_kwargs, model)
|
| 212 |
+
|
| 213 |
if "gemma-3" in model and "messages" in litellm_kwargs:
|
| 214 |
litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
|
| 215 |
+
|
| 216 |
litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
|
| 217 |
|
| 218 |
for attempt in range(self.max_retries):
|
| 219 |
try:
|
| 220 |
lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
|
| 221 |
+
|
| 222 |
+
if pre_request_callback:
|
| 223 |
+
await pre_request_callback()
|
| 224 |
|
| 225 |
response = await litellm.acompletion(api_key=current_key, **litellm_kwargs)
|
| 226 |
|
| 227 |
if is_streaming:
|
| 228 |
key_acquired = False
|
| 229 |
+
return self._safe_streaming_wrapper(response, current_key, model, request)
|
| 230 |
else:
|
| 231 |
await self.usage_manager.record_success(current_key, model, response)
|
|
|
|
| 232 |
await self.usage_manager.release_key(current_key, model)
|
| 233 |
+
key_acquired = False
|
| 234 |
return response
|
| 235 |
|
| 236 |
+
except (StreamedAPIError, APIConnectionError) as e:
|
| 237 |
+
# These errors are caught to allow retrying with the next key.
|
| 238 |
last_exception = e
|
|
|
|
|
|
|
| 239 |
log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
|
| 240 |
classified_error = classify_error(e)
|
| 241 |
+
await self.usage_manager.record_failure(current_key, model, classified_error)
|
| 242 |
+
lib_logger.warning(f"Key ...{current_key[-4:]} encountered '{classified_error.error_type}'. Trying next key.")
|
| 243 |
+
break # Break from retry loop, try next key
|
| 244 |
|
| 245 |
+
except Exception as e:
|
| 246 |
+
last_exception = e
|
| 247 |
+
log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
|
| 248 |
|
| 249 |
if request and await request.is_disconnected():
|
| 250 |
lib_logger.warning(f"Client disconnected. Aborting retries for key ...{current_key[-4:]}.")
|
| 251 |
raise last_exception
|
| 252 |
|
| 253 |
+
classified_error = classify_error(e)
|
| 254 |
+
if classified_error.error_type in ['invalid_request', 'context_window_exceeded']:
|
| 255 |
+
raise last_exception
|
| 256 |
+
|
| 257 |
if classified_error.error_type in ['server_error', 'api_connection']:
|
| 258 |
await self.usage_manager.record_failure(current_key, model, classified_error)
|
| 259 |
+
if attempt >= self.max_retries - 1:
|
| 260 |
+
break
|
| 261 |
wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
|
| 262 |
await asyncio.sleep(wait_time)
|
| 263 |
continue
|
| 264 |
+
|
| 265 |
await self.usage_manager.record_failure(current_key, model, classified_error)
|
| 266 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
finally:
|
| 268 |
if key_acquired and current_key:
|
| 269 |
await self.usage_manager.release_key(current_key, model)
|
|
|
|
| 271 |
if last_exception:
|
| 272 |
raise last_exception
|
| 273 |
|
| 274 |
+
raise Exception("Failed to complete the request: No available API keys for the provider or all keys failed.")
|
| 275 |
|
| 276 |
async def aembedding(self, request: Optional[Any] = None, **kwargs) -> Any:
|
| 277 |
kwargs = self._convert_model_params(**kwargs)
|