import time from collections import deque from dataclasses import dataclass from typing import List, Tuple from .print_style import PrintStyle @dataclass class CallRecord: timestamp: float input_tokens: int output_tokens: int = 0 # Default to 0, will be set separately class RateLimiter: def __init__(self, max_calls: int, max_input_tokens: int, max_output_tokens: int, window_seconds: int = 60): self.max_calls = max_calls self.max_input_tokens = max_input_tokens self.max_output_tokens = max_output_tokens self.window_seconds = window_seconds self.call_records: deque = deque() def _clean_old_records(self, current_time: float): while self.call_records and current_time - self.call_records[0].timestamp > self.window_seconds: self.call_records.popleft() def _get_counts(self) -> Tuple[int, int, int]: calls = len(self.call_records) input_tokens = sum(record.input_tokens for record in self.call_records) output_tokens = sum(record.output_tokens for record in self.call_records) return calls, input_tokens, output_tokens def _wait_if_needed(self, current_time: float, new_input_tokens: int): while True: self._clean_old_records(current_time) calls, input_tokens, output_tokens = self._get_counts() wait_reasons = [] if self.max_calls > 0 and calls >= self.max_calls: wait_reasons.append("max calls") if self.max_input_tokens > 0 and input_tokens + new_input_tokens > self.max_input_tokens: wait_reasons.append("max input tokens") if self.max_output_tokens > 0 and output_tokens >= self.max_output_tokens: wait_reasons.append("max output tokens") if not wait_reasons: break oldest_record = self.call_records[0] wait_time = oldest_record.timestamp + self.window_seconds - current_time if wait_time > 0: PrintStyle(font_color="yellow", padding=True).print(f"Rate limit exceeded. Waiting for {wait_time:.2f} seconds due to: {', '.join(wait_reasons)}") time.sleep(wait_time) current_time = time.time() def limit_call_and_input(self, input_token_count: int) -> CallRecord: current_time = time.time() self._wait_if_needed(current_time, input_token_count) new_record = CallRecord(current_time, input_token_count) self.call_records.append(new_record) return new_record def set_output_tokens(self, output_token_count: int): if self.call_records: self.call_records[-1].output_tokens += output_token_count return self # Example usage rate_limiter = RateLimiter(max_calls=5, max_input_tokens=1000, max_output_tokens=2000) def rate_limited_function(input_token_count: int, output_token_count: int): # First, limit the call and input tokens (this may wait) rate_limiter.limit_call_and_input(input_token_count) # Your function logic here print(f"Function called with {input_token_count} input tokens") # After processing, set the output tokens (this doesn't wait) rate_limiter.set_output_tokens(output_token_count) print(f"Function completed with {output_token_count} output tokens")