Mirrowel commited on
Commit
982e3aa
·
1 Parent(s): 45e5783

Fix: Re-Add robust streaming wrapper for RotatingClient

Browse files

Implement `_safe_streaming_wrapper` to enhance streaming reliability for the `RotatingClient`. This new method:
- Buffers fragmented JSON chunks to reassemble complete objects.
- Gracefully handles client disconnections during active streams.
- Distinguishes between streamed content and API errors, raising `StreamedAPIError` for provider errors.
- Ensures accurate usage recording and timely key release for streaming operations.
- Sends the `[DONE]` signal upon successful stream completion.

Files changed (1) hide show
  1. src/rotator_library/client.py +111 -34
src/rotator_library/client.py CHANGED
@@ -84,6 +84,92 @@ class RotatingClient:
84
  return None
85
  return self._provider_instances[provider_name]
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  async def acompletion(self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
88
  kwargs = self._convert_model_params(**kwargs)
89
  model = kwargs.get("model")
@@ -96,25 +182,10 @@ class RotatingClient:
96
  if provider not in self.api_keys:
97
  raise ValueError(f"No API keys configured for provider: {provider}")
98
 
99
- async def _streaming_generator(_stream, _key, _model):
100
- """Generator that yields stream chunks and handles key release."""
101
- try:
102
- async for chunk in _stream:
103
- if request and await request.is_disconnected():
104
- lib_logger.warning("Client disconnected, cancelling stream.")
105
- raise asyncio.CancelledError("Client disconnected")
106
- yield f"data: {json.dumps(chunk.dict())}\n\n"
107
-
108
- await self.usage_manager.record_success(_key, _model, _stream)
109
- yield "data: [DONE]\n\n"
110
- finally:
111
- await self.usage_manager.release_key(_key, _model)
112
- lib_logger.info(f"STREAM FINISHED and lock released for key ...{_key[-4:]}.")
113
-
114
  keys_for_provider = self.api_keys[provider]
115
  tried_keys = set()
116
  last_exception = None
117
-
118
  while len(tried_keys) < len(keys_for_provider):
119
  current_key = None
120
  key_acquired = False
@@ -135,58 +206,64 @@ class RotatingClient:
135
  converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
136
  if converted_settings is not None: litellm_kwargs["safety_settings"] = converted_settings
137
  else: del litellm_kwargs["safety_settings"]
 
138
  if provider == "gemini":
139
  if provider_instance: provider_instance.handle_thinking_parameter(litellm_kwargs, model)
 
140
  if "gemma-3" in model and "messages" in litellm_kwargs:
141
  litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
 
142
  litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
143
 
144
  for attempt in range(self.max_retries):
145
  try:
146
  lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
147
- if pre_request_callback: await pre_request_callback()
 
 
148
 
149
  response = await litellm.acompletion(api_key=current_key, **litellm_kwargs)
150
 
151
  if is_streaming:
152
  key_acquired = False
153
- return _streaming_generator(response, current_key, model)
154
  else:
155
  await self.usage_manager.record_success(current_key, model, response)
156
- key_acquired = False
157
  await self.usage_manager.release_key(current_key, model)
 
158
  return response
159
 
160
- except Exception as e:
 
161
  last_exception = e
162
- if isinstance(e, asyncio.CancelledError): raise e
163
-
164
  log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
165
  classified_error = classify_error(e)
 
 
 
166
 
167
- if classified_error.error_type in ['invalid_request', 'context_window_exceeded']:
168
- raise last_exception
 
169
 
170
  if request and await request.is_disconnected():
171
  lib_logger.warning(f"Client disconnected. Aborting retries for key ...{current_key[-4:]}.")
172
  raise last_exception
173
 
 
 
 
 
174
  if classified_error.error_type in ['server_error', 'api_connection']:
175
  await self.usage_manager.record_failure(current_key, model, classified_error)
176
- if attempt >= self.max_retries - 1: break
 
177
  wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
178
  await asyncio.sleep(wait_time)
179
  continue
180
-
181
  await self.usage_manager.record_failure(current_key, model, classified_error)
182
  break
183
-
184
- except (litellm.InvalidRequestError, litellm.ContextWindowExceededError, asyncio.CancelledError) as e:
185
- raise e
186
- except Exception as e:
187
- last_exception = e
188
- lib_logger.error(f"An unexpected error occurred with key ...{current_key[-4:] if current_key else 'N/A'}: {e}")
189
- continue
190
  finally:
191
  if key_acquired and current_key:
192
  await self.usage_manager.release_key(current_key, model)
@@ -194,7 +271,7 @@ class RotatingClient:
194
  if last_exception:
195
  raise last_exception
196
 
197
- raise Exception("Failed to complete the request: No available API keys or all keys failed.")
198
 
199
  async def aembedding(self, request: Optional[Any] = None, **kwargs) -> Any:
200
  kwargs = self._convert_model_params(**kwargs)
 
84
  return None
85
  return self._provider_instances[provider_name]
86
 
87
+ async def _safe_streaming_wrapper(self, stream: Any, key: str, model: str, request: Optional[Any] = None) -> AsyncGenerator[Any, None]:
88
+ """
89
+ A hybrid wrapper for streaming that buffers fragmented JSON, handles client disconnections gracefully,
90
+ and distinguishes between content and streamed errors.
91
+ """
92
+ usage_recorded = False
93
+ stream_completed = False
94
+ stream_iterator = stream.__aiter__()
95
+ json_buffer = ""
96
+
97
+ try:
98
+ while True:
99
+ if request and await request.is_disconnected():
100
+ lib_logger.warning(f"Client disconnected. Aborting stream for key ...{key[-4:]}.")
101
+ # Do not yield [DONE] because the client is gone.
102
+ # The 'finally' block will handle key release.
103
+ break
104
+
105
+ try:
106
+ chunk = await stream_iterator.__anext__()
107
+ if json_buffer:
108
+ lib_logger.warning(f"Discarding incomplete JSON buffer: {json_buffer}")
109
+ json_buffer = ""
110
+
111
+ yield f"data: {json.dumps(chunk.dict())}\n\n"
112
+
113
+ if not usage_recorded and hasattr(chunk, 'usage') and chunk.usage:
114
+ await self.usage_manager.record_success(key, model, chunk)
115
+ usage_recorded = True
116
+
117
+ except StopAsyncIteration:
118
+ stream_completed = True
119
+ if json_buffer:
120
+ lib_logger.warning(f"Stream ended with incomplete data in buffer: {json_buffer}")
121
+ break
122
+
123
+ except Exception as e:
124
+ try:
125
+ raw_chunk = str(e).split("Received chunk:")[-1].strip()
126
+ json_buffer += raw_chunk
127
+ parsed_data = json.loads(json_buffer)
128
+
129
+ lib_logger.info(f"Successfully reassembled JSON from buffer: {json_buffer}")
130
+
131
+ if "error" in parsed_data:
132
+ lib_logger.warning(f"Reassembled object is an API error. Passing it to the client and raising internally.")
133
+ yield f"data: {json.dumps(parsed_data)}\n\n"
134
+ # Signal the error to the outer retry loop so it can try the next key.
135
+ raise StreamedAPIError("Provider error received in stream", data=parsed_data)
136
+ else:
137
+ yield f"data: {json.dumps(parsed_data)}\n\n"
138
+
139
+ json_buffer = ""
140
+ except json.JSONDecodeError:
141
+ lib_logger.info(f"Buffer still incomplete. Waiting for more chunks: {json_buffer}")
142
+ continue
143
+ except StreamedAPIError:
144
+ # Re-raise to be caught by the outer handler
145
+ raise
146
+ except Exception as buffer_exc:
147
+ lib_logger.error(f"Error during stream buffering logic: {buffer_exc}. Discarding buffer.")
148
+ json_buffer = ""
149
+ continue
150
+
151
+ except StreamedAPIError:
152
+ # This is caught by the acompletion retry logic.
153
+ # We re-raise it to ensure it's not caught by the generic 'except Exception'.
154
+ raise
155
+
156
+ except Exception as e:
157
+ # Catch any other unexpected errors during streaming.
158
+ lib_logger.error(f"An unexpected error occurred during the stream for key ...{key[-4:]}: {e}")
159
+ # We still need to raise it so the client knows something went wrong.
160
+ raise
161
+
162
+ finally:
163
+ # Only record usage if the stream completed successfully and usage wasn't already recorded.
164
+ if stream_completed and not usage_recorded:
165
+ await self.usage_manager.record_success(key, model, stream)
166
+
167
+ await self.usage_manager.release_key(key, model)
168
+ lib_logger.info(f"STREAM FINISHED and lock released for key ...{key[-4:]}.")
169
+
170
+ if stream_completed:
171
+ yield "data: [DONE]\n\n"
172
+
173
  async def acompletion(self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
174
  kwargs = self._convert_model_params(**kwargs)
175
  model = kwargs.get("model")
 
182
  if provider not in self.api_keys:
183
  raise ValueError(f"No API keys configured for provider: {provider}")
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  keys_for_provider = self.api_keys[provider]
186
  tried_keys = set()
187
  last_exception = None
188
+
189
  while len(tried_keys) < len(keys_for_provider):
190
  current_key = None
191
  key_acquired = False
 
206
  converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
207
  if converted_settings is not None: litellm_kwargs["safety_settings"] = converted_settings
208
  else: del litellm_kwargs["safety_settings"]
209
+
210
  if provider == "gemini":
211
  if provider_instance: provider_instance.handle_thinking_parameter(litellm_kwargs, model)
212
+
213
  if "gemma-3" in model and "messages" in litellm_kwargs:
214
  litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
215
+
216
  litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
217
 
218
  for attempt in range(self.max_retries):
219
  try:
220
  lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
221
+
222
+ if pre_request_callback:
223
+ await pre_request_callback()
224
 
225
  response = await litellm.acompletion(api_key=current_key, **litellm_kwargs)
226
 
227
  if is_streaming:
228
  key_acquired = False
229
+ return self._safe_streaming_wrapper(response, current_key, model, request)
230
  else:
231
  await self.usage_manager.record_success(current_key, model, response)
 
232
  await self.usage_manager.release_key(current_key, model)
233
+ key_acquired = False
234
  return response
235
 
236
+ except (StreamedAPIError, APIConnectionError) as e:
237
+ # These errors are caught to allow retrying with the next key.
238
  last_exception = e
 
 
239
  log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
240
  classified_error = classify_error(e)
241
+ await self.usage_manager.record_failure(current_key, model, classified_error)
242
+ lib_logger.warning(f"Key ...{current_key[-4:]} encountered '{classified_error.error_type}'. Trying next key.")
243
+ break # Break from retry loop, try next key
244
 
245
+ except Exception as e:
246
+ last_exception = e
247
+ log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
248
 
249
  if request and await request.is_disconnected():
250
  lib_logger.warning(f"Client disconnected. Aborting retries for key ...{current_key[-4:]}.")
251
  raise last_exception
252
 
253
+ classified_error = classify_error(e)
254
+ if classified_error.error_type in ['invalid_request', 'context_window_exceeded']:
255
+ raise last_exception
256
+
257
  if classified_error.error_type in ['server_error', 'api_connection']:
258
  await self.usage_manager.record_failure(current_key, model, classified_error)
259
+ if attempt >= self.max_retries - 1:
260
+ break
261
  wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
262
  await asyncio.sleep(wait_time)
263
  continue
264
+
265
  await self.usage_manager.record_failure(current_key, model, classified_error)
266
  break
 
 
 
 
 
 
 
267
  finally:
268
  if key_acquired and current_key:
269
  await self.usage_manager.release_key(current_key, model)
 
271
  if last_exception:
272
  raise last_exception
273
 
274
+ raise Exception("Failed to complete the request: No available API keys for the provider or all keys failed.")
275
 
276
  async def aembedding(self, request: Optional[Any] = None, **kwargs) -> Any:
277
  kwargs = self._convert_model_params(**kwargs)