llm-api-proxy / src /proxy_app /batch_manager.py
Mirrowel
feat(embeddings): Implement request batching for performance
1337d48
import asyncio
from typing import List, Dict, Any, Tuple
import time
from rotator_library import RotatingClient
class EmbeddingBatcher:
def __init__(self, client: RotatingClient, batch_size: int = 64, timeout: float = 0.1):
self.client = client
self.batch_size = batch_size
self.timeout = timeout
self.queue = asyncio.Queue()
self.worker_task = asyncio.create_task(self._batch_worker())
async def add_request(self, request_data: Dict[str, Any]) -> Any:
future = asyncio.Future()
await self.queue.put((request_data, future))
return await future
async def _batch_worker(self):
while True:
batch, futures = await self._gather_batch()
if not batch:
continue
try:
# Assume all requests in a batch use the same model and other settings
model = batch[0]["model"]
inputs = [item["input"][0] for item in batch] # Extract single string input
batched_request = {
"model": model,
"input": inputs
}
# Pass through any other relevant parameters from the first request
for key in ["input_type", "dimensions", "user"]:
if key in batch[0]:
batched_request[key] = batch[0][key]
response = await self.client.aembedding(**batched_request)
# Distribute results back to the original requesters
for i, future in enumerate(futures):
# Create a new response object for each item in the batch
single_response_data = {
"object": response.object,
"model": response.model,
"data": [response.data[i]],
"usage": response.usage # Usage is for the whole batch
}
future.set_result(single_response_data)
except Exception as e:
for future in futures:
future.set_exception(e)
async def _gather_batch(self) -> Tuple[List[Dict[str, Any]], List[asyncio.Future]]:
batch = []
futures = []
start_time = time.time()
while len(batch) < self.batch_size and (time.time() - start_time) < self.timeout:
try:
# Wait for an item with a timeout
timeout = self.timeout - (time.time() - start_time)
if timeout <= 0:
break
request, future = await asyncio.wait_for(self.queue.get(), timeout=timeout)
batch.append(request)
futures.append(future)
except asyncio.TimeoutError:
break
return batch, futures
async def stop(self):
self.worker_task.cancel()
try:
await self.worker_task
except asyncio.CancelledError:
pass