Spaces:
Runtime error
Runtime error
File size: 7,544 Bytes
853cf7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
from typing import Any
from collections.abc import Generator
from smolagents import (
OpenAIModel,
ChatMessage,
ChatMessageStreamDelta,
Tool,
TokenUsage
)
from smolagents.models import (
ChatMessageToolCallStreamDelta,
ChatMessageStreamDelta,
remove_content_after_stop_sequences
)
import openai
class SwitchableOpenAIModel(OpenAIModel):
"""This model connects to an OpenAI-compatible API server.
Parameters:
model_list (`str`):
The models identifier to use on the server (e.g. "gpt-5").
api_base (`str`, *optional*):
The base URL of the OpenAI-compatible API server.
api_key (`str`, *optional*):
The API key to use for authentication.
organization (`str`, *optional*):
The organization to use for the API request.
project (`str`, *optional*):
The project to use for the API request.
client_kwargs (`dict[str, Any]`, *optional*):
Additional keyword arguments to pass to the OpenAI client (like organization, project, max_retries etc.).
custom_role_conversions (`dict[str, str]`, *optional*):
Custom role conversion mapping to convert message roles in others.
Useful for specific models that do not support specific message roles like "system".
flatten_messages_as_text (`bool`, default `False`):
Whether to flatten messages as text.
**kwargs:
Additional keyword arguments to forward to the underlying OpenAI API completion call, for instance `temperature`.
"""
def __init__(
self,
model_list: str,
api_base: str | None = None,
api_key: str | None = None,
organization: str | None = None,
project: str | None = None,
client_kwargs: dict[str, Any] | None = None,
custom_role_conversions: dict[str, str] | None = None,
flatten_messages_as_text: bool = False,
**kwargs,
):
self.model_list = model_list
self.model_index = 0
super().__init__(
model_id=self.model_list[self.model_index],
api_base=api_base,
api_key=api_key,
organization=organization,
project=project,
client_kwargs=client_kwargs,
custom_role_conversions=custom_role_conversions,
flatten_messages_as_text=flatten_messages_as_text,
**kwargs,
)
def generate_stream(
self,
messages: list[ChatMessage | dict],
stop_sequences: list[str] | None = None,
response_format: dict[str, str] | None = None,
tools_to_call_from: list[Tool] | None = None,
**kwargs,
) -> Generator[ChatMessageStreamDelta]:
completion_kwargs = self._prepare_completion_kwargs(
messages=messages,
stop_sequences=stop_sequences,
response_format=response_format,
tools_to_call_from=tools_to_call_from,
model=self.model_list[self.model_index],
custom_role_conversions=self.custom_role_conversions,
convert_images_to_image_urls=True,
**kwargs,
)
self._apply_rate_limit()
try:
for event in self.client.chat.completions.create(
**completion_kwargs,
stream=True,
stream_options={"include_usage": True},
):
if event.usage:
yield ChatMessageStreamDelta(
content="",
token_usage=TokenUsage(
input_tokens=event.usage.prompt_tokens,
output_tokens=event.usage.completion_tokens,
),
)
if event.choices:
choice = event.choices[0]
if choice.delta:
yield ChatMessageStreamDelta(
content=choice.delta.content,
tool_calls=[
ChatMessageToolCallStreamDelta(
index=delta.index,
id=delta.id,
type=delta.type,
function=delta.function,
)
for delta in choice.delta.tool_calls
]
if choice.delta.tool_calls
else None,
)
else:
if not getattr(choice, "finish_reason", None):
raise ValueError(
f"No content or tool calls in event: {event}")
except openai.RateLimitError as err:
if self.model_index < len(self.model_list) - 1:
self.model_index += 1
print(
f"Switching to model {self.model_list[self.model_index]}")
return self.generate_stream(
messages=messages,
stop_sequences=stop_sequences,
response_format=response_format,
tools_to_call_from=tools_to_call_from,
**kwargs,
)
else:
raise err
except Exception as err:
raise err
def generate(
self,
messages: list[ChatMessage | dict],
stop_sequences: list[str] | None = None,
response_format: dict[str, str] | None = None,
tools_to_call_from: list[Tool] | None = None,
**kwargs,
) -> ChatMessage:
completion_kwargs = self._prepare_completion_kwargs(
messages=messages,
stop_sequences=stop_sequences,
response_format=response_format,
tools_to_call_from=tools_to_call_from,
model=self.model_list[self.model_index],
custom_role_conversions=self.custom_role_conversions,
convert_images_to_image_urls=True,
**kwargs,
)
self._apply_rate_limit()
try:
response = self.client.chat.completions.create(**completion_kwargs)
except openai.RateLimitError as err:
if self.model_index < len(self.model_list) - 1:
self.model_index += 1
print(
f"Switching to model {self.model_list[self.model_index]}")
return self.generate(
messages=messages,
stop_sequences=stop_sequences,
response_format=response_format,
tools_to_call_from=tools_to_call_from,
**kwargs,
)
else:
raise err
except Exception as err:
raise err
content = response.choices[0].message.content
if stop_sequences is not None and not self.supports_stop_parameter:
content = remove_content_after_stop_sequences(
content, stop_sequences)
return ChatMessage(
role=response.choices[0].message.role,
content=content,
tool_calls=response.choices[0].message.tool_calls,
raw=response,
token_usage=TokenUsage(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
),
)
|