Spaces:
Runtime error
Runtime error
File size: 9,130 Bytes
28a72c5 6dac870 28a72c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
from typing import Any
from copy import deepcopy
from smolagents import ChatMessage, Tool, tool_role_conversions, MessageRole, ApiModel
from smolagents.models import supports_stop_parameter
from smolagents.utils import make_image_url, encode_image_base64
from ollama import Client
from ollama import Message as OllamaMessage
from tools_registry import ToolsRegistry
DEFAULT_NUM_CTX = 16384
class OllamaModel(ApiModel):
def __init__(
self,
host: str = "http://localhost:11434",
model_id: str = "gemma3:12b",
timeout: int = 120,
client_kwargs: dict[str, Any] | None = None,
custom_role_conversions: dict[str, str] | None = None,
flatten_messages_as_text: bool = True,
tools_registry=ToolsRegistry,
**kwargs,
):
self.client_kwargs = {**(client_kwargs or {}), "host": host, "timeout": timeout}
self.client = Client(**self.client_kwargs)
self.tools_registry = tools_registry
super().__init__(
model_id=model_id,
custom_role_conversions=custom_role_conversions,
client=self.client,
flatten_messages_as_text=flatten_messages_as_text,
**kwargs,
)
def _get_clean_message_list(
self,
message_list: list[dict[str, str | list[dict]]],
role_conversions: dict[MessageRole, MessageRole] | dict[str, str] = {},
convert_images_to_image_urls: bool = False,
flatten_messages_as_text: bool = False,
) -> list[OllamaMessage]:
"""
Subsequent messages with the same role will be concatenated to a single message.
output_message_list is a list of messages that will be used to generate the final message that is chat template compatible with transformers LLM chat template.
Args:
message_list (`list[dict[str, str]]`): List of chat messages.
role_conversions (`dict[MessageRole, MessageRole]`, *optional* ): Mapping to convert roles.
convert_images_to_image_urls (`bool`, default `False`): Whether to convert images to image URLs.
flatten_messages_as_text (`bool`, default `False`): Whether to flatten messages as text.
"""
output_message_list: list[dict[str, str | list[dict]]] = []
message_list = deepcopy(message_list) # Avoid modifying the original list
for message in message_list:
role = message["role"]
if role not in MessageRole.roles():
raise ValueError(
f"Incorrect role {role}, only {MessageRole.roles()} are supported for now."
)
if role in role_conversions:
message["role"] = role_conversions[role] # type: ignore
# encode images if needed
if isinstance(message["content"], list):
for element in message["content"]:
assert isinstance(
element, dict
), "Error: this element should be a dict:" + str(element)
if element["type"] == "image":
assert (
not flatten_messages_as_text
), f"Cannot use images with {flatten_messages_as_text=}"
if convert_images_to_image_urls:
element.update(
{
"type": "image_url",
"image_url": {
"url": make_image_url(
encode_image_base64(element.pop("image"))
)
},
}
)
else:
element["image"] = encode_image_base64(element["image"])
if (
len(output_message_list) > 0
and message["role"] == output_message_list[-1]["role"]
):
assert isinstance(
message["content"], list
), "Error: wrong content:" + str(message["content"])
if flatten_messages_as_text:
output_message_list[-1]["content"] += (
"\n" + message["content"][0]["text"]
)
else:
for el in message["content"]:
if (
el["type"] == "text"
and output_message_list[-1]["content"][-1]["type"] == "text"
):
# Merge consecutive text messages rather than creating new ones
output_message_list[-1]["content"][-1]["text"] += (
"\n" + el["text"]
)
else:
output_message_list[-1]["content"].append(el)
else:
if flatten_messages_as_text:
content = message["content"][0]["text"]
else:
content = message["content"]
output_message_list.append(
OllamaMessage(role=message["role"], content=content)
)
return output_message_list
def _prepare_completion_kwargs(
self,
messages: list[dict[str, str | list[dict]]],
stop_sequences: list[str] | None = None,
grammar: str | None = None,
tools_to_call_from: list[Tool] | None = None,
custom_role_conversions: dict[str, str] | None = None,
convert_images_to_image_urls: bool = False,
stream: bool = True,
keep_alive: float | str | None = None,
**kwargs,
) -> dict[str, Any]:
"""
Prepare parameters required for model invocation, handling parameter priorities.
Parameter priority from high to low:
1. Explicitly passed kwargs
2. Specific parameters (stop_sequences, grammar, etc.)
3. Default values in self.kwargs
"""
# Clean and standardize the message list
flatten_messages_as_text = kwargs.pop(
"flatten_messages_as_text", self.flatten_messages_as_text
)
messages = self._get_clean_message_list(
messages,
role_conversions=custom_role_conversions or tool_role_conversions,
convert_images_to_image_urls=convert_images_to_image_urls,
flatten_messages_as_text=flatten_messages_as_text,
)
generation_options = {"num_ctx": DEFAULT_NUM_CTX}
# Use self.kwargs as the base configuration
completion_kwargs = {
**self.kwargs,
"messages": messages,
}
# Handle specific parameters
if stop_sequences is not None:
# Some models do not support stop parameter
if supports_stop_parameter(self.model_id or ""):
generation_options["stop"] = stop_sequences
# Define Ollama's parameters
completion_kwargs["model"] = self.model_id
if tools_to_call_from:
completion_kwargs["tools"] = [
self.tools_registry.get_all_tools()[tool.name]
for tool in tools_to_call_from
]
else:
completion_kwargs["tools"] = None
completion_kwargs["stream"] = stream
if grammar is not None:
completion_kwargs["format"] = grammar
else:
completion_kwargs["format"] = None
completion_kwargs["options"] = generation_options
completion_kwargs["keep_alive"] = keep_alive
# Finally, use the passed-in kwargs to override all settings
completion_kwargs.update(kwargs)
return completion_kwargs
def generate(
self,
messages: list[dict[str, str | list[dict]]],
stop_sequences: list[str] | None = None,
grammar: str | None = None,
tools_to_call_from: list[Tool] | None = None,
**kwargs,
) -> ChatMessage:
completion_kwargs = self._prepare_completion_kwargs(
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
tools_to_call_from=tools_to_call_from,
convert_images_to_image_urls=True,
custom_role_conversions=self.custom_role_conversions,
stream=False,
**kwargs,
)
response = self.client.chat(**completion_kwargs)
self.last_input_token_count = response.prompt_eval_count
self.last_output_token_count = response.eval_count
response_json_dumped = response.message.model_dump()
if response_json_dumped["tool_calls"]:
for tool_call_id, tool_call in enumerate(
response_json_dumped["tool_calls"]
):
tool_call["type"] = "function"
tool_call["id"] = tool_call_id
return ChatMessage.from_dict(response_json_dumped, raw=None)
|