FinalAssignmentTemplate / rate_limiter.py
petrov826's picture
rename APICallCounter to RateLimiter and refactor
891518c verified
raw
history blame
1.75 kB
import time
class RateLimiter:
"""
litellmのAPI呼び出し回数をカウントして、
一定回数のAPI呼び出しが行われるたびに一定秒間待機する
"""
def __init__(self, calls_per_wait: int = 10, seconds_to_wait: int = 60):
"""
RateLimiterを初期化する
Args:
calls_per_wait (int): 待機するまでの呼び出し回数
seconds_to_wait (int): 待機する秒数
"""
self.calls_per_wait = calls_per_wait
self.seconds_to_wait = seconds_to_wait
self.successful_calls = 0 # APIの呼び出し成功回数
self.failed_calls = 0 # APIの呼び出し失敗回数
def increment_successful_calls(self, kwargs, completion_obj, start_time, end_time):
"""API呼び出しが成功したときに、呼び出し成功回数を1増やす。"""
self.successful_calls += 1
self.sleep_if_needed()
def increment_failed_calls(self, kwargs, completion_obj, start_time, end_time):
"""API呼び出しが失敗したときに、呼び出し失敗回数を1増やす。"""
self.failed_calls += 1
self.sleep_if_needed()
def sleep_if_needed(self):
"""
呼び出し回数が`calls_per_wait`を超えていたら、
APIの呼び出し回数をリセットして`seconds_to_wait`秒待機する
"""
if (self.successful_calls + self.failed_calls) >= self.calls_per_wait:
self.successful_calls = 0
self.failed_calls = 0
time.sleep(self.seconds_to_wait)
print(f"呼び出し回数がf{self.calls_per_wait}回を超えたので、f{self.seconds_to_wait}秒待機します")