Spaces:
Paused
Paused
File size: 3,388 Bytes
e5902c3 15a2751 bab921d e5902c3 15a2751 bab921d e5902c3 15a2751 e5902c3 15a2751 e5902c3 bab921d 15a2751 bab921d 15a2751 bab921d 15a2751 e5902c3 bab921d 15a2751 bab921d 15a2751 bab921d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import time
from collections import deque
from dataclasses import dataclass
from typing import List, Tuple
from .print_style import PrintStyle
@dataclass
class CallRecord:
timestamp: float
input_tokens: int
output_tokens: int = 0 # Default to 0, will be set separately
class RateLimiter:
def __init__(self, max_calls: int, max_input_tokens: int, max_output_tokens: int, window_seconds: int = 60):
self.max_calls = max_calls
self.max_input_tokens = max_input_tokens
self.max_output_tokens = max_output_tokens
self.window_seconds = window_seconds
self.call_records: deque = deque()
def _clean_old_records(self, current_time: float):
while self.call_records and current_time - self.call_records[0].timestamp > self.window_seconds:
self.call_records.popleft()
def _get_counts(self) -> Tuple[int, int, int]:
calls = len(self.call_records)
input_tokens = sum(record.input_tokens for record in self.call_records)
output_tokens = sum(record.output_tokens for record in self.call_records)
return calls, input_tokens, output_tokens
def _wait_if_needed(self, current_time: float, new_input_tokens: int):
while True:
self._clean_old_records(current_time)
calls, input_tokens, output_tokens = self._get_counts()
wait_reasons = []
if self.max_calls > 0 and calls >= self.max_calls:
wait_reasons.append("max calls")
if self.max_input_tokens > 0 and input_tokens + new_input_tokens > self.max_input_tokens:
wait_reasons.append("max input tokens")
if self.max_output_tokens > 0 and output_tokens >= self.max_output_tokens:
wait_reasons.append("max output tokens")
if not wait_reasons:
break
oldest_record = self.call_records[0]
wait_time = oldest_record.timestamp + self.window_seconds - current_time
if wait_time > 0:
PrintStyle(font_color="yellow", padding=True).print(f"Rate limit exceeded. Waiting for {wait_time:.2f} seconds due to: {', '.join(wait_reasons)}")
time.sleep(wait_time)
current_time = time.time()
def limit_call_and_input(self, input_token_count: int) -> CallRecord:
current_time = time.time()
self._wait_if_needed(current_time, input_token_count)
new_record = CallRecord(current_time, input_token_count)
self.call_records.append(new_record)
return new_record
def set_output_tokens(self, output_token_count: int):
if self.call_records:
self.call_records[-1].output_tokens += output_token_count
return self
# Example usage
rate_limiter = RateLimiter(max_calls=5, max_input_tokens=1000, max_output_tokens=2000)
def rate_limited_function(input_token_count: int, output_token_count: int):
# First, limit the call and input tokens (this may wait)
rate_limiter.limit_call_and_input(input_token_count)
# Your function logic here
print(f"Function called with {input_token_count} input tokens")
# After processing, set the output tokens (this doesn't wait)
rate_limiter.set_output_tokens(output_token_count)
print(f"Function completed with {output_token_count} output tokens")
|