Mirrowel commited on
Commit
b5b51f2
·
1 Parent(s): 978e4a8

feat: Improve streaming robustness and key management

Browse files

Refactor streaming logic to enhance stability and resource cleanup.

- **Proxy (`main.py`):** Added client disconnection detection during streaming to
terminate the response early, preventing unnecessary resource consumption.
The `request` object is now passed to the streaming wrapper for this check.
- **Rotator Library (`client.py`):** Removed the complex `_safe_streaming_wrapper`
for a more direct streaming approach. Ensured API keys are reliably released
upon stream completion, budget exhaustion, or API errors, improving resource
management and preventing potential deadlocks. The `request` object is now
also passed to the underlying provider call for better context.

src/proxy_app/main.py CHANGED
@@ -111,6 +111,7 @@ async def verify_api_key(auth: str = Depends(api_key_header)):
111
  return auth
112
 
113
  async def streaming_response_wrapper(
 
114
  request_data: dict,
115
  response_stream: AsyncGenerator[str, None]
116
  ) -> AsyncGenerator[str, None]:
@@ -124,6 +125,9 @@ async def streaming_response_wrapper(
124
 
125
  try:
126
  async for chunk_str in response_stream:
 
 
 
127
  yield chunk_str
128
  if chunk_str.strip() and chunk_str.startswith("data:"):
129
  content = chunk_str[len("data:"):].strip()
@@ -242,13 +246,13 @@ async def chat_completions(
242
  response = await client.acompletion(request=request, **request_data)
243
 
244
  if is_streaming:
245
- # Wrap the streaming response to enable logging after it's complete
246
  return StreamingResponse(
247
- streaming_response_wrapper(request_data, response),
248
  media_type="text/event-stream"
249
  )
250
  else:
251
- # For non-streaming, log immediately
252
  if ENABLE_REQUEST_LOGGING:
253
  log_request_response(
254
  request_data=request_data,
 
111
  return auth
112
 
113
  async def streaming_response_wrapper(
114
+ request: Request,
115
  request_data: dict,
116
  response_stream: AsyncGenerator[str, None]
117
  ) -> AsyncGenerator[str, None]:
 
125
 
126
  try:
127
  async for chunk_str in response_stream:
128
+ if await request.is_disconnected():
129
+ logging.warning("Client disconnected, stopping stream.")
130
+ break
131
  yield chunk_str
132
  if chunk_str.strip() and chunk_str.startswith("data:"):
133
  content = chunk_str[len("data:"):].strip()
 
246
  response = await client.acompletion(request=request, **request_data)
247
 
248
  if is_streaming:
249
+ # For streaming, the response is the generator.
250
  return StreamingResponse(
251
+ streaming_response_wrapper(request, request_data, response),
252
  media_type="text/event-stream"
253
  )
254
  else:
255
+ # For non-streaming, the response is the completed object.
256
  if ENABLE_REQUEST_LOGGING:
257
  log_request_response(
258
  request_data=request_data,
src/rotator_library/client.py CHANGED
@@ -84,103 +84,7 @@ class RotatingClient:
84
  return None
85
  return self._provider_instances[provider_name]
86
 
87
- async def _safe_streaming_wrapper(self, stream: Any, key: str, model: str) -> AsyncGenerator[Any, None]:
88
- """
89
- A definitive hybrid wrapper for streaming responses that ensures usage is recorded
90
- and the key lock is released only after the stream is fully consumed. It handles
91
- fragmented JSON by buffering and intelligently distinguishing between content and
92
- errors, feeding actual errors back into the main retry logic.
93
- """
94
- usage_recorded = False
95
- stream_completed = False
96
- stream_iterator = stream.__aiter__()
97
- json_buffer = ""
98
-
99
- try:
100
- while True:
101
- try:
102
- # 1. Await the next item from the stream iterator.
103
- chunk = await stream_iterator.__anext__()
104
-
105
- # 2. If we receive a valid chunk while the buffer has content,
106
- # it implies the buffered data was an unrecoverable fragment.
107
- # Log it, discard the buffer, and proceed with the valid chunk.
108
- if json_buffer:
109
- lib_logger.warning(f"Discarding incomplete JSON buffer because a valid chunk was received: {json_buffer}")
110
- json_buffer = ""
111
-
112
- # 3. This is the "happy path" where the chunk is valid.
113
- # Yield it in the Server-Sent Events (SSE) format.
114
- yield f"data: {json.dumps(chunk.dict())}\n\n"
115
-
116
- # 4. Try to record usage from the valid chunk itself.
117
- if not usage_recorded and hasattr(chunk, 'usage') and chunk.usage:
118
- await self.usage_manager.record_success(key, model, chunk)
119
- usage_recorded = True
120
- lib_logger.info(f"Recorded usage from stream chunk for key ...{key[-4:]}")
121
-
122
- except StopAsyncIteration:
123
- # 5. The stream has ended successfully.
124
- stream_completed = True
125
- if json_buffer:
126
- lib_logger.warning(f"Stream ended with incomplete data in buffer: {json_buffer}")
127
- break
128
-
129
- except Exception as e:
130
- # 6. An exception occurred, indicating a potentially malformed or fragmented chunk.
131
- lib_logger.info(f"Malformed chunk detected for key ...{key[-4:]}. Attempting to buffer and reassemble.")
132
-
133
- try:
134
- # 6a. The raw chunk string is usually in the exception message from litellm.
135
- # We extract it here. This is fragile but necessary.
136
- raw_chunk = str(e).split("Received chunk:")[-1].strip()
137
- json_buffer += raw_chunk
138
-
139
- # 6b. Try to parse the entire buffer.
140
- try:
141
- parsed_data = json.loads(json_buffer)
142
- # If successful, we have a complete JSON object.
143
- lib_logger.info(f"Successfully reassembled JSON from buffer: {json_buffer}")
144
-
145
- # 6b. INTELLIGENTLY INSPECT the reassembled object.
146
- if "error" in parsed_data:
147
- # This is a provider error. Log it and pass it through the stream.
148
- lib_logger.warning(f"Reassembled object is an API error. Passing it to the client.")
149
- yield f"data: {json.dumps(parsed_data)}\n\n"
150
- else:
151
- # This is a valid content chunk that was fragmented.
152
- lib_logger.info("Reassembled object is a valid content chunk.")
153
- yield f"data: {json.dumps(parsed_data)}\n\n"
154
-
155
- json_buffer = "" # Clear buffer after successful processing.
156
- except json.JSONDecodeError:
157
- # The buffer is still not a complete JSON object.
158
- # We'll continue to the next loop iteration to get more chunks.
159
- lib_logger.info(f"Buffer is still not a complete JSON object. Waiting for more chunks.")
160
- continue
161
- except Exception as buffer_exc:
162
- # If our own buffering logic fails, log it and reset to prevent getting stuck.
163
- lib_logger.error(f"Error during stream buffering logic: {buffer_exc}. Discarding buffer.")
164
- json_buffer = "" # Reset buffer on error
165
- continue
166
- finally:
167
- # 7. This block only runs if the stream completes successfully (no StreamedAPIError).
168
- if not usage_recorded:
169
- # If usage wasn't found in any chunk, try to get it from the final stream object.
170
- await self.usage_manager.record_success(key, model, stream)
171
- # 8. Release the key so it can be used by other requests.
172
- await self.usage_manager.release_key(key, model)
173
- lib_logger.info(f"STREAM FINISHED and lock released for key ...{key[-4:]}.")
174
- # 9. Only send the [DONE] message if the stream completed without being aborted.
175
- if stream_completed:
176
- yield "data: [DONE]\n\n"
177
-
178
-
179
  async def acompletion(self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
180
- """
181
- Performs a completion call with smart key rotation and retry logic.
182
- It will try each available key in sequence if the previous one fails.
183
- """
184
  kwargs = self._convert_model_params(**kwargs)
185
  model = kwargs.get("model")
186
  is_streaming = kwargs.get("stream", False)
@@ -188,133 +92,109 @@ class RotatingClient:
188
  if not model:
189
  raise ValueError("'model' is a required parameter.")
190
 
191
- provider = kwargs.get("model").split('/')[0]
192
  if provider not in self.api_keys:
193
  raise ValueError(f"No API keys configured for provider: {provider}")
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  keys_for_provider = self.api_keys[provider]
196
  tried_keys = set()
197
  last_exception = None
198
-
199
  while len(tried_keys) < len(keys_for_provider):
200
  current_key = None
201
  key_acquired = False
202
  try:
203
  keys_to_try = [k for k in keys_for_provider if k not in tried_keys]
204
  if not keys_to_try:
205
- break
206
 
207
- current_key = await self.usage_manager.acquire_key(
208
- available_keys=keys_to_try,
209
- model=model
210
- )
211
  key_acquired = True
212
  tried_keys.add(current_key)
213
 
214
- # Prepare litellm_kwargs once per key, not on every retry
215
  litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
216
-
217
  provider_instance = self._get_provider_instance(provider)
218
  if provider_instance:
219
-
220
- # Ensure safety_settings are present, defaulting to lowest if not provided
221
  if "safety_settings" not in litellm_kwargs:
222
- litellm_kwargs["safety_settings"] = {
223
- "harassment": "BLOCK_NONE",
224
- "hate_speech": "BLOCK_NONE",
225
- "sexually_explicit": "BLOCK_NONE",
226
- "dangerous_content": "BLOCK_NONE",
227
- }
228
-
229
  converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
230
-
231
- if converted_settings is not None:
232
- litellm_kwargs["safety_settings"] = converted_settings
233
- else:
234
- # If conversion returns None, remove it to avoid sending empty settings
235
- del litellm_kwargs["safety_settings"]
236
-
237
  if provider == "gemini":
238
- if provider_instance:
239
- provider_instance.handle_thinking_parameter(litellm_kwargs, model)
240
-
241
  if "gemma-3" in model and "messages" in litellm_kwargs:
242
- new_messages = [
243
- {"role": "user", "content": m["content"]} if m.get("role") == "system" else m
244
- for m in litellm_kwargs["messages"]
245
- ]
246
- litellm_kwargs["messages"] = new_messages
247
-
248
  litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
249
 
250
  for attempt in range(self.max_retries):
251
  try:
252
  lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
253
-
254
- if pre_request_callback:
255
- await pre_request_callback()
256
 
257
  response = await litellm.acompletion(api_key=current_key, **litellm_kwargs)
258
 
259
  if is_streaming:
260
- # For streaming, the wrapper takes responsibility for the key.
261
  key_acquired = False
262
- # The wrapper will now yield all chunks, including any provider errors,
263
- # and will handle releasing the key internally.
264
- return self._safe_streaming_wrapper(response, current_key, model)
265
  else:
266
- # For non-streaming, record and release here.
267
  await self.usage_manager.record_success(current_key, model, response)
 
268
  await self.usage_manager.release_key(current_key, model)
269
- key_acquired = False # Key has been released
270
  return response
271
 
272
  except Exception as e:
273
  last_exception = e
274
- log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
275
 
276
- classified_error = classify_error(e)
277
-
278
  classified_error = classify_error(e)
279
 
280
  if classified_error.error_type in ['invalid_request', 'context_window_exceeded']:
281
- # These errors are not recoverable by rotating keys, so fail fast.
282
- lib_logger.error(f"Unrecoverable error '{classified_error.error_type}' with key ...{current_key[-4:]}. Failing request.")
283
  raise last_exception
284
-
285
  if request and await request.is_disconnected():
286
  lib_logger.warning(f"Client disconnected. Aborting retries for key ...{current_key[-4:]}.")
287
  raise last_exception
288
 
289
  if classified_error.error_type in ['server_error', 'api_connection']:
290
- # These are temporary, so record the failure and retry with backoff.
291
  await self.usage_manager.record_failure(current_key, model, classified_error)
292
-
293
- if attempt >= self.max_retries - 1:
294
- lib_logger.warning(f"Key ...{current_key[-4:]} failed on final retry for {classified_error.error_type}. Trying next key.")
295
- break
296
-
297
- # Use a longer cooldown for API connection errors
298
- base_wait = 5 if classified_error.error_type == 'api_connection' else 1
299
- wait_time = classified_error.retry_after or (base_wait * (2 ** attempt)) + random.uniform(0, 1)
300
-
301
- lib_logger.warning(f"Key ...{current_key[-4:]} encountered a {classified_error.error_type}. Retrying in {wait_time:.2f} seconds...")
302
  await asyncio.sleep(wait_time)
303
  continue
304
-
305
- # For other errors (rate_limit, authentication, unknown), record failure and try the next key.
306
  await self.usage_manager.record_failure(current_key, model, classified_error)
307
- lib_logger.warning(f"Key ...{current_key[-4:]} encountered '{classified_error.error_type}'. Trying next key.")
308
  break
 
 
 
 
 
 
 
309
  finally:
310
- # This block ensures the key is always released if it was acquired but not passed to the wrapper.
311
  if key_acquired and current_key:
312
  await self.usage_manager.release_key(current_key, model)
313
 
314
  if last_exception:
315
  raise last_exception
316
 
317
- raise Exception("Failed to complete the request: No available API keys for the provider or all keys failed.")
318
 
319
  async def aembedding(self, request: Optional[Any] = None, **kwargs) -> Any:
320
  """
 
84
  return None
85
  return self._provider_instances[provider_name]
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  async def acompletion(self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
 
 
 
 
88
  kwargs = self._convert_model_params(**kwargs)
89
  model = kwargs.get("model")
90
  is_streaming = kwargs.get("stream", False)
 
92
  if not model:
93
  raise ValueError("'model' is a required parameter.")
94
 
95
+ provider = model.split('/')[0]
96
  if provider not in self.api_keys:
97
  raise ValueError(f"No API keys configured for provider: {provider}")
98
 
99
+ async def _streaming_generator(_stream, _key, _model):
100
+ """Generator that yields stream chunks and handles key release."""
101
+ try:
102
+ async for chunk in _stream:
103
+ if request and await request.is_disconnected():
104
+ lib_logger.warning("Client disconnected, cancelling stream.")
105
+ raise asyncio.CancelledError("Client disconnected")
106
+ yield f"data: {json.dumps(chunk.dict())}\n\n"
107
+
108
+ await self.usage_manager.record_success(_key, _model, _stream)
109
+ yield "data: [DONE]\n\n"
110
+ finally:
111
+ await self.usage_manager.release_key(_key, _model)
112
+ lib_logger.info(f"STREAM FINISHED and lock released for key ...{_key[-4:]}.")
113
+
114
  keys_for_provider = self.api_keys[provider]
115
  tried_keys = set()
116
  last_exception = None
117
+
118
  while len(tried_keys) < len(keys_for_provider):
119
  current_key = None
120
  key_acquired = False
121
  try:
122
  keys_to_try = [k for k in keys_for_provider if k not in tried_keys]
123
  if not keys_to_try:
124
+ break
125
 
126
+ current_key = await self.usage_manager.acquire_key(available_keys=keys_to_try, model=model)
 
 
 
127
  key_acquired = True
128
  tried_keys.add(current_key)
129
 
 
130
  litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
 
131
  provider_instance = self._get_provider_instance(provider)
132
  if provider_instance:
 
 
133
  if "safety_settings" not in litellm_kwargs:
134
+ litellm_kwargs["safety_settings"] = {"harassment": "BLOCK_NONE", "hate_speech": "BLOCK_NONE", "sexually_explicit": "BLOCK_NONE", "dangerous_content": "BLOCK_NONE"}
 
 
 
 
 
 
135
  converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
136
+ if converted_settings is not None: litellm_kwargs["safety_settings"] = converted_settings
137
+ else: del litellm_kwargs["safety_settings"]
 
 
 
 
 
138
  if provider == "gemini":
139
+ if provider_instance: provider_instance.handle_thinking_parameter(litellm_kwargs, model)
 
 
140
  if "gemma-3" in model and "messages" in litellm_kwargs:
141
+ litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
 
 
 
 
 
142
  litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
143
 
144
  for attempt in range(self.max_retries):
145
  try:
146
  lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
147
+ if pre_request_callback: await pre_request_callback()
 
 
148
 
149
  response = await litellm.acompletion(api_key=current_key, **litellm_kwargs)
150
 
151
  if is_streaming:
 
152
  key_acquired = False
153
+ return _streaming_generator(response, current_key, model)
 
 
154
  else:
 
155
  await self.usage_manager.record_success(current_key, model, response)
156
+ key_acquired = False
157
  await self.usage_manager.release_key(current_key, model)
 
158
  return response
159
 
160
  except Exception as e:
161
  last_exception = e
162
+ if isinstance(e, asyncio.CancelledError): raise e
163
 
164
+ log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
 
165
  classified_error = classify_error(e)
166
 
167
  if classified_error.error_type in ['invalid_request', 'context_window_exceeded']:
 
 
168
  raise last_exception
169
+
170
  if request and await request.is_disconnected():
171
  lib_logger.warning(f"Client disconnected. Aborting retries for key ...{current_key[-4:]}.")
172
  raise last_exception
173
 
174
  if classified_error.error_type in ['server_error', 'api_connection']:
 
175
  await self.usage_manager.record_failure(current_key, model, classified_error)
176
+ if attempt >= self.max_retries - 1: break
177
+ wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
 
 
 
 
 
 
 
 
178
  await asyncio.sleep(wait_time)
179
  continue
180
+
 
181
  await self.usage_manager.record_failure(current_key, model, classified_error)
 
182
  break
183
+
184
+ except (litellm.InvalidRequestError, litellm.ContextWindowExceededError, asyncio.CancelledError) as e:
185
+ raise e
186
+ except Exception as e:
187
+ last_exception = e
188
+ lib_logger.error(f"An unexpected error occurred with key ...{current_key[-4:] if current_key else 'N/A'}: {e}")
189
+ continue
190
  finally:
 
191
  if key_acquired and current_key:
192
  await self.usage_manager.release_key(current_key, model)
193
 
194
  if last_exception:
195
  raise last_exception
196
 
197
+ raise Exception("Failed to complete the request: No available API keys or all keys failed.")
198
 
199
  async def aembedding(self, request: Optional[Any] = None, **kwargs) -> Any:
200
  """