File size: 3,066 Bytes
1337d48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import asyncio
from typing import List, Dict, Any, Tuple
import time
from rotator_library import RotatingClient

class EmbeddingBatcher:
    def __init__(self, client: RotatingClient, batch_size: int = 64, timeout: float = 0.1):
        self.client = client
        self.batch_size = batch_size
        self.timeout = timeout
        self.queue = asyncio.Queue()
        self.worker_task = asyncio.create_task(self._batch_worker())

    async def add_request(self, request_data: Dict[str, Any]) -> Any:
        future = asyncio.Future()
        await self.queue.put((request_data, future))
        return await future

    async def _batch_worker(self):
        while True:
            batch, futures = await self._gather_batch()
            if not batch:
                continue

            try:
                # Assume all requests in a batch use the same model and other settings
                model = batch[0]["model"]
                inputs = [item["input"][0] for item in batch] # Extract single string input

                batched_request = {
                    "model": model,
                    "input": inputs
                }
                
                # Pass through any other relevant parameters from the first request
                for key in ["input_type", "dimensions", "user"]:
                    if key in batch[0]:
                        batched_request[key] = batch[0][key]

                response = await self.client.aembedding(**batched_request)
                
                # Distribute results back to the original requesters
                for i, future in enumerate(futures):
                    # Create a new response object for each item in the batch
                    single_response_data = {
                        "object": response.object,
                        "model": response.model,
                        "data": [response.data[i]],
                        "usage": response.usage # Usage is for the whole batch
                    }
                    future.set_result(single_response_data)

            except Exception as e:
                for future in futures:
                    future.set_exception(e)

    async def _gather_batch(self) -> Tuple[List[Dict[str, Any]], List[asyncio.Future]]:
        batch = []
        futures = []
        start_time = time.time()

        while len(batch) < self.batch_size and (time.time() - start_time) < self.timeout:
            try:
                # Wait for an item with a timeout
                timeout = self.timeout - (time.time() - start_time)
                if timeout <= 0:
                    break
                request, future = await asyncio.wait_for(self.queue.get(), timeout=timeout)
                batch.append(request)
                futures.append(future)
            except asyncio.TimeoutError:
                break
        
        return batch, futures

    async def stop(self):
        self.worker_task.cancel()
        try:
            await self.worker_task
        except asyncio.CancelledError:
            pass