Spaces:
Running
on
L4
Running
on
L4
| # -*- coding: utf-8 -*- | |
| # Time :2024/11/17 15:33 | |
| # Author :Hui Huang | |
| import asyncio | |
| import uuid | |
| from typing import Callable, List, Any, Awaitable, Tuple | |
| from asyncio import Queue | |
| class BatchProcessor: | |
| """Batch Processor for handling asynchronous requests in batches. | |
| This class manages a queue of requests and processes them in batches | |
| using multiple worker tasks. | |
| Attributes: | |
| processing_function (Callable[[List[Any]], Awaitable[List[Any]]]): | |
| The function used for processing requests in batches. | |
| num_workers (int): The number of worker tasks to process requests. | |
| batch_size (int): The maximum number of requests to process in a single batch. | |
| request_queue (Queue): The queue holding incoming requests. | |
| loop (asyncio.AbstractEventLoop): The event loop used to create worker tasks. | |
| worker_tasks (List[asyncio.Task]): The list of worker tasks. | |
| """ | |
| def __init__( | |
| self, | |
| processing_function: Callable[[List[Any]], Awaitable[List[Any]]], | |
| num_workers: int, | |
| batch_size: int, | |
| wait_timeout: float = 0.05 | |
| ) -> None: | |
| """Initialize the BatchProcessor with the given processing function, number of workers, and batch size. | |
| Args: | |
| processing_function (Callable[[List[Any]], Awaitable[List[Any]]]): | |
| The function used for processing requests in batches. | |
| num_workers (int): The number of worker tasks to process requests. | |
| batch_size (int): The maximum number of requests to process in a single batch. | |
| """ | |
| self.processing_function = processing_function | |
| self.num_workers = num_workers | |
| self.batch_size = batch_size | |
| self.wait_timeout = wait_timeout | |
| self.request_queue: Queue = Queue() | |
| self.loop = asyncio.get_running_loop() | |
| self.worker_tasks = [ | |
| self.loop.create_task(self.batch_processor(i)) for i in range(num_workers) | |
| ] | |
| # Wait until all worker tasks are started | |
| self.loop.create_task(self._log_workers_started()) | |
| async def _log_workers_started(self): | |
| await asyncio.sleep(0) # Yield control to ensure workers have started | |
| async def batch_processor(self, worker_id: int): | |
| """Worker task that processes requests from the queue in batches. | |
| Args: | |
| worker_id (int): The identifier for the worker task. | |
| """ | |
| while True: | |
| requests: List[Tuple[Any, asyncio.Future]] = [] | |
| try: | |
| while len(requests) < self.batch_size: | |
| request = await asyncio.wait_for( | |
| self.request_queue.get(), timeout=self.wait_timeout | |
| ) | |
| requests.append(request) | |
| except asyncio.TimeoutError: | |
| pass | |
| if requests: | |
| all_requests = [ | |
| req[0] for req in requests | |
| ] # Extract the actual input data from each request tuple | |
| futures = [req[1] for req in requests] # Extract the futures to resolve | |
| try: | |
| results = await self.processing_function(all_requests) | |
| for (future, result) in zip(futures, results): | |
| future.set_result(result) | |
| except Exception as e: | |
| for future in futures: | |
| future.set_exception(e) | |
| async def add_request(self, single_input: Any): | |
| """Add a new request to the queue. | |
| Args: | |
| single_input (Any): The input data for processing. | |
| """ | |
| # loop = asyncio.get_running_loop() | |
| future = self.loop.create_future() | |
| self.request_queue.put_nowait((single_input, future)) | |
| return future | |
| async def shutdown(self): | |
| """Shutdown the batch processor by cancelling all worker tasks.""" | |
| for task in self.worker_tasks: | |
| task.cancel() | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| print("Worker task cancelled.") | |
| class AsyncBatchEngine: | |
| def __init__( | |
| self, | |
| processing_function: Callable[[List[Any]], Awaitable[List[Any]]], | |
| batch_size: int = 32, | |
| wait_timeout: float = 0.01, | |
| ): | |
| """ | |
| Initialize the AsyncBatchEngine with a processing function, number of workers, and batch size. | |
| Args: | |
| processing_function (Callable[[List[Any]], Awaitable[List[Any]]]): The batch processing function. | |
| batch_size (int): The maximum number of requests to process in a single batch. | |
| """ | |
| self._processing_function = processing_function | |
| self._batch_size = batch_size | |
| self._is_running = False | |
| self._batch_processor = None | |
| self._wait_timeout = wait_timeout | |
| async def start(self): | |
| """Start the engine by initializing the batch processor and worker tasks.""" | |
| if self._is_running: | |
| return | |
| self._batch_processor = BatchProcessor( | |
| processing_function=self._processing_function, | |
| batch_size=self._batch_size, | |
| wait_timeout=self._wait_timeout, | |
| num_workers=1 | |
| ) | |
| self._is_running = True | |
| async def stop(self): | |
| """Stop the engine by shutting down the batch processor and worker tasks.""" | |
| self._check_running() | |
| self._is_running = False | |
| if self._batch_processor is not None: | |
| await self._batch_processor.shutdown() | |
| def _check_running(self): | |
| """Check if the engine is running. | |
| Raises: | |
| ValueError: If the engine is not running. | |
| """ | |
| if not self._is_running: | |
| raise ValueError( | |
| "The engine is not running. " | |
| "You must start the engine before using it." | |
| ) | |
| async def add_request(self, single_input: Any, request_id: str = None) -> dict: | |
| """Asynchronously add a request to be processed. | |
| Args: | |
| single_input (Any): The input data for processing. | |
| request_id (str): Optional request identifier to avoid data mix-up. | |
| Raises: | |
| ValueError: If the engine is not running when this method is called. | |
| """ | |
| if not self._is_running: | |
| await self.start() | |
| if request_id is None: | |
| request_id = str(uuid.uuid4()) # Assign a unique ID if not provided | |
| future = await self._batch_processor.add_request(single_input=single_input) # type: ignore | |
| result = await future | |
| return dict( | |
| request_id=request_id, | |
| feature=result | |
| ) |