File size: 6,771 Bytes
528efee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# -*- coding: utf-8 -*-
# Time      :2024/11/17 15:33
# Author    :Hui Huang
import asyncio
import uuid
from typing import Callable, List, Any, Awaitable, Tuple
from asyncio import Queue


class BatchProcessor:
    """Batch Processor for handling asynchronous requests in batches.

    This class manages a queue of requests and processes them in batches
    using multiple worker tasks.

    Attributes:
        processing_function (Callable[[List[Any]], Awaitable[List[Any]]]):
            The function used for processing requests in batches.
        num_workers (int): The number of worker tasks to process requests.
        batch_size (int): The maximum number of requests to process in a single batch.
        request_queue (Queue): The queue holding incoming requests.
        loop (asyncio.AbstractEventLoop): The event loop used to create worker tasks.
        worker_tasks (List[asyncio.Task]): The list of worker tasks.
    """

    def __init__(
            self,
            processing_function: Callable[[List[Any]], Awaitable[List[Any]]],
            num_workers: int,
            batch_size: int,
            wait_timeout: float = 0.05
    ) -> None:
        """Initialize the BatchProcessor with the given processing function, number of workers, and batch size.

        Args:
            processing_function (Callable[[List[Any]], Awaitable[List[Any]]]):
                The function used for processing requests in batches.
            num_workers (int): The number of worker tasks to process requests.
            batch_size (int): The maximum number of requests to process in a single batch.
        """
        self.processing_function = processing_function
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.wait_timeout = wait_timeout
        self.request_queue: Queue = Queue()
        self.loop = asyncio.get_running_loop()
        self.worker_tasks = [
            self.loop.create_task(self.batch_processor(i)) for i in range(num_workers)
        ]
        # Wait until all worker tasks are started
        self.loop.create_task(self._log_workers_started())

    async def _log_workers_started(self):
        await asyncio.sleep(0)  # Yield control to ensure workers have started

    async def batch_processor(self, worker_id: int):
        """Worker task that processes requests from the queue in batches.

        Args:
            worker_id (int): The identifier for the worker task.
        """

        while True:
            requests: List[Tuple[Any, asyncio.Future]] = []
            try:
                while len(requests) < self.batch_size:
                    request = await asyncio.wait_for(
                        self.request_queue.get(), timeout=self.wait_timeout
                    )
                    requests.append(request)
            except asyncio.TimeoutError:
                pass

            if requests:
                all_requests = [
                    req[0] for req in requests
                ]  # Extract the actual input data from each request tuple
                futures = [req[1] for req in requests]  # Extract the futures to resolve
                try:
                    results = await self.processing_function(all_requests)

                    for (future, result) in zip(futures, results):
                        future.set_result(result)
                except Exception as e:
                    for future in futures:
                        future.set_exception(e)

    async def add_request(self, single_input: Any):
        """Add a new request to the queue.

        Args:
            single_input (Any): The input data for processing.
        """
        # loop = asyncio.get_running_loop()
        future = self.loop.create_future()
        self.request_queue.put_nowait((single_input, future))
        return future

    async def shutdown(self):
        """Shutdown the batch processor by cancelling all worker tasks."""
        for task in self.worker_tasks:
            task.cancel()
            try:
                await task
            except asyncio.CancelledError:
                print("Worker task cancelled.")


class AsyncBatchEngine:

    def __init__(
            self,
            processing_function: Callable[[List[Any]], Awaitable[List[Any]]],
            batch_size: int = 32,
            wait_timeout: float = 0.01,
    ):
        """
        Initialize the AsyncBatchEngine with a processing function, number of workers, and batch size.

        Args:
            processing_function (Callable[[List[Any]], Awaitable[List[Any]]]): The batch processing function.
            batch_size (int): The maximum number of requests to process in a single batch.
        """
        self._processing_function = processing_function
        self._batch_size = batch_size
        self._is_running = False
        self._batch_processor = None
        self._wait_timeout = wait_timeout

    async def start(self):
        """Start the engine by initializing the batch processor and worker tasks."""
        if self._is_running:
            return

        self._batch_processor = BatchProcessor(
            processing_function=self._processing_function,
            batch_size=self._batch_size,
            wait_timeout=self._wait_timeout,
            num_workers=1
        )
        self._is_running = True

    async def stop(self):
        """Stop the engine by shutting down the batch processor and worker tasks."""
        self._check_running()
        self._is_running = False
        if self._batch_processor is not None:
            await self._batch_processor.shutdown()

    def _check_running(self):
        """Check if the engine is running.

        Raises:
            ValueError: If the engine is not running.
        """
        if not self._is_running:
            raise ValueError(
                "The engine is not running. "
                "You must start the engine before using it."
            )

    async def add_request(self, single_input: Any, request_id: str = None) -> dict:
        """Asynchronously add a request to be processed.

        Args:
            single_input (Any): The input data for processing.
            request_id (str): Optional request identifier to avoid data mix-up.

        Raises:
            ValueError: If the engine is not running when this method is called.
        """
        if not self._is_running:
            await self.start()

        if request_id is None:
            request_id = str(uuid.uuid4())  # Assign a unique ID if not provided
        future = await self._batch_processor.add_request(single_input=single_input)  # type: ignore
        result = await future
        return dict(
            request_id=request_id,
            feature=result
        )