Spaces:
Runtime error
Runtime error
| from typing import Any | |
| from copy import deepcopy | |
| from smolagents import ChatMessage, Tool, tool_role_conversions, MessageRole, ApiModel | |
| from smolagents.models import supports_stop_parameter | |
| from smolagents.utils import make_image_url, encode_image_base64 | |
| from ollama import Client | |
| from ollama import Message as OllamaMessage | |
| from tools_registry import ToolsRegistry | |
| DEFAULT_NUM_CTX = 16384 | |
| class OllamaModel(ApiModel): | |
| def __init__( | |
| self, | |
| host: str = "http://localhost:11434", | |
| model_id: str = "gemma3:12b", | |
| timeout: int = 120, | |
| client_kwargs: dict[str, Any] | None = None, | |
| custom_role_conversions: dict[str, str] | None = None, | |
| flatten_messages_as_text: bool = True, | |
| tools_registry=ToolsRegistry, | |
| **kwargs, | |
| ): | |
| self.client_kwargs = {**(client_kwargs or {}), "host": host, "timeout": timeout} | |
| self.client = Client(**self.client_kwargs) | |
| self.tools_registry = tools_registry | |
| super().__init__( | |
| model_id=model_id, | |
| custom_role_conversions=custom_role_conversions, | |
| client=self.client, | |
| flatten_messages_as_text=flatten_messages_as_text, | |
| **kwargs, | |
| ) | |
| def _get_clean_message_list( | |
| self, | |
| message_list: list[dict[str, str | list[dict]]], | |
| role_conversions: dict[MessageRole, MessageRole] | dict[str, str] = {}, | |
| convert_images_to_image_urls: bool = False, | |
| flatten_messages_as_text: bool = False, | |
| ) -> list[OllamaMessage]: | |
| """ | |
| Subsequent messages with the same role will be concatenated to a single message. | |
| output_message_list is a list of messages that will be used to generate the final message that is chat template compatible with transformers LLM chat template. | |
| Args: | |
| message_list (`list[dict[str, str]]`): List of chat messages. | |
| role_conversions (`dict[MessageRole, MessageRole]`, *optional* ): Mapping to convert roles. | |
| convert_images_to_image_urls (`bool`, default `False`): Whether to convert images to image URLs. | |
| flatten_messages_as_text (`bool`, default `False`): Whether to flatten messages as text. | |
| """ | |
| output_message_list: list[dict[str, str | list[dict]]] = [] | |
| message_list = deepcopy(message_list) # Avoid modifying the original list | |
| for message in message_list: | |
| role = message["role"] | |
| if role not in MessageRole.roles(): | |
| raise ValueError( | |
| f"Incorrect role {role}, only {MessageRole.roles()} are supported for now." | |
| ) | |
| if role in role_conversions: | |
| message["role"] = role_conversions[role] # type: ignore | |
| # encode images if needed | |
| if isinstance(message["content"], list): | |
| for element in message["content"]: | |
| assert isinstance( | |
| element, dict | |
| ), "Error: this element should be a dict:" + str(element) | |
| if element["type"] == "image": | |
| assert ( | |
| not flatten_messages_as_text | |
| ), f"Cannot use images with {flatten_messages_as_text=}" | |
| if convert_images_to_image_urls: | |
| element.update( | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": make_image_url( | |
| encode_image_base64(element.pop("image")) | |
| ) | |
| }, | |
| } | |
| ) | |
| else: | |
| element["image"] = encode_image_base64(element["image"]) | |
| if ( | |
| len(output_message_list) > 0 | |
| and message["role"] == output_message_list[-1]["role"] | |
| ): | |
| assert isinstance( | |
| message["content"], list | |
| ), "Error: wrong content:" + str(message["content"]) | |
| if flatten_messages_as_text: | |
| output_message_list[-1]["content"] += ( | |
| "\n" + message["content"][0]["text"] | |
| ) | |
| else: | |
| for el in message["content"]: | |
| if ( | |
| el["type"] == "text" | |
| and output_message_list[-1]["content"][-1]["type"] == "text" | |
| ): | |
| # Merge consecutive text messages rather than creating new ones | |
| output_message_list[-1]["content"][-1]["text"] += ( | |
| "\n" + el["text"] | |
| ) | |
| else: | |
| output_message_list[-1]["content"].append(el) | |
| else: | |
| if flatten_messages_as_text: | |
| content = message["content"][0]["text"] | |
| else: | |
| content = message["content"] | |
| output_message_list.append( | |
| OllamaMessage(role=message["role"], content=content) | |
| ) | |
| return output_message_list | |
| def _prepare_completion_kwargs( | |
| self, | |
| messages: list[dict[str, str | list[dict]]], | |
| stop_sequences: list[str] | None = None, | |
| grammar: str | None = None, | |
| tools_to_call_from: list[Tool] | None = None, | |
| custom_role_conversions: dict[str, str] | None = None, | |
| convert_images_to_image_urls: bool = False, | |
| stream: bool = True, | |
| keep_alive: float | str | None = None, | |
| **kwargs, | |
| ) -> dict[str, Any]: | |
| """ | |
| Prepare parameters required for model invocation, handling parameter priorities. | |
| Parameter priority from high to low: | |
| 1. Explicitly passed kwargs | |
| 2. Specific parameters (stop_sequences, grammar, etc.) | |
| 3. Default values in self.kwargs | |
| """ | |
| # Clean and standardize the message list | |
| flatten_messages_as_text = kwargs.pop( | |
| "flatten_messages_as_text", self.flatten_messages_as_text | |
| ) | |
| messages = self._get_clean_message_list( | |
| messages, | |
| role_conversions=custom_role_conversions or tool_role_conversions, | |
| convert_images_to_image_urls=convert_images_to_image_urls, | |
| flatten_messages_as_text=flatten_messages_as_text, | |
| ) | |
| generation_options = {"num_ctx": DEFAULT_NUM_CTX} | |
| # Use self.kwargs as the base configuration | |
| completion_kwargs = { | |
| **self.kwargs, | |
| "messages": messages, | |
| } | |
| # Handle specific parameters | |
| if stop_sequences is not None: | |
| # Some models do not support stop parameter | |
| if supports_stop_parameter(self.model_id or ""): | |
| generation_options["stop"] = stop_sequences | |
| # Define Ollama's parameters | |
| completion_kwargs["model"] = self.model_id | |
| if tools_to_call_from: | |
| completion_kwargs["tools"] = [ | |
| self.tools_registry.get_all_tools()[tool.name] | |
| for tool in tools_to_call_from | |
| ] | |
| else: | |
| completion_kwargs["tools"] = None | |
| completion_kwargs["stream"] = stream | |
| if grammar is not None: | |
| completion_kwargs["format"] = grammar | |
| else: | |
| completion_kwargs["format"] = None | |
| completion_kwargs["options"] = generation_options | |
| completion_kwargs["keep_alive"] = keep_alive | |
| # Finally, use the passed-in kwargs to override all settings | |
| completion_kwargs.update(kwargs) | |
| return completion_kwargs | |
| def generate( | |
| self, | |
| messages: list[dict[str, str | list[dict]]], | |
| stop_sequences: list[str] | None = None, | |
| grammar: str | None = None, | |
| tools_to_call_from: list[Tool] | None = None, | |
| **kwargs, | |
| ) -> ChatMessage: | |
| completion_kwargs = self._prepare_completion_kwargs( | |
| messages=messages, | |
| stop_sequences=stop_sequences, | |
| grammar=grammar, | |
| tools_to_call_from=tools_to_call_from, | |
| convert_images_to_image_urls=True, | |
| custom_role_conversions=self.custom_role_conversions, | |
| stream=False, | |
| **kwargs, | |
| ) | |
| response = self.client.chat(**completion_kwargs) | |
| self.last_input_token_count = response.prompt_eval_count | |
| self.last_output_token_count = response.eval_count | |
| response_json_dumped = response.message.model_dump() | |
| if response_json_dumped["tool_calls"]: | |
| for tool_call_id, tool_call in enumerate( | |
| response_json_dumped["tool_calls"] | |
| ): | |
| tool_call["type"] = "function" | |
| tool_call["id"] = tool_call_id | |
| return ChatMessage.from_dict(response_json_dumped, raw=None) | |