Mirrowel commited on
Commit
1337d48
·
1 Parent(s): 304ae92

feat(embeddings): Implement request batching for performance

Browse files

This commit introduces an asynchronous batching mechanism for the `/v1/embeddings` endpoint to improve performance and efficiency.

A new `EmbeddingBatcher` class manages a queue of incoming requests. A background worker gathers these requests into batches, limited by size or a timeout, and sends them as a single API call to the provider. This significantly reduces the number of network requests under high concurrency.

Key changes:
- Add `EmbeddingBatcher` class in a new `src/proxy_app/batch_manager.py` module.
- Integrate the batcher into the FastAPI app lifecycle.
- Refactor the `/v1/embeddings` endpoint to use the new batcher.
- Update documentation to note that this feature is a work in progress.

DOCUMENTATION.md CHANGED
@@ -142,7 +142,7 @@ The application uses FastAPI's `lifespan` context manager to manage the `Rotatin
142
  #### Endpoints
143
 
144
  * `POST /v1/chat/completions`: The main endpoint for chat requests.
145
- * `POST /v1/embeddings`: The endpoint for creating embeddings.
146
  * `GET /v1/models`: Returns a list of all available models from configured providers.
147
  * `GET /v1/providers`: Returns a list of all configured providers.
148
  * `POST /v1/token-count`: Calculates the token count for a given message payload.
 
142
  #### Endpoints
143
 
144
  * `POST /v1/chat/completions`: The main endpoint for chat requests.
145
+ * `POST /v1/embeddings`: The endpoint for creating embeddings. - Not fully functional yet.
146
  * `GET /v1/models`: Returns a list of all available models from configured providers.
147
  * `GET /v1/providers`: Returns a list of all configured providers.
148
  * `POST /v1/token-count`: Calculates the token count for a given message payload.
README.md CHANGED
@@ -208,7 +208,7 @@ curl -X POST http://127.0.0.1:8000/v1/chat/completions \
208
  ### Available API Endpoints
209
 
210
  - `POST /v1/chat/completions`: The main endpoint for making chat requests.
211
- - `POST /v1/embeddings`: The endpoint for creating embeddings.
212
  - `GET /v1/models`: Returns a list of all available models from your configured providers.
213
  - `GET /v1/providers`: Returns a list of all configured providers.
214
  - `POST /v1/token-count`: Calculates the token count for a given message payload.
 
208
  ### Available API Endpoints
209
 
210
  - `POST /v1/chat/completions`: The main endpoint for making chat requests.
211
+ - `POST /v1/embeddings`: The endpoint for creating embeddings. - Not fully functional yet.
212
  - `GET /v1/models`: Returns a list of all available models from your configured providers.
213
  - `GET /v1/providers`: Returns a list of all configured providers.
214
  - `POST /v1/token-count`: Calculates the token count for a given message payload.
src/proxy_app/batch_manager.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import List, Dict, Any, Tuple
3
+ import time
4
+ from rotator_library import RotatingClient
5
+
6
+ class EmbeddingBatcher:
7
+ def __init__(self, client: RotatingClient, batch_size: int = 64, timeout: float = 0.1):
8
+ self.client = client
9
+ self.batch_size = batch_size
10
+ self.timeout = timeout
11
+ self.queue = asyncio.Queue()
12
+ self.worker_task = asyncio.create_task(self._batch_worker())
13
+
14
+ async def add_request(self, request_data: Dict[str, Any]) -> Any:
15
+ future = asyncio.Future()
16
+ await self.queue.put((request_data, future))
17
+ return await future
18
+
19
+ async def _batch_worker(self):
20
+ while True:
21
+ batch, futures = await self._gather_batch()
22
+ if not batch:
23
+ continue
24
+
25
+ try:
26
+ # Assume all requests in a batch use the same model and other settings
27
+ model = batch[0]["model"]
28
+ inputs = [item["input"][0] for item in batch] # Extract single string input
29
+
30
+ batched_request = {
31
+ "model": model,
32
+ "input": inputs
33
+ }
34
+
35
+ # Pass through any other relevant parameters from the first request
36
+ for key in ["input_type", "dimensions", "user"]:
37
+ if key in batch[0]:
38
+ batched_request[key] = batch[0][key]
39
+
40
+ response = await self.client.aembedding(**batched_request)
41
+
42
+ # Distribute results back to the original requesters
43
+ for i, future in enumerate(futures):
44
+ # Create a new response object for each item in the batch
45
+ single_response_data = {
46
+ "object": response.object,
47
+ "model": response.model,
48
+ "data": [response.data[i]],
49
+ "usage": response.usage # Usage is for the whole batch
50
+ }
51
+ future.set_result(single_response_data)
52
+
53
+ except Exception as e:
54
+ for future in futures:
55
+ future.set_exception(e)
56
+
57
+ async def _gather_batch(self) -> Tuple[List[Dict[str, Any]], List[asyncio.Future]]:
58
+ batch = []
59
+ futures = []
60
+ start_time = time.time()
61
+
62
+ while len(batch) < self.batch_size and (time.time() - start_time) < self.timeout:
63
+ try:
64
+ # Wait for an item with a timeout
65
+ timeout = self.timeout - (time.time() - start_time)
66
+ if timeout <= 0:
67
+ break
68
+ request, future = await asyncio.wait_for(self.queue.get(), timeout=timeout)
69
+ batch.append(request)
70
+ futures.append(future)
71
+ except asyncio.TimeoutError:
72
+ break
73
+
74
+ return batch, futures
75
+
76
+ async def stop(self):
77
+ self.worker_task.cancel()
78
+ try:
79
+ await self.worker_task
80
+ except asyncio.CancelledError:
81
+ pass
src/proxy_app/main.py CHANGED
@@ -36,6 +36,7 @@ sys.path.append(str(Path(__file__).resolve().parent.parent))
36
 
