from typing import Any from copy import deepcopy from smolagents import ChatMessage, Tool, tool_role_conversions, MessageRole, ApiModel from smolagents.models import supports_stop_parameter from smolagents.utils import make_image_url, encode_image_base64 from ollama import Client from ollama import Message as OllamaMessage from tools_registry import ToolsRegistry DEFAULT_NUM_CTX = 16384 class OllamaModel(ApiModel): def __init__( self, host: str = "http://localhost:11434", model_id: str = "gemma3:12b", timeout: int = 120, client_kwargs: dict[str, Any] | None = None, custom_role_conversions: dict[str, str] | None = None, flatten_messages_as_text: bool = True, tools_registry=ToolsRegistry, **kwargs, ): self.client_kwargs = {**(client_kwargs or {}), "host": host, "timeout": timeout} self.client = Client(**self.client_kwargs) self.tools_registry = tools_registry super().__init__( model_id=model_id, custom_role_conversions=custom_role_conversions, client=self.client, flatten_messages_as_text=flatten_messages_as_text, **kwargs, ) def _get_clean_message_list( self, message_list: list[dict[str, str | list[dict]]], role_conversions: dict[MessageRole, MessageRole] | dict[str, str] = {}, convert_images_to_image_urls: bool = False, flatten_messages_as_text: bool = False, ) -> list[OllamaMessage]: """ Subsequent messages with the same role will be concatenated to a single message. output_message_list is a list of messages that will be used to generate the final message that is chat template compatible with transformers LLM chat template. Args: message_list (`list[dict[str, str]]`): List of chat messages. role_conversions (`dict[MessageRole, MessageRole]`, *optional* ): Mapping to convert roles. convert_images_to_image_urls (`bool`, default `False`): Whether to convert images to image URLs. flatten_messages_as_text (`bool`, default `False`): Whether to flatten messages as text. """ output_message_list: list[dict[str, str | list[dict]]] = [] message_list = deepcopy(message_list) # Avoid modifying the original list for message in message_list: role = message["role"] if role not in MessageRole.roles(): raise ValueError( f"Incorrect role {role}, only {MessageRole.roles()} are supported for now." ) if role in role_conversions: message["role"] = role_conversions[role] # type: ignore # encode images if needed if isinstance(message["content"], list): for element in message["content"]: assert isinstance( element, dict ), "Error: this element should be a dict:" + str(element) if element["type"] == "image": assert ( not flatten_messages_as_text ), f"Cannot use images with {flatten_messages_as_text=}" if convert_images_to_image_urls: element.update( { "type": "image_url", "image_url": { "url": make_image_url( encode_image_base64(element.pop("image")) ) }, } ) else: element["image"] = encode_image_base64(element["image"]) if ( len(output_message_list) > 0 and message["role"] == output_message_list[-1]["role"] ): assert isinstance( message["content"], list ), "Error: wrong content:" + str(message["content"]) if flatten_messages_as_text: output_message_list[-1]["content"] += ( "\n" + message["content"][0]["text"] ) else: for el in message["content"]: if ( el["type"] == "text" and output_message_list[-1]["content"][-1]["type"] == "text" ): # Merge consecutive text messages rather than creating new ones output_message_list[-1]["content"][-1]["text"] += ( "\n" + el["text"] ) else: output_message_list[-1]["content"].append(el) else: if flatten_messages_as_text: content = message["content"][0]["text"] else: content = message["content"] output_message_list.append( OllamaMessage(role=message["role"], content=content) ) return output_message_list def _prepare_completion_kwargs( self, messages: list[dict[str, str | list[dict]]], stop_sequences: list[str] | None = None, grammar: str | None = None, tools_to_call_from: list[Tool] | None = None, custom_role_conversions: dict[str, str] | None = None, convert_images_to_image_urls: bool = False, stream: bool = True, keep_alive: float | str | None = None, **kwargs, ) -> dict[str, Any]: """ Prepare parameters required for model invocation, handling parameter priorities. Parameter priority from high to low: 1. Explicitly passed kwargs 2. Specific parameters (stop_sequences, grammar, etc.) 3. Default values in self.kwargs """ # Clean and standardize the message list flatten_messages_as_text = kwargs.pop( "flatten_messages_as_text", self.flatten_messages_as_text ) messages = self._get_clean_message_list( messages, role_conversions=custom_role_conversions or tool_role_conversions, convert_images_to_image_urls=convert_images_to_image_urls, flatten_messages_as_text=flatten_messages_as_text, ) generation_options = {"num_ctx": DEFAULT_NUM_CTX} # Use self.kwargs as the base configuration completion_kwargs = { **self.kwargs, "messages": messages, } # Handle specific parameters if stop_sequences is not None: # Some models do not support stop parameter if supports_stop_parameter(self.model_id or ""): generation_options["stop"] = stop_sequences # Define Ollama's parameters completion_kwargs["model"] = self.model_id if tools_to_call_from: completion_kwargs["tools"] = [ self.tools_registry.get_all_tools()[tool.name] for tool in tools_to_call_from ] else: completion_kwargs["tools"] = None completion_kwargs["stream"] = stream if grammar is not None: completion_kwargs["format"] = grammar else: completion_kwargs["format"] = None completion_kwargs["options"] = generation_options completion_kwargs["keep_alive"] = keep_alive # Finally, use the passed-in kwargs to override all settings completion_kwargs.update(kwargs) return completion_kwargs def generate( self, messages: list[dict[str, str | list[dict]]], stop_sequences: list[str] | None = None, grammar: str | None = None, tools_to_call_from: list[Tool] | None = None, **kwargs, ) -> ChatMessage: completion_kwargs = self._prepare_completion_kwargs( messages=messages, stop_sequences=stop_sequences, grammar=grammar, tools_to_call_from=tools_to_call_from, convert_images_to_image_urls=True, custom_role_conversions=self.custom_role_conversions, stream=False, **kwargs, ) response = self.client.chat(**completion_kwargs) self.last_input_token_count = response.prompt_eval_count self.last_output_token_count = response.eval_count response_json_dumped = response.message.model_dump() if response_json_dumped["tool_calls"]: for tool_call_id, tool_call in enumerate( response_json_dumped["tool_calls"] ): tool_call["type"] = "function" tool_call["id"] = tool_call_id return ChatMessage.from_dict(response_json_dumped, raw=None)