from typing import Any from collections.abc import Generator from smolagents import ( OpenAIModel, ChatMessage, ChatMessageStreamDelta, Tool, TokenUsage ) from smolagents.models import ( ChatMessageToolCallStreamDelta, ChatMessageStreamDelta, remove_content_after_stop_sequences ) import openai class SwitchableOpenAIModel(OpenAIModel): """This model connects to an OpenAI-compatible API server. Parameters: model_list (`str`): The models identifier to use on the server (e.g. "gpt-5"). api_base (`str`, *optional*): The base URL of the OpenAI-compatible API server. api_key (`str`, *optional*): The API key to use for authentication. organization (`str`, *optional*): The organization to use for the API request. project (`str`, *optional*): The project to use for the API request. client_kwargs (`dict[str, Any]`, *optional*): Additional keyword arguments to pass to the OpenAI client (like organization, project, max_retries etc.). custom_role_conversions (`dict[str, str]`, *optional*): Custom role conversion mapping to convert message roles in others. Useful for specific models that do not support specific message roles like "system". flatten_messages_as_text (`bool`, default `False`): Whether to flatten messages as text. **kwargs: Additional keyword arguments to forward to the underlying OpenAI API completion call, for instance `temperature`. """ def __init__( self, model_list: str, api_base: str | None = None, api_key: str | None = None, organization: str | None = None, project: str | None = None, client_kwargs: dict[str, Any] | None = None, custom_role_conversions: dict[str, str] | None = None, flatten_messages_as_text: bool = False, **kwargs, ): self.model_list = model_list self.model_index = 0 super().__init__( model_id=self.model_list[self.model_index], api_base=api_base, api_key=api_key, organization=organization, project=project, client_kwargs=client_kwargs, custom_role_conversions=custom_role_conversions, flatten_messages_as_text=flatten_messages_as_text, **kwargs, ) def generate_stream( self, messages: list[ChatMessage | dict], stop_sequences: list[str] | None = None, response_format: dict[str, str] | None = None, tools_to_call_from: list[Tool] | None = None, **kwargs, ) -> Generator[ChatMessageStreamDelta]: completion_kwargs = self._prepare_completion_kwargs( messages=messages, stop_sequences=stop_sequences, response_format=response_format, tools_to_call_from=tools_to_call_from, model=self.model_list[self.model_index], custom_role_conversions=self.custom_role_conversions, convert_images_to_image_urls=True, **kwargs, ) self._apply_rate_limit() try: for event in self.client.chat.completions.create( **completion_kwargs, stream=True, stream_options={"include_usage": True}, ): if event.usage: yield ChatMessageStreamDelta( content="", token_usage=TokenUsage( input_tokens=event.usage.prompt_tokens, output_tokens=event.usage.completion_tokens, ), ) if event.choices: choice = event.choices[0] if choice.delta: yield ChatMessageStreamDelta( content=choice.delta.content, tool_calls=[ ChatMessageToolCallStreamDelta( index=delta.index, id=delta.id, type=delta.type, function=delta.function, ) for delta in choice.delta.tool_calls ] if choice.delta.tool_calls else None, ) else: if not getattr(choice, "finish_reason", None): raise ValueError( f"No content or tool calls in event: {event}") except openai.RateLimitError as err: if self.model_index < len(self.model_list) - 1: self.model_index += 1 print( f"Switching to model {self.model_list[self.model_index]}") return self.generate_stream( messages=messages, stop_sequences=stop_sequences, response_format=response_format, tools_to_call_from=tools_to_call_from, **kwargs, ) else: raise err except Exception as err: raise err def generate( self, messages: list[ChatMessage | dict], stop_sequences: list[str] | None = None, response_format: dict[str, str] | None = None, tools_to_call_from: list[Tool] | None = None, **kwargs, ) -> ChatMessage: completion_kwargs = self._prepare_completion_kwargs( messages=messages, stop_sequences=stop_sequences, response_format=response_format, tools_to_call_from=tools_to_call_from, model=self.model_list[self.model_index], custom_role_conversions=self.custom_role_conversions, convert_images_to_image_urls=True, **kwargs, ) self._apply_rate_limit() try: response = self.client.chat.completions.create(**completion_kwargs) except openai.RateLimitError as err: if self.model_index < len(self.model_list) - 1: self.model_index += 1 print( f"Switching to model {self.model_list[self.model_index]}") return self.generate( messages=messages, stop_sequences=stop_sequences, response_format=response_format, tools_to_call_from=tools_to_call_from, **kwargs, ) else: raise err except Exception as err: raise err content = response.choices[0].message.content if stop_sequences is not None and not self.supports_stop_parameter: content = remove_content_after_stop_sequences( content, stop_sequences) return ChatMessage( role=response.choices[0].message.role, content=content, tool_calls=response.choices[0].message.tool_calls, raw=response, token_usage=TokenUsage( input_tokens=response.usage.prompt_tokens, output_tokens=response.usage.completion_tokens, ), )