Spaces:
Sleeping
Sleeping
| import asyncio | |
| from sentence_transformers import SentenceTransformer | |
| from .model import predict | |
| class DynamicBatcher: | |
| """Collects concurrent requests into batches for efficient inference. | |
| Flow: | |
| 1. Endpoint calls submit(texts) → creates a Future, puts (texts, future) on queue | |
| 2. Background _worker loop collects items from queue | |
| 3. When max_batch_size reached OR max_wait_ms elapsed → flush: | |
| - Flatten all texts into one list | |
| - Run predict(model, all_texts) via run_in_executor (non-blocking) | |
| - Split results back by request, set each future's result | |
| 4. submit() awaits its future and returns the result to the endpoint | |
| """ | |
| def __init__(self, model: SentenceTransformer, max_batch_size: int, max_wait_ms: int): | |
| """Initialize the batcher.""" | |
| self._model = model | |
| self._max_batch_size = max_batch_size | |
| self._max_wait_ms = max_wait_ms | |
| self._queue = asyncio.Queue() | |
| self._worker_task = None | |
| def start(self) -> None: | |
| """Launch the background worker as an asyncio task.""" | |
| self._worker_task = asyncio.create_task(self._worker()) | |
| async def stop(self) -> None: | |
| """Stop the background worker and drain remaining requests.""" | |
| try: | |
| self._worker_task.cancel() | |
| await self._worker_task | |
| except asyncio.CancelledError: | |
| pass | |
| try: | |
| while not self._queue.empty(): | |
| texts, future = self._queue.get_nowait() | |
| future.set_exception(asyncio.CancelledError) | |
| except asyncio.QueueEmpty: | |
| pass | |
| async def submit(self, texts: list[str]) -> list[list[float]]: | |
| """Submit a request for batched inference. Called by the /predict endpoint.""" | |
| loop = asyncio.get_event_loop() | |
| future = loop.create_future() | |
| await self._queue.put((texts, future)) | |
| return await future | |
| async def _worker(self) -> None: | |
| """Background loop that collects and processes batches.""" | |
| while True: | |
| batch = [] | |
| first_item = await self._queue.get() | |
| batch.append(first_item) | |
| while len(batch) < self._max_batch_size: | |
| try: | |
| item = await asyncio.wait_for(self._queue.get(), timeout=self._max_wait_ms) | |
| batch.append(item) | |
| except asyncio.TimeoutError: | |
| break | |
| # flatten texts | |
| all_texts, sizes = [], [0] | |
| for texts, _ in batch: | |
| all_texts.extend(texts) | |
| sizes.append(sizes[-1] + len(texts)) | |
| # run inference | |
| loop = asyncio.get_event_loop() | |
| try: | |
| all_embeddings = await loop.run_in_executor(None, predict, self._model, all_texts) | |
| # split results back and resolve futures | |
| for idx, (_, future) in enumerate(batch): | |
| future.set_result(all_embeddings[sizes[idx]:sizes[idx+1]]) | |
| except Exception as e: | |
| # to handle raise in predict() | |
| for _, future in batch: | |
| future.set_exception(e) | |