37
  from rotator_library import RotatingClient, PROVIDER_PLUGINS
38
  from proxy_app.request_logger import log_request_response
 
39
 
40
  # Configure logging
41
  logging.basicConfig(level=logging.INFO)
@@ -67,11 +68,15 @@ if not api_keys:
67
  @asynccontextmanager
68
  async def lifespan(app: FastAPI):
69
  """Manage the RotatingClient's lifecycle with the app's lifespan."""
70
- app.state.rotating_client = RotatingClient(api_keys=api_keys)
71
- print("RotatingClient initialized.")
 
 
 
72
  yield
73
- await app.state.rotating_client.close()
74
- print("RotatingClient closed.")
 
75
 
76
  # --- FastAPI App Setup ---
77
  app = FastAPI(lifespan=lifespan)
@@ -81,6 +86,10 @@ def get_rotating_client(request: Request) -> RotatingClient:
81
  """Dependency to get the rotating client instance from the app state."""
82
  return request.app.state.rotating_client
83
 
 
 
 
 
84
  async def verify_api_key(auth: str = Depends(api_key_header)):
85
  """Dependency to verify the proxy API key."""
86
  if not auth or auth != f"Bearer {PROXY_API_KEY}":
@@ -267,21 +276,27 @@ async def chat_completions(
267
  async def embeddings(
268
  request: Request,
269
  body: EmbeddingRequest,
270
- client: RotatingClient = Depends(get_rotating_client),
271
  _ = Depends(verify_api_key)
272
  ):
273
  """
274
  OpenAI-compatible endpoint for creating embeddings.
 
275
  """
276
  try:
277
  request_data = body.model_dump(exclude_none=True)
278
 
279
- # Ensure input is always a list for the client
280
- if isinstance(request_data.get("input"), str):
281
- request_data["input"] = [request_data["input"]]
 
 
282
 
283
- response = await client.aembedding(**request_data)
284
 
 
 
 
285
  if ENABLE_REQUEST_LOGGING:
286
  response_summary = {
287
  "model": response.model,
 
36
 
37
  from rotator_library import RotatingClient, PROVIDER_PLUGINS
38
  from proxy_app.request_logger import log_request_response
39
+ from proxy_app.batch_manager import EmbeddingBatcher
40
 
41
  # Configure logging
42
  logging.basicConfig(level=logging.INFO)
 
68
  @asynccontextmanager
69
  async def lifespan(app: FastAPI):
70
  """Manage the RotatingClient's lifecycle with the app's lifespan."""
71
+ client = RotatingClient(api_keys=api_keys)
72
+ batcher = EmbeddingBatcher(client=client)
73
+ app.state.rotating_client = client
74
+ app.state.embedding_batcher = batcher
75
+ print("RotatingClient and EmbeddingBatcher initialized.")
76
  yield
77
+ await batcher.stop()
78
+ await client.close()
79
+ print("RotatingClient and EmbeddingBatcher closed.")
80
 
81
  # --- FastAPI App Setup ---
82
  app = FastAPI(lifespan=lifespan)
 
86
  """Dependency to get the rotating client instance from the app state."""
87
  return request.app.state.rotating_client
88
 
89
+ def get_embedding_batcher(request: Request) -> EmbeddingBatcher:
90
+ """Dependency to get the embedding batcher instance from the app state."""
91
+ return request.app.state.embedding_batcher
92
+
93
  async def verify_api_key(auth: str = Depends(api_key_header)):
94
  """Dependency to verify the proxy API key."""
95
  if not auth or auth != f"Bearer {PROXY_API_KEY}":
 
276
  async def embeddings(
277
  request: Request,
278
  body: EmbeddingRequest,
279
+ batcher: EmbeddingBatcher = Depends(get_embedding_batcher),
280
  _ = Depends(verify_api_key)
281
  ):
282
  """
283
  OpenAI-compatible endpoint for creating embeddings.
284
+ This endpoint uses a batching manager to group requests for efficiency.
285
  """
286
  try:
287
  request_data = body.model_dump(exclude_none=True)
288
 
289
+ # The batcher expects a single string input per request
290
+ if isinstance(request_data.get("input"), list):
291
+ if len(request_data["input"]) > 1:
292
+ raise HTTPException(status_code=400, detail="Batching multiple inputs in a single request is not supported when using the server-side batcher. Please send one input string per request.")
293
+ request_data["input"] = request_data["input"][0]
294
 
295
+ response_data = await batcher.add_request(request_data)
296
 
297
+ # The batcher returns a dict, not a Pydantic model, so we construct it
298
+ response = litellm.EmbeddingResponse(**response_data)
299
+
300
  if ENABLE_REQUEST_LOGGING:
301
  response_summary = {
302
  "model": response.model,