File size: 3,388 Bytes
e5902c3
 
15a2751
 
bab921d
e5902c3
15a2751
 
 
 
bab921d
e5902c3
15a2751
 
 
 
 
 
 
e5902c3
15a2751
 
 
 
 
 
 
 
 
e5902c3
bab921d
15a2751
 
 
 
bab921d
 
 
 
 
 
 
 
 
15a2751
 
 
 
 
bab921d
15a2751
e5902c3
 
bab921d
15a2751
bab921d
 
 
 
 
 
 
 
 
 
 
 
15a2751
bab921d
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import time
from collections import deque
from dataclasses import dataclass
from typing import List, Tuple
from .print_style import PrintStyle

@dataclass
class CallRecord:
    timestamp: float
    input_tokens: int
    output_tokens: int = 0  # Default to 0, will be set separately

class RateLimiter:
    def __init__(self, max_calls: int, max_input_tokens: int, max_output_tokens: int, window_seconds: int = 60):
        self.max_calls = max_calls
        self.max_input_tokens = max_input_tokens
        self.max_output_tokens = max_output_tokens
        self.window_seconds = window_seconds
        self.call_records: deque = deque()

    def _clean_old_records(self, current_time: float):
        while self.call_records and current_time - self.call_records[0].timestamp > self.window_seconds:
            self.call_records.popleft()

    def _get_counts(self) -> Tuple[int, int, int]:
        calls = len(self.call_records)
        input_tokens = sum(record.input_tokens for record in self.call_records)
        output_tokens = sum(record.output_tokens for record in self.call_records)
        return calls, input_tokens, output_tokens

    def _wait_if_needed(self, current_time: float, new_input_tokens: int):
        while True:
            self._clean_old_records(current_time)
            calls, input_tokens, output_tokens = self._get_counts()
            
            wait_reasons = []
            if self.max_calls > 0 and calls >= self.max_calls:
                wait_reasons.append("max calls")
            if self.max_input_tokens > 0 and input_tokens + new_input_tokens > self.max_input_tokens:
                wait_reasons.append("max input tokens")
            if self.max_output_tokens > 0 and output_tokens >= self.max_output_tokens:
                wait_reasons.append("max output tokens")
            
            if not wait_reasons:
                break
            
            oldest_record = self.call_records[0]
            wait_time = oldest_record.timestamp + self.window_seconds - current_time
            if wait_time > 0:
                PrintStyle(font_color="yellow", padding=True).print(f"Rate limit exceeded. Waiting for {wait_time:.2f} seconds due to: {', '.join(wait_reasons)}")
                time.sleep(wait_time)
            current_time = time.time()

    def limit_call_and_input(self, input_token_count: int) -> CallRecord:
        current_time = time.time()
        self._wait_if_needed(current_time, input_token_count)
        new_record = CallRecord(current_time, input_token_count)
        self.call_records.append(new_record)
        return new_record

    def set_output_tokens(self, output_token_count: int):
        if self.call_records:
            self.call_records[-1].output_tokens += output_token_count
        return self

# Example usage
rate_limiter = RateLimiter(max_calls=5, max_input_tokens=1000, max_output_tokens=2000)

def rate_limited_function(input_token_count: int, output_token_count: int):
    # First, limit the call and input tokens (this may wait)
    rate_limiter.limit_call_and_input(input_token_count)
    
    # Your function logic here
    print(f"Function called with {input_token_count} input tokens")
    
    # After processing, set the output tokens (this doesn't wait)
    rate_limiter.set_output_tokens(output_token_count)
    print(f"Function completed with {output_token_count} output tokens")