Mirrowel commited on
Commit
64e385a
·
1 Parent(s): dbb8b44

feat: Add support for embeddings endpoint

Browse files

This commit introduces a new OpenAI-compatible `/v1/embeddings` endpoint to the proxy, enabling the creation of text embeddings.

Key changes include:
- A new `aembedding` method in `RotatingClient` to handle API calls, key rotation, and rate limit errors for embedding requests.
- The `UsageManager` is updated to correctly track token usage from embedding responses.
- Request logging is enhanced to categorize logs into `completions` and `embeddings` subdirectories for better organization.
- A sanitizer for embedding requests is added to filter unsupported parameters before calling the OpenAI API.

src/proxy_app/main.py CHANGED
@@ -10,9 +10,18 @@ import logging
10
  from pathlib import Path
11
  import sys
12
  import json
13
- from typing import AsyncGenerator, Any
 
14
  import argparse
15
 
 
 
 
 
 
 
 
 
16
  # --- Argument Parsing ---
17
  parser = argparse.ArgumentParser(description="API Key Proxy Server")
18
  parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to.")
@@ -188,7 +197,8 @@ async def streaming_response_wrapper(
188
  log_request_response(
189
  request_data=request_data,
190
  response_data=full_response,
191
- is_streaming=True
 
192
  )
193
 
194
  @app.post("/v1/chat/completions")
@@ -219,7 +229,8 @@ async def chat_completions(
219
  log_request_response(
220
  request_data=request_data,
221
  response_data=response.dict(),
222
- is_streaming=False
 
223
  )
224
  return response
225
 
@@ -234,7 +245,53 @@ async def chat_completions(
234
  log_request_response(
235
  request_data=request_data,
236
  response_data={"error": str(e)},
237
- is_streaming=request_data.get("stream", False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  )
239
  raise HTTPException(status_code=500, detail=str(e))
240
 
 
10
  from pathlib import Path
11
  import sys
12
  import json
13
+ from typing import AsyncGenerator, Any, List, Optional
14
+ from pydantic import BaseModel
15
  import argparse
16
 
17
+ # --- Pydantic Models ---
18
+ class EmbeddingRequest(BaseModel):
19
+ model: str
20
+ input: List[str]
21
+ input_type: Optional[str] = None
22
+ dimensions: Optional[int] = None
23
+ user: Optional[str] = None
24
+
25
  # --- Argument Parsing ---
26
  parser = argparse.ArgumentParser(description="API Key Proxy Server")
27
  parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to.")
 
197
  log_request_response(
198
  request_data=request_data,
199
  response_data=full_response,
200
+ is_streaming=True,
201
+ log_type="completion"
202
  )
203
 
204
  @app.post("/v1/chat/completions")
 
229
  log_request_response(
230
  request_data=request_data,
231
  response_data=response.dict(),
232
+ is_streaming=False,
233
+ log_type="completion"
234
  )
235
  return response
236
 
 
245
  log_request_response(
246
  request_data=request_data,
247
  response_data={"error": str(e)},
248
+ is_streaming=request_data.get("stream", False),
249
+ log_type="completion"
250
+ )
251
+ raise HTTPException(status_code=500, detail=str(e))
252
+
253
+ @app.post("/v1/embeddings")
254
+ async def embeddings(
255
+ request: Request,
256
+ body: EmbeddingRequest,
257
+ client: RotatingClient = Depends(get_rotating_client),
258
+ _ = Depends(verify_api_key)
259
+ ):
260
+ """
261
+ OpenAI-compatible endpoint for creating embeddings.
262
+ """
263
+ try:
264
+ request_data = body.dict(exclude_none=True)
265
+ response = await client.aembedding(**request_data)
266
+
267
+ if ENABLE_REQUEST_LOGGING:
268
+ response_summary = {
269
+ "model": response.model,
270
+ "object": response.object,
271
+ "usage": response.usage.dict(),
272
+ "data_count": len(response.data),
273
+ "embedding_dimensions": len(response.data[0].embedding) if response.data else 0
274
+ }
275
+ log_request_response(
276
+ request_data=request_data,
277
+ response_data=response_summary,
278
+ is_streaming=False,
279
+ log_type="embedding"
280
+ )
281
+ return response
282
+
283
+ except Exception as e:
284
+ logging.error(f"Embedding request failed: {e}")
285
+ if ENABLE_REQUEST_LOGGING:
286
+ try:
287
+ request_data = await request.json()
288
+ except json.JSONDecodeError:
289
+ request_data = {"error": "Could not parse request body"}
290
+ log_request_response(
291
+ request_data=request_data,
292
+ response_data={"error": str(e)},
293
+ is_streaming=False,
294
+ log_type="embedding"
295
  )
296
  raise HTTPException(status_code=500, detail=str(e))
297
 
src/proxy_app/request_logger.py CHANGED
@@ -3,18 +3,39 @@ import os
3
  from datetime import datetime
4
  from pathlib import Path
5
  import uuid
 
6
 
7
  LOGS_DIR = Path(__file__).resolve().parent.parent.parent / "logs"
 
 
 
 
8
  LOGS_DIR.mkdir(exist_ok=True)
 
 
9
 
10
- def log_request_response(request_data: dict, response_data: dict, is_streaming: bool):
 
 
 
 
 
11
  """
12
- Logs the request and response data to a single file in the logs directory.
13
  """
14
  try:
15
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
16
  unique_id = uuid.uuid4()
17
- filename = LOGS_DIR / f"{timestamp}_{unique_id}.json"
 
 
 
 
 
 
 
 
 
18
 
19
  log_content = {
20
  "request": request_data,
 
3
  from datetime import datetime
4
  from pathlib import Path
5
  import uuid
6
+ from typing import Literal
7
 
8
  LOGS_DIR = Path(__file__).resolve().parent.parent.parent / "logs"
9
+ COMPLETIONS_LOGS_DIR = LOGS_DIR / "completions"
10
+ EMBEDDINGS_LOGS_DIR = LOGS_DIR / "embeddings"
11
+
12
+ # Create directories if they don't exist
13
  LOGS_DIR.mkdir(exist_ok=True)
14
+ COMPLETIONS_LOGS_DIR.mkdir(exist_ok=True)
15
+ EMBEDDINGS_LOGS_DIR.mkdir(exist_ok=True)
16
 
17
+ def log_request_response(
18
+ request_data: dict,
19
+ response_data: dict,
20
+ is_streaming: bool,
21
+ log_type: Literal["completion", "embedding"]
22
+ ):
23
  """
24
+ Logs the request and response data to a file in the appropriate log directory.
25
  """
26
  try:
27
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
28
  unique_id = uuid.uuid4()
29
+
30
+ if log_type == "completion":
31
+ target_dir = COMPLETIONS_LOGS_DIR
32
+ elif log_type == "embedding":
33
+ target_dir = EMBEDDINGS_LOGS_DIR
34
+ else:
35
+ # Fallback to the main logs directory if log_type is invalid
36
+ target_dir = LOGS_DIR
37
+
38
+ filename = target_dir / f"{timestamp}_{unique_id}.json"
39
 
40
  log_content = {
41
  "request": request_data,
src/rotator_library/client.py CHANGED
@@ -305,6 +305,87 @@ class RotatingClient:
305
 
306
  raise Exception("Failed to complete the request: No available API keys for the provider or all keys failed.")
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  def token_count(self, model: str, text: str = None, messages: List[Dict[str, str]] = None) -> int:
309
  """Calculates the number of tokens for a given text or list of messages."""
310
  if not model:
 
305
 
306
  raise Exception("Failed to complete the request: No available API keys for the provider or all keys failed.")
307
 
308
+ async def aembedding(self, **kwargs) -> Any:
309
+ """
310
+ Performs an embedding call with smart key rotation and retry logic.
311
+ """
312
+ model = kwargs.get("model")
313
+ if not model:
314
+ raise ValueError("'model' is a required parameter.")
315
+
316
+ provider = model.split('/')[0]
317
+ if provider not in self.api_keys:
318
+ raise ValueError(f"No API keys configured for provider: {provider}")
319
+
320
+ keys_for_provider = self.api_keys[provider]
321
+ tried_keys = set()
322
+ last_exception = None
323
+
324
+ while len(tried_keys) < len(keys_for_provider):
325
+ current_key = None
326
+ key_acquired = False
327
+ try:
328
+ keys_to_try = [k for k in keys_for_provider if k not in tried_keys]
329
+ if not keys_to_try:
330
+ break
331
+
332
+ current_key = await self.usage_manager.acquire_key(
333
+ available_keys=keys_to_try,
334
+ model=model
335
+ )
336
+ key_acquired = True
337
+ tried_keys.add(current_key)
338
+
339
+ litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
340
+ litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
341
+
342
+ for attempt in range(self.max_retries):
343
+ try:
344
+ lib_logger.info(f"Attempting embedding call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
345
+
346
+ response = await litellm.aembedding(api_key=current_key, **litellm_kwargs)
347
+
348
+ await self.usage_manager.record_success(current_key, model, response)
349
+ await self.usage_manager.release_key(current_key, model)
350
+ key_acquired = False
351
+ return response
352
+
353
+ except Exception as e:
354
+ last_exception = e
355
+ log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
356
+
357
+ classified_error = classify_error(e)
358
+
359
+ if classified_error.error_type in ['invalid_request', 'context_window_exceeded']:
360
+ lib_logger.error(f"Unrecoverable error '{classified_error.error_type}' with key ...{current_key[-4:]}. Failing request.")
361
+ raise last_exception
362
+
363
+ if classified_error.error_type in ['server_error', 'api_connection']:
364
+ await self.usage_manager.record_failure(current_key, model, classified_error)
365
+
366
+ if attempt >= self.max_retries - 1:
367
+ lib_logger.warning(f"Key ...{current_key[-4:]} failed on final retry for {classified_error.error_type}. Trying next key.")
368
+ break
369
+
370
+ base_wait = 5 if classified_error.error_type == 'api_connection' else 1
371
+ wait_time = classified_error.retry_after or (base_wait * (2 ** attempt)) + random.uniform(0, 1)
372
+
373
+ lib_logger.warning(f"Key ...{current_key[-4:]} encountered a {classified_error.error_type}. Retrying in {wait_time:.2f} seconds...")
374
+ await asyncio.sleep(wait_time)
375
+ continue
376
+
377
+ await self.usage_manager.record_failure(current_key, model, classified_error)
378
+ lib_logger.warning(f"Key ...{current_key[-4:]} encountered '{classified_error.error_type}'. Trying next key.")
379
+ break
380
+ finally:
381
+ if key_acquired and current_key:
382
+ await self.usage_manager.release_key(current_key, model)
383
+
384
+ if last_exception:
385
+ raise last_exception
386
+
387
+ raise Exception("Failed to complete the request: No available API keys for the provider or all keys failed.")
388
+
389
  def token_count(self, model: str, text: str = None, messages: List[Dict[str, str]] = None) -> int:
390
  """Calculates the number of tokens for a given text or list of messages."""
391
  if not model:
src/rotator_library/request_sanitizer.py CHANGED
@@ -4,6 +4,9 @@ def sanitize_request_payload(payload: Dict[str, Any], model: str) -> Dict[str, A
4
  """
5
  Removes unsupported parameters from the request payload based on the model.
6
  """
 
 
 
7
  if payload.get("thinking") == {"type": "enabled", "budget_tokens": -1}:
8
  if model not in ["gemini/gemini-2.5-pro", "gemini/gemini-2.5-flash"]:
9
  del payload["thinking"]
 
4
  """
5
  Removes unsupported parameters from the request payload based on the model.
6
  """
7
+ if "dimensions" in payload and not model.startswith("openai/text-embedding-3"):
8
+ del payload["dimensions"]
9
+
10
  if payload.get("thinking") == {"type": "enabled", "budget_tokens": -1}:
11
  if model not in ["gemini/gemini-2.5-pro", "gemini/gemini-2.5-flash"]:
12
  del payload["thinking"]
src/rotator_library/usage_manager.py CHANGED
@@ -244,11 +244,17 @@ class UsageManager:
244
  if completion_response and hasattr(completion_response, 'usage') and completion_response.usage:
245
  usage = completion_response.usage
246
  daily_model_data["prompt_tokens"] += usage.prompt_tokens
247
- daily_model_data["completion_tokens"] += usage.completion_tokens
248
  lib_logger.info(f"Recorded usage from final stream object for key ...{key[-4:]}")
249
  try:
250
- cost = litellm.completion_cost(completion_response=completion_response)
251
- daily_model_data["approx_cost"] += cost
 
 
 
 
 
 
252
  except Exception as e:
253
  lib_logger.warning(f"Could not calculate cost for model {model}: {e}")
254
  else:
 
244
  if completion_response and hasattr(completion_response, 'usage') and completion_response.usage:
245
  usage = completion_response.usage
246
  daily_model_data["prompt_tokens"] += usage.prompt_tokens
247
+ daily_model_data["completion_tokens"] += getattr(usage, 'completion_tokens', 0) # Not present in embedding responses
248
  lib_logger.info(f"Recorded usage from final stream object for key ...{key[-4:]}")
249
  try:
250
+ # Differentiate cost calculation based on response type
251
+ if isinstance(completion_response, litellm.EmbeddingResponse):
252
+ cost = litellm.embedding_cost(embedding_response=completion_response)
253
+ else:
254
+ cost = litellm.completion_cost(completion_response=completion_response)
255
+
256
+ if cost is not None:
257
+ daily_model_data["approx_cost"] += cost
258
  except Exception as e:
259
  lib_logger.warning(f"Could not calculate cost for model {model}: {e}")
260
  else: