Spaces:
Paused
Paused
| import asyncio | |
| from typing import List, Dict, Any, Tuple | |
| import time | |
| from rotator_library import RotatingClient | |
| class EmbeddingBatcher: | |
| def __init__(self, client: RotatingClient, batch_size: int = 64, timeout: float = 0.1): | |
| self.client = client | |
| self.batch_size = batch_size | |
| self.timeout = timeout | |
| self.queue = asyncio.Queue() | |
| self.worker_task = asyncio.create_task(self._batch_worker()) | |
| async def add_request(self, request_data: Dict[str, Any]) -> Any: | |
| future = asyncio.Future() | |
| await self.queue.put((request_data, future)) | |
| return await future | |
| async def _batch_worker(self): | |
| while True: | |
| batch, futures = await self._gather_batch() | |
| if not batch: | |
| continue | |
| try: | |
| # Assume all requests in a batch use the same model and other settings | |
| model = batch[0]["model"] | |
| inputs = [item["input"][0] for item in batch] # Extract single string input | |
| batched_request = { | |
| "model": model, | |
| "input": inputs | |
| } | |
| # Pass through any other relevant parameters from the first request | |
| for key in ["input_type", "dimensions", "user"]: | |
| if key in batch[0]: | |
| batched_request[key] = batch[0][key] | |
| response = await self.client.aembedding(**batched_request) | |
| # Distribute results back to the original requesters | |
| for i, future in enumerate(futures): | |
| # Create a new response object for each item in the batch | |
| single_response_data = { | |
| "object": response.object, | |
| "model": response.model, | |
| "data": [response.data[i]], | |
| "usage": response.usage # Usage is for the whole batch | |
| } | |
| future.set_result(single_response_data) | |
| except Exception as e: | |
| for future in futures: | |
| future.set_exception(e) | |
| async def _gather_batch(self) -> Tuple[List[Dict[str, Any]], List[asyncio.Future]]: | |
| batch = [] | |
| futures = [] | |
| start_time = time.time() | |
| while len(batch) < self.batch_size and (time.time() - start_time) < self.timeout: | |
| try: | |
| # Wait for an item with a timeout | |
| timeout = self.timeout - (time.time() - start_time) | |
| if timeout <= 0: | |
| break | |
| request, future = await asyncio.wait_for(self.queue.get(), timeout=timeout) | |
| batch.append(request) | |
| futures.append(future) | |
| except asyncio.TimeoutError: | |
| break | |
| return batch, futures | |
| async def stop(self): | |
| self.worker_task.cancel() | |
| try: | |
| await self.worker_task | |
| except asyncio.CancelledError: | |
| pass |