Spaces:
Paused
Paused
Mirrowel commited on
Commit ·
1337d48
1
Parent(s): 304ae92
feat(embeddings): Implement request batching for performance
Browse filesThis commit introduces an asynchronous batching mechanism for the `/v1/embeddings` endpoint to improve performance and efficiency.
A new `EmbeddingBatcher` class manages a queue of incoming requests. A background worker gathers these requests into batches, limited by size or a timeout, and sends them as a single API call to the provider. This significantly reduces the number of network requests under high concurrency.
Key changes:
- Add `EmbeddingBatcher` class in a new `src/proxy_app/batch_manager.py` module.
- Integrate the batcher into the FastAPI app lifecycle.
- Refactor the `/v1/embeddings` endpoint to use the new batcher.
- Update documentation to note that this feature is a work in progress.
- DOCUMENTATION.md +1 -1
- README.md +1 -1
- src/proxy_app/batch_manager.py +81 -0
- src/proxy_app/main.py +24 -9
DOCUMENTATION.md
CHANGED
|
@@ -142,7 +142,7 @@ The application uses FastAPI's `lifespan` context manager to manage the `Rotatin
|
|
| 142 |
#### Endpoints
|
| 143 |
|
| 144 |
* `POST /v1/chat/completions`: The main endpoint for chat requests.
|
| 145 |
-
* `POST /v1/embeddings`: The endpoint for creating embeddings.
|
| 146 |
* `GET /v1/models`: Returns a list of all available models from configured providers.
|
| 147 |
* `GET /v1/providers`: Returns a list of all configured providers.
|
| 148 |
* `POST /v1/token-count`: Calculates the token count for a given message payload.
|
|
|
|
| 142 |
#### Endpoints
|
| 143 |
|
| 144 |
* `POST /v1/chat/completions`: The main endpoint for chat requests.
|
| 145 |
+
* `POST /v1/embeddings`: The endpoint for creating embeddings. - Not fully functional yet.
|
| 146 |
* `GET /v1/models`: Returns a list of all available models from configured providers.
|
| 147 |
* `GET /v1/providers`: Returns a list of all configured providers.
|
| 148 |
* `POST /v1/token-count`: Calculates the token count for a given message payload.
|
README.md
CHANGED
|
@@ -208,7 +208,7 @@ curl -X POST http://127.0.0.1:8000/v1/chat/completions \
|
|
| 208 |
### Available API Endpoints
|
| 209 |
|
| 210 |
- `POST /v1/chat/completions`: The main endpoint for making chat requests.
|
| 211 |
-
- `POST /v1/embeddings`: The endpoint for creating embeddings.
|
| 212 |
- `GET /v1/models`: Returns a list of all available models from your configured providers.
|
| 213 |
- `GET /v1/providers`: Returns a list of all configured providers.
|
| 214 |
- `POST /v1/token-count`: Calculates the token count for a given message payload.
|
|
|
|
| 208 |
### Available API Endpoints
|
| 209 |
|
| 210 |
- `POST /v1/chat/completions`: The main endpoint for making chat requests.
|
| 211 |
+
- `POST /v1/embeddings`: The endpoint for creating embeddings. - Not fully functional yet.
|
| 212 |
- `GET /v1/models`: Returns a list of all available models from your configured providers.
|
| 213 |
- `GET /v1/providers`: Returns a list of all configured providers.
|
| 214 |
- `POST /v1/token-count`: Calculates the token count for a given message payload.
|
src/proxy_app/batch_manager.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from typing import List, Dict, Any, Tuple
|
| 3 |
+
import time
|
| 4 |
+
from rotator_library import RotatingClient
|
| 5 |
+
|
| 6 |
+
class EmbeddingBatcher:
|
| 7 |
+
def __init__(self, client: RotatingClient, batch_size: int = 64, timeout: float = 0.1):
|
| 8 |
+
self.client = client
|
| 9 |
+
self.batch_size = batch_size
|
| 10 |
+
self.timeout = timeout
|
| 11 |
+
self.queue = asyncio.Queue()
|
| 12 |
+
self.worker_task = asyncio.create_task(self._batch_worker())
|
| 13 |
+
|
| 14 |
+
async def add_request(self, request_data: Dict[str, Any]) -> Any:
|
| 15 |
+
future = asyncio.Future()
|
| 16 |
+
await self.queue.put((request_data, future))
|
| 17 |
+
return await future
|
| 18 |
+
|
| 19 |
+
async def _batch_worker(self):
|
| 20 |
+
while True:
|
| 21 |
+
batch, futures = await self._gather_batch()
|
| 22 |
+
if not batch:
|
| 23 |
+
continue
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
# Assume all requests in a batch use the same model and other settings
|
| 27 |
+
model = batch[0]["model"]
|
| 28 |
+
inputs = [item["input"][0] for item in batch] # Extract single string input
|
| 29 |
+
|
| 30 |
+
batched_request = {
|
| 31 |
+
"model": model,
|
| 32 |
+
"input": inputs
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
# Pass through any other relevant parameters from the first request
|
| 36 |
+
for key in ["input_type", "dimensions", "user"]:
|
| 37 |
+
if key in batch[0]:
|
| 38 |
+
batched_request[key] = batch[0][key]
|
| 39 |
+
|
| 40 |
+
response = await self.client.aembedding(**batched_request)
|
| 41 |
+
|
| 42 |
+
# Distribute results back to the original requesters
|
| 43 |
+
for i, future in enumerate(futures):
|
| 44 |
+
# Create a new response object for each item in the batch
|
| 45 |
+
single_response_data = {
|
| 46 |
+
"object": response.object,
|
| 47 |
+
"model": response.model,
|
| 48 |
+
"data": [response.data[i]],
|
| 49 |
+
"usage": response.usage # Usage is for the whole batch
|
| 50 |
+
}
|
| 51 |
+
future.set_result(single_response_data)
|
| 52 |
+
|
| 53 |
+
except Exception as e:
|
| 54 |
+
for future in futures:
|
| 55 |
+
future.set_exception(e)
|
| 56 |
+
|
| 57 |
+
async def _gather_batch(self) -> Tuple[List[Dict[str, Any]], List[asyncio.Future]]:
|
| 58 |
+
batch = []
|
| 59 |
+
futures = []
|
| 60 |
+
start_time = time.time()
|
| 61 |
+
|
| 62 |
+
while len(batch) < self.batch_size and (time.time() - start_time) < self.timeout:
|
| 63 |
+
try:
|
| 64 |
+
# Wait for an item with a timeout
|
| 65 |
+
timeout = self.timeout - (time.time() - start_time)
|
| 66 |
+
if timeout <= 0:
|
| 67 |
+
break
|
| 68 |
+
request, future = await asyncio.wait_for(self.queue.get(), timeout=timeout)
|
| 69 |
+
batch.append(request)
|
| 70 |
+
futures.append(future)
|
| 71 |
+
except asyncio.TimeoutError:
|
| 72 |
+
break
|
| 73 |
+
|
| 74 |
+
return batch, futures
|
| 75 |
+
|
| 76 |
+
async def stop(self):
|
| 77 |
+
self.worker_task.cancel()
|
| 78 |
+
try:
|
| 79 |
+
await self.worker_task
|
| 80 |
+
except asyncio.CancelledError:
|
| 81 |
+
pass
|
src/proxy_app/main.py
CHANGED
|
@@ -36,6 +36,7 @@ sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
| 36 |
|
| 37 |
from rotator_library import RotatingClient, PROVIDER_PLUGINS
|
| 38 |
from proxy_app.request_logger import log_request_response
|
|
|
|
| 39 |
|
| 40 |
# Configure logging
|
| 41 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -67,11 +68,15 @@ if not api_keys:
|
|
| 67 |
@asynccontextmanager
|
| 68 |
async def lifespan(app: FastAPI):
|
| 69 |
"""Manage the RotatingClient's lifecycle with the app's lifespan."""
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
| 72 |
yield
|
| 73 |
-
await
|
| 74 |
-
|
|
|
|
| 75 |
|
| 76 |
# --- FastAPI App Setup ---
|
| 77 |
app = FastAPI(lifespan=lifespan)
|
|
@@ -81,6 +86,10 @@ def get_rotating_client(request: Request) -> RotatingClient:
|
|
| 81 |
"""Dependency to get the rotating client instance from the app state."""
|
| 82 |
return request.app.state.rotating_client
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
async def verify_api_key(auth: str = Depends(api_key_header)):
|
| 85 |
"""Dependency to verify the proxy API key."""
|
| 86 |
if not auth or auth != f"Bearer {PROXY_API_KEY}":
|
|
@@ -267,21 +276,27 @@ async def chat_completions(
|
|
| 267 |
async def embeddings(
|
| 268 |
request: Request,
|
| 269 |
body: EmbeddingRequest,
|
| 270 |
-
|
| 271 |
_ = Depends(verify_api_key)
|
| 272 |
):
|
| 273 |
"""
|
| 274 |
OpenAI-compatible endpoint for creating embeddings.
|
|
|
|
| 275 |
"""
|
| 276 |
try:
|
| 277 |
request_data = body.model_dump(exclude_none=True)
|
| 278 |
|
| 279 |
-
#
|
| 280 |
-
if isinstance(request_data.get("input"),
|
| 281 |
-
request_data["input"]
|
|
|
|
|
|
|
| 282 |
|
| 283 |
-
|
| 284 |
|
|
|
|
|
|
|
|
|
|
| 285 |
if ENABLE_REQUEST_LOGGING:
|
| 286 |
response_summary = {
|
| 287 |
"model": response.model,
|
|
|
|
| 36 |
|
| 37 |
from rotator_library import RotatingClient, PROVIDER_PLUGINS
|
| 38 |
from proxy_app.request_logger import log_request_response
|
| 39 |
+
from proxy_app.batch_manager import EmbeddingBatcher
|
| 40 |
|
| 41 |
# Configure logging
|
| 42 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 68 |
@asynccontextmanager
|
| 69 |
async def lifespan(app: FastAPI):
|
| 70 |
"""Manage the RotatingClient's lifecycle with the app's lifespan."""
|
| 71 |
+
client = RotatingClient(api_keys=api_keys)
|
| 72 |
+
batcher = EmbeddingBatcher(client=client)
|
| 73 |
+
app.state.rotating_client = client
|
| 74 |
+
app.state.embedding_batcher = batcher
|
| 75 |
+
print("RotatingClient and EmbeddingBatcher initialized.")
|
| 76 |
yield
|
| 77 |
+
await batcher.stop()
|
| 78 |
+
await client.close()
|
| 79 |
+
print("RotatingClient and EmbeddingBatcher closed.")
|
| 80 |
|
| 81 |
# --- FastAPI App Setup ---
|
| 82 |
app = FastAPI(lifespan=lifespan)
|
|
|
|
| 86 |
"""Dependency to get the rotating client instance from the app state."""
|
| 87 |
return request.app.state.rotating_client
|
| 88 |
|
| 89 |
+
def get_embedding_batcher(request: Request) -> EmbeddingBatcher:
|
| 90 |
+
"""Dependency to get the embedding batcher instance from the app state."""
|
| 91 |
+
return request.app.state.embedding_batcher
|
| 92 |
+
|
| 93 |
async def verify_api_key(auth: str = Depends(api_key_header)):
|
| 94 |
"""Dependency to verify the proxy API key."""
|
| 95 |
if not auth or auth != f"Bearer {PROXY_API_KEY}":
|
|
|
|
| 276 |
async def embeddings(
|
| 277 |
request: Request,
|
| 278 |
body: EmbeddingRequest,
|
| 279 |
+
batcher: EmbeddingBatcher = Depends(get_embedding_batcher),
|
| 280 |
_ = Depends(verify_api_key)
|
| 281 |
):
|
| 282 |
"""
|
| 283 |
OpenAI-compatible endpoint for creating embeddings.
|
| 284 |
+
This endpoint uses a batching manager to group requests for efficiency.
|
| 285 |
"""
|
| 286 |
try:
|
| 287 |
request_data = body.model_dump(exclude_none=True)
|
| 288 |
|
| 289 |
+
# The batcher expects a single string input per request
|
| 290 |
+
if isinstance(request_data.get("input"), list):
|
| 291 |
+
if len(request_data["input"]) > 1:
|
| 292 |
+
raise HTTPException(status_code=400, detail="Batching multiple inputs in a single request is not supported when using the server-side batcher. Please send one input string per request.")
|
| 293 |
+
request_data["input"] = request_data["input"][0]
|
| 294 |
|
| 295 |
+
response_data = await batcher.add_request(request_data)
|
| 296 |
|
| 297 |
+
# The batcher returns a dict, not a Pydantic model, so we construct it
|
| 298 |
+
response = litellm.EmbeddingResponse(**response_data)
|
| 299 |
+
|
| 300 |
if ENABLE_REQUEST_LOGGING:
|
| 301 |
response_summary = {
|
| 302 |
"model": response.model,
|