unit4_test / rate_limit_models.py
jhcadfergu's picture
Update rate_limit_models.py
950a6d6 verified
from typing import Generator
from tenacity import (
retry,
stop_after_attempt,
wait_exponential
) # for exponential backoff
from smolagents import OpenAIServerModel, Tool, ChatMessage, ChatMessageStreamDelta
# amit
# class SleepPerLastTokenOpenAIServerModel(OpenAIServerModel):
# def __init__(self, sleep_factor: float = 0.01, **kwargs):
# super().__init__(**kwargs)
# self.sleep_factor = sleep_factor
#
# def __call__(
# self,
# messages: List[Dict[str, str]],
# stop_sequences: Optional[List[str]] = None,
# grammar: Optional[str] = None,
# tools_to_call_from: Optional[List[Tool]] = None,
# **kwargs,
# ) -> ChatMessage:
# if self.last_input_token_count is not None:
# sleep_time = (
# self.last_input_token_count + self.last_output_token_count
# ) * self.sleep_factor
# print(f"Sleeping for {sleep_time:.2f} seconds...")
# time.sleep(sleep_time)
#
# return super().__call__(
# messages, stop_sequences, grammar, tools_to_call_from, **kwargs
# )
class ExponentialBackoffOpenAIServerModel(OpenAIServerModel):
@retry(wait=wait_exponential(max=128), stop=stop_after_attempt(7))
def generate(
self,
messages: list[ChatMessage],
stop_sequences: list[str] | None = None,
response_format: dict[str, str] | None = None,
tools_to_call_from: list[Tool] | None = None,
**kwargs,
) -> ChatMessage:
result = super().generate(messages, stop_sequences, response_format, tools_to_call_from, **kwargs)
if result.content is None:
raise ValueError("LLM API generated content=None")
return result
@retry(wait=wait_exponential(max=128), stop=stop_after_attempt(7))
def generate_stream(
self,
messages: list[ChatMessage],
stop_sequences: list[str] | None = None,
response_format: dict[str, str] | None = None,
tools_to_call_from: list[Tool] | None = None,
**kwargs,
) -> Generator[ChatMessageStreamDelta, None, None]:
yield super().generate_stream(messages, stop_sequences, response_format, tools_to_call_from, **kwargs)