DirtyFock
Use predefined manager_agent
6dac870
from typing import Any
from copy import deepcopy
from smolagents import ChatMessage, Tool, tool_role_conversions, MessageRole, ApiModel
from smolagents.models import supports_stop_parameter
from smolagents.utils import make_image_url, encode_image_base64
from ollama import Client
from ollama import Message as OllamaMessage
from tools_registry import ToolsRegistry
DEFAULT_NUM_CTX = 16384
class OllamaModel(ApiModel):
def __init__(
self,
host: str = "http://localhost:11434",
model_id: str = "gemma3:12b",
timeout: int = 120,
client_kwargs: dict[str, Any] | None = None,
custom_role_conversions: dict[str, str] | None = None,
flatten_messages_as_text: bool = True,
tools_registry=ToolsRegistry,
**kwargs,
):
self.client_kwargs = {**(client_kwargs or {}), "host": host, "timeout": timeout}
self.client = Client(**self.client_kwargs)
self.tools_registry = tools_registry
super().__init__(
model_id=model_id,
custom_role_conversions=custom_role_conversions,
client=self.client,
flatten_messages_as_text=flatten_messages_as_text,
**kwargs,
)
def _get_clean_message_list(
self,
message_list: list[dict[str, str | list[dict]]],
role_conversions: dict[MessageRole, MessageRole] | dict[str, str] = {},
convert_images_to_image_urls: bool = False,
flatten_messages_as_text: bool = False,
) -> list[OllamaMessage]:
"""
Subsequent messages with the same role will be concatenated to a single message.
output_message_list is a list of messages that will be used to generate the final message that is chat template compatible with transformers LLM chat template.
Args:
message_list (`list[dict[str, str]]`): List of chat messages.
role_conversions (`dict[MessageRole, MessageRole]`, *optional* ): Mapping to convert roles.
convert_images_to_image_urls (`bool`, default `False`): Whether to convert images to image URLs.
flatten_messages_as_text (`bool`, default `False`): Whether to flatten messages as text.
"""
output_message_list: list[dict[str, str | list[dict]]] = []
message_list = deepcopy(message_list) # Avoid modifying the original list
for message in message_list:
role = message["role"]
if role not in MessageRole.roles():
raise ValueError(
f"Incorrect role {role}, only {MessageRole.roles()} are supported for now."
)
if role in role_conversions:
message["role"] = role_conversions[role] # type: ignore
# encode images if needed
if isinstance(message["content"], list):
for element in message["content"]:
assert isinstance(
element, dict
), "Error: this element should be a dict:" + str(element)
if element["type"] == "image":
assert (
not flatten_messages_as_text
), f"Cannot use images with {flatten_messages_as_text=}"
if convert_images_to_image_urls:
element.update(
{
"type": "image_url",
"image_url": {
"url": make_image_url(
encode_image_base64(element.pop("image"))
)
},
}
)
else:
element["image"] = encode_image_base64(element["image"])
if (
len(output_message_list) > 0
and message["role"] == output_message_list[-1]["role"]
):
assert isinstance(
message["content"], list
), "Error: wrong content:" + str(message["content"])
if flatten_messages_as_text:
output_message_list[-1]["content"] += (
"\n" + message["content"][0]["text"]
)
else:
for el in message["content"]:
if (
el["type"] == "text"
and output_message_list[-1]["content"][-1]["type"] == "text"
):
# Merge consecutive text messages rather than creating new ones
output_message_list[-1]["content"][-1]["text"] += (
"\n" + el["text"]
)
else:
output_message_list[-1]["content"].append(el)
else:
if flatten_messages_as_text:
content = message["content"][0]["text"]
else:
content = message["content"]
output_message_list.append(
OllamaMessage(role=message["role"], content=content)
)
return output_message_list
def _prepare_completion_kwargs(
self,
messages: list[dict[str, str | list[dict]]],
stop_sequences: list[str] | None = None,
grammar: str | None = None,
tools_to_call_from: list[Tool] | None = None,
custom_role_conversions: dict[str, str] | None = None,
convert_images_to_image_urls: bool = False,
stream: bool = True,
keep_alive: float | str | None = None,
**kwargs,
) -> dict[str, Any]:
"""
Prepare parameters required for model invocation, handling parameter priorities.
Parameter priority from high to low:
1. Explicitly passed kwargs
2. Specific parameters (stop_sequences, grammar, etc.)
3. Default values in self.kwargs
"""
# Clean and standardize the message list
flatten_messages_as_text = kwargs.pop(
"flatten_messages_as_text", self.flatten_messages_as_text
)
messages = self._get_clean_message_list(
messages,
role_conversions=custom_role_conversions or tool_role_conversions,
convert_images_to_image_urls=convert_images_to_image_urls,
flatten_messages_as_text=flatten_messages_as_text,
)
generation_options = {"num_ctx": DEFAULT_NUM_CTX}
# Use self.kwargs as the base configuration
completion_kwargs = {
**self.kwargs,
"messages": messages,
}
# Handle specific parameters
if stop_sequences is not None:
# Some models do not support stop parameter
if supports_stop_parameter(self.model_id or ""):
generation_options["stop"] = stop_sequences
# Define Ollama's parameters
completion_kwargs["model"] = self.model_id
if tools_to_call_from:
completion_kwargs["tools"] = [
self.tools_registry.get_all_tools()[tool.name]
for tool in tools_to_call_from
]
else:
completion_kwargs["tools"] = None
completion_kwargs["stream"] = stream
if grammar is not None:
completion_kwargs["format"] = grammar
else:
completion_kwargs["format"] = None
completion_kwargs["options"] = generation_options
completion_kwargs["keep_alive"] = keep_alive
# Finally, use the passed-in kwargs to override all settings
completion_kwargs.update(kwargs)
return completion_kwargs
def generate(
self,
messages: list[dict[str, str | list[dict]]],
stop_sequences: list[str] | None = None,
grammar: str | None = None,
tools_to_call_from: list[Tool] | None = None,
**kwargs,
) -> ChatMessage:
completion_kwargs = self._prepare_completion_kwargs(
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
tools_to_call_from=tools_to_call_from,
convert_images_to_image_urls=True,
custom_role_conversions=self.custom_role_conversions,
stream=False,
**kwargs,
)
response = self.client.chat(**completion_kwargs)
self.last_input_token_count = response.prompt_eval_count
self.last_output_token_count = response.eval_count
response_json_dumped = response.message.model_dump()
if response_json_dumped["tool_calls"]:
for tool_call_id, tool_call in enumerate(
response_json_dumped["tool_calls"]
):
tool_call["type"] = "function"
tool_call["id"] = tool_call_id
return ChatMessage.from_dict(response_json_dumped, raw=None)