GPA_DEMO / models /glm_speech_tokenizer /batch_processor.py
wanglamao
init
528efee
# -*- coding: utf-8 -*-
# Time :2024/11/17 15:33
# Author :Hui Huang
import asyncio
import uuid
from typing import Callable, List, Any, Awaitable, Tuple
from asyncio import Queue
class BatchProcessor:
"""Batch Processor for handling asynchronous requests in batches.
This class manages a queue of requests and processes them in batches
using multiple worker tasks.
Attributes:
processing_function (Callable[[List[Any]], Awaitable[List[Any]]]):
The function used for processing requests in batches.
num_workers (int): The number of worker tasks to process requests.
batch_size (int): The maximum number of requests to process in a single batch.
request_queue (Queue): The queue holding incoming requests.
loop (asyncio.AbstractEventLoop): The event loop used to create worker tasks.
worker_tasks (List[asyncio.Task]): The list of worker tasks.
"""
def __init__(
self,
processing_function: Callable[[List[Any]], Awaitable[List[Any]]],
num_workers: int,
batch_size: int,
wait_timeout: float = 0.05
) -> None:
"""Initialize the BatchProcessor with the given processing function, number of workers, and batch size.
Args:
processing_function (Callable[[List[Any]], Awaitable[List[Any]]]):
The function used for processing requests in batches.
num_workers (int): The number of worker tasks to process requests.
batch_size (int): The maximum number of requests to process in a single batch.
"""
self.processing_function = processing_function
self.num_workers = num_workers
self.batch_size = batch_size
self.wait_timeout = wait_timeout
self.request_queue: Queue = Queue()
self.loop = asyncio.get_running_loop()
self.worker_tasks = [
self.loop.create_task(self.batch_processor(i)) for i in range(num_workers)
]
# Wait until all worker tasks are started
self.loop.create_task(self._log_workers_started())
async def _log_workers_started(self):
await asyncio.sleep(0) # Yield control to ensure workers have started
async def batch_processor(self, worker_id: int):
"""Worker task that processes requests from the queue in batches.
Args:
worker_id (int): The identifier for the worker task.
"""
while True:
requests: List[Tuple[Any, asyncio.Future]] = []
try:
while len(requests) < self.batch_size:
request = await asyncio.wait_for(
self.request_queue.get(), timeout=self.wait_timeout
)
requests.append(request)
except asyncio.TimeoutError:
pass
if requests:
all_requests = [
req[0] for req in requests
] # Extract the actual input data from each request tuple
futures = [req[1] for req in requests] # Extract the futures to resolve
try:
results = await self.processing_function(all_requests)
for (future, result) in zip(futures, results):
future.set_result(result)
except Exception as e:
for future in futures:
future.set_exception(e)
async def add_request(self, single_input: Any):
"""Add a new request to the queue.
Args:
single_input (Any): The input data for processing.
"""
# loop = asyncio.get_running_loop()
future = self.loop.create_future()
self.request_queue.put_nowait((single_input, future))
return future
async def shutdown(self):
"""Shutdown the batch processor by cancelling all worker tasks."""
for task in self.worker_tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
print("Worker task cancelled.")
class AsyncBatchEngine:
def __init__(
self,
processing_function: Callable[[List[Any]], Awaitable[List[Any]]],
batch_size: int = 32,
wait_timeout: float = 0.01,
):
"""
Initialize the AsyncBatchEngine with a processing function, number of workers, and batch size.
Args:
processing_function (Callable[[List[Any]], Awaitable[List[Any]]]): The batch processing function.
batch_size (int): The maximum number of requests to process in a single batch.
"""
self._processing_function = processing_function
self._batch_size = batch_size
self._is_running = False
self._batch_processor = None
self._wait_timeout = wait_timeout
async def start(self):
"""Start the engine by initializing the batch processor and worker tasks."""
if self._is_running:
return
self._batch_processor = BatchProcessor(
processing_function=self._processing_function,
batch_size=self._batch_size,
wait_timeout=self._wait_timeout,
num_workers=1
)
self._is_running = True
async def stop(self):
"""Stop the engine by shutting down the batch processor and worker tasks."""
self._check_running()
self._is_running = False
if self._batch_processor is not None:
await self._batch_processor.shutdown()
def _check_running(self):
"""Check if the engine is running.
Raises:
ValueError: If the engine is not running.
"""
if not self._is_running:
raise ValueError(
"The engine is not running. "
"You must start the engine before using it."
)
async def add_request(self, single_input: Any, request_id: str = None) -> dict:
"""Asynchronously add a request to be processed.
Args:
single_input (Any): The input data for processing.
request_id (str): Optional request identifier to avoid data mix-up.
Raises:
ValueError: If the engine is not running when this method is called.
"""
if not self._is_running:
await self.start()
if request_id is None:
request_id = str(uuid.uuid4()) # Assign a unique ID if not provided
future = await self._batch_processor.add_request(single_input=single_input) # type: ignore
result = await future
return dict(
request_id=request_id,
feature=result
)