File size: 9,130 Bytes
28a72c5
 
 
 
 
 
 
 
 
 
 
6dac870
28a72c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from typing import Any
from copy import deepcopy

from smolagents import ChatMessage, Tool, tool_role_conversions, MessageRole, ApiModel
from smolagents.models import supports_stop_parameter
from smolagents.utils import make_image_url, encode_image_base64
from ollama import Client
from ollama import Message as OllamaMessage

from tools_registry import ToolsRegistry

DEFAULT_NUM_CTX = 16384


class OllamaModel(ApiModel):
    def __init__(
        self,
        host: str = "http://localhost:11434",
        model_id: str = "gemma3:12b",
        timeout: int = 120,
        client_kwargs: dict[str, Any] | None = None,
        custom_role_conversions: dict[str, str] | None = None,
        flatten_messages_as_text: bool = True,
        tools_registry=ToolsRegistry,
        **kwargs,
    ):
        self.client_kwargs = {**(client_kwargs or {}), "host": host, "timeout": timeout}
        self.client = Client(**self.client_kwargs)
        self.tools_registry = tools_registry
        super().__init__(
            model_id=model_id,
            custom_role_conversions=custom_role_conversions,
            client=self.client,
            flatten_messages_as_text=flatten_messages_as_text,
            **kwargs,
        )

    def _get_clean_message_list(
        self,
        message_list: list[dict[str, str | list[dict]]],
        role_conversions: dict[MessageRole, MessageRole] | dict[str, str] = {},
        convert_images_to_image_urls: bool = False,
        flatten_messages_as_text: bool = False,
    ) -> list[OllamaMessage]:
        """
        Subsequent messages with the same role will be concatenated to a single message.
        output_message_list is a list of messages that will be used to generate the final message that is chat template compatible with transformers LLM chat template.

        Args:
            message_list (`list[dict[str, str]]`): List of chat messages.
            role_conversions (`dict[MessageRole, MessageRole]`, *optional* ): Mapping to convert roles.
            convert_images_to_image_urls (`bool`, default `False`): Whether to convert images to image URLs.
            flatten_messages_as_text (`bool`, default `False`): Whether to flatten messages as text.
        """
        output_message_list: list[dict[str, str | list[dict]]] = []
        message_list = deepcopy(message_list)  # Avoid modifying the original list
        for message in message_list:
            role = message["role"]
            if role not in MessageRole.roles():
                raise ValueError(
                    f"Incorrect role {role}, only {MessageRole.roles()} are supported for now."
                )

            if role in role_conversions:
                message["role"] = role_conversions[role]  # type: ignore
            # encode images if needed
            if isinstance(message["content"], list):
                for element in message["content"]:
                    assert isinstance(
                        element, dict
                    ), "Error: this element should be a dict:" + str(element)
                    if element["type"] == "image":
                        assert (
                            not flatten_messages_as_text
                        ), f"Cannot use images with {flatten_messages_as_text=}"
                        if convert_images_to_image_urls:
                            element.update(
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": make_image_url(
                                            encode_image_base64(element.pop("image"))
                                        )
                                    },
                                }
                            )
                        else:
                            element["image"] = encode_image_base64(element["image"])

            if (
                len(output_message_list) > 0
                and message["role"] == output_message_list[-1]["role"]
            ):
                assert isinstance(
                    message["content"], list
                ), "Error: wrong content:" + str(message["content"])
                if flatten_messages_as_text:
                    output_message_list[-1]["content"] += (
                        "\n" + message["content"][0]["text"]
                    )
                else:
                    for el in message["content"]:
                        if (
                            el["type"] == "text"
                            and output_message_list[-1]["content"][-1]["type"] == "text"
                        ):
                            # Merge consecutive text messages rather than creating new ones
                            output_message_list[-1]["content"][-1]["text"] += (
                                "\n" + el["text"]
                            )
                        else:
                            output_message_list[-1]["content"].append(el)
            else:
                if flatten_messages_as_text:
                    content = message["content"][0]["text"]
                else:
                    content = message["content"]
                output_message_list.append(
                    OllamaMessage(role=message["role"], content=content)
                )
        return output_message_list

    def _prepare_completion_kwargs(
        self,
        messages: list[dict[str, str | list[dict]]],
        stop_sequences: list[str] | None = None,
        grammar: str | None = None,
        tools_to_call_from: list[Tool] | None = None,
        custom_role_conversions: dict[str, str] | None = None,
        convert_images_to_image_urls: bool = False,
        stream: bool = True,
        keep_alive: float | str | None = None,
        **kwargs,
    ) -> dict[str, Any]:
        """
        Prepare parameters required for model invocation, handling parameter priorities.

        Parameter priority from high to low:
        1. Explicitly passed kwargs
        2. Specific parameters (stop_sequences, grammar, etc.)
        3. Default values in self.kwargs
        """
        # Clean and standardize the message list
        flatten_messages_as_text = kwargs.pop(
            "flatten_messages_as_text", self.flatten_messages_as_text
        )
        messages = self._get_clean_message_list(
            messages,
            role_conversions=custom_role_conversions or tool_role_conversions,
            convert_images_to_image_urls=convert_images_to_image_urls,
            flatten_messages_as_text=flatten_messages_as_text,
        )
        generation_options = {"num_ctx": DEFAULT_NUM_CTX}
        # Use self.kwargs as the base configuration
        completion_kwargs = {
            **self.kwargs,
            "messages": messages,
        }

        # Handle specific parameters
        if stop_sequences is not None:
            # Some models do not support stop parameter
            if supports_stop_parameter(self.model_id or ""):
                generation_options["stop"] = stop_sequences

        # Define Ollama's parameters
        completion_kwargs["model"] = self.model_id
        if tools_to_call_from:
            completion_kwargs["tools"] = [
                self.tools_registry.get_all_tools()[tool.name]
                for tool in tools_to_call_from
            ]
        else:
            completion_kwargs["tools"] = None
        completion_kwargs["stream"] = stream
        if grammar is not None:
            completion_kwargs["format"] = grammar
        else:
            completion_kwargs["format"] = None

        completion_kwargs["options"] = generation_options
        completion_kwargs["keep_alive"] = keep_alive

        # Finally, use the passed-in kwargs to override all settings
        completion_kwargs.update(kwargs)

        return completion_kwargs

    def generate(
        self,
        messages: list[dict[str, str | list[dict]]],
        stop_sequences: list[str] | None = None,
        grammar: str | None = None,
        tools_to_call_from: list[Tool] | None = None,
        **kwargs,
    ) -> ChatMessage:
        completion_kwargs = self._prepare_completion_kwargs(
            messages=messages,
            stop_sequences=stop_sequences,
            grammar=grammar,
            tools_to_call_from=tools_to_call_from,
            convert_images_to_image_urls=True,
            custom_role_conversions=self.custom_role_conversions,
            stream=False,
            **kwargs,
        )
        response = self.client.chat(**completion_kwargs)

        self.last_input_token_count = response.prompt_eval_count
        self.last_output_token_count = response.eval_count

        response_json_dumped = response.message.model_dump()

        if response_json_dumped["tool_calls"]:
            for tool_call_id, tool_call in enumerate(
                response_json_dumped["tool_calls"]
            ):
                tool_call["type"] = "function"
                tool_call["id"] = tool_call_id

        return ChatMessage.from_dict(response_json_dumped, raw=None)