oldmonk69 commited on
Commit
5fa6e05
·
verified ·
1 Parent(s): a07160e

Upload agents.py

Browse files
Files changed (1) hide show
  1. agents.py +1778 -0
agents.py ADDED
@@ -0,0 +1,1778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import importlib
18
+ import json
19
+ import os
20
+ import tempfile
21
+ import textwrap
22
+ import time
23
+ import warnings
24
+ from abc import ABC, abstractmethod
25
+ from collections.abc import Callable, Generator
26
+ from concurrent.futures import ThreadPoolExecutor, as_completed
27
+ from dataclasses import dataclass
28
+ from logging import getLogger
29
+ from pathlib import Path
30
+ from typing import TYPE_CHECKING, Any, Literal, Type, TypeAlias, TypedDict, Union
31
+
32
+ import yaml
33
+ from huggingface_hub import create_repo, metadata_update, snapshot_download, upload_folder
34
+ from jinja2 import StrictUndefined, Template
35
+ from rich.console import Group
36
+ from rich.live import Live
37
+ from rich.markdown import Markdown
38
+ from rich.panel import Panel
39
+ from rich.rule import Rule
40
+ from rich.text import Text
41
+
42
+
43
+ if TYPE_CHECKING:
44
+ import PIL.Image
45
+
46
+ from .agent_types import AgentAudio, AgentImage, handle_agent_output_types
47
+ from .default_tools import TOOL_MAPPING, FinalAnswerTool
48
+ from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonExecutor, PythonExecutor, fix_final_answer_code
49
+ from .memory import (
50
+ ActionStep,
51
+ AgentMemory,
52
+ CallbackRegistry,
53
+ FinalAnswerStep,
54
+ MemoryStep,
55
+ PlanningStep,
56
+ SystemPromptStep,
57
+ TaskStep,
58
+ Timing,
59
+ TokenUsage,
60
+ ToolCall,
61
+ )
62
+ from .models import (
63
+ CODEAGENT_RESPONSE_FORMAT,
64
+ ChatMessage,
65
+ ChatMessageStreamDelta,
66
+ ChatMessageToolCall,
67
+ MessageRole,
68
+ Model,
69
+ agglomerate_stream_deltas,
70
+ parse_json_if_needed,
71
+ )
72
+ from .monitoring import (
73
+ YELLOW_HEX,
74
+ AgentLogger,
75
+ LogLevel,
76
+ Monitor,
77
+ )
78
+ from .remote_executors import DockerExecutor, E2BExecutor, ModalExecutor, WasmExecutor
79
+ from .tools import BaseTool, Tool, validate_tool_arguments
80
+ from .utils import (
81
+ AgentError,
82
+ AgentExecutionError,
83
+ AgentGenerationError,
84
+ AgentMaxStepsError,
85
+ AgentParsingError,
86
+ AgentToolCallError,
87
+ AgentToolExecutionError,
88
+ create_agent_gradio_app_template,
89
+ extract_code_from_text,
90
+ is_valid_name,
91
+ make_init_file,
92
+ parse_code_blobs,
93
+ truncate_content,
94
+ )
95
+
96
+
97
+ logger = getLogger(__name__)
98
+
99
+
100
+ def populate_template(template: str, variables: dict[str, Any]) -> str:
101
+ compiled_template = Template(template, undefined=StrictUndefined)
102
+ try:
103
+ return compiled_template.render(**variables)
104
+ except Exception as e:
105
+ raise Exception(f"Error during jinja template rendering: {type(e).__name__}: {e}")
106
+
107
+
108
+ @dataclass
109
+ class ActionOutput:
110
+ output: Any
111
+ is_final_answer: bool
112
+
113
+
114
+ @dataclass
115
+ class ToolOutput:
116
+ id: str
117
+ output: Any
118
+ is_final_answer: bool
119
+ observation: str
120
+ tool_call: ToolCall
121
+
122
+
123
+ class PlanningPromptTemplate(TypedDict):
124
+ """
125
+ Prompt templates for the planning step.
126
+
127
+ Args:
128
+ plan (`str`): Initial plan prompt.
129
+ update_plan_pre_messages (`str`): Update plan pre-messages prompt.
130
+ update_plan_post_messages (`str`): Update plan post-messages prompt.
131
+ """
132
+
133
+ initial_plan: str
134
+ update_plan_pre_messages: str
135
+ update_plan_post_messages: str
136
+
137
+
138
+ class ManagedAgentPromptTemplate(TypedDict):
139
+ """
140
+ Prompt templates for the managed agent.
141
+
142
+ Args:
143
+ task (`str`): Task prompt.
144
+ report (`str`): Report prompt.
145
+ """
146
+
147
+ task: str
148
+ report: str
149
+
150
+
151
+ class FinalAnswerPromptTemplate(TypedDict):
152
+ """
153
+ Prompt templates for the final answer.
154
+
155
+ Args:
156
+ pre_messages (`str`): Pre-messages prompt.
157
+ post_messages (`str`): Post-messages prompt.
158
+ """
159
+
160
+ pre_messages: str
161
+ post_messages: str
162
+
163
+
164
+ class PromptTemplates(TypedDict):
165
+ """
166
+ Prompt templates for the agent.
167
+
168
+ Args:
169
+ system_prompt (`str`): System prompt.
170
+ planning ([`~agents.PlanningPromptTemplate`]): Planning prompt templates.
171
+ managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt templates.
172
+ final_answer ([`~agents.FinalAnswerPromptTemplate`]): Final answer prompt templates.
173
+ """
174
+
175
+ system_prompt: str
176
+ planning: PlanningPromptTemplate
177
+ managed_agent: ManagedAgentPromptTemplate
178
+ final_answer: FinalAnswerPromptTemplate
179
+
180
+
181
+ EMPTY_PROMPT_TEMPLATES = PromptTemplates(
182
+ system_prompt="",
183
+ planning=PlanningPromptTemplate(
184
+ initial_plan="",
185
+ update_plan_pre_messages="",
186
+ update_plan_post_messages="",
187
+ ),
188
+ managed_agent=ManagedAgentPromptTemplate(task="", report=""),
189
+ final_answer=FinalAnswerPromptTemplate(pre_messages="", post_messages=""),
190
+ )
191
+
192
+
193
+ @dataclass
194
+ class RunResult:
195
+ """Holds extended information about an agent run.
196
+
197
+ Attributes:
198
+ output (Any | None): The final output of the agent run, if available.
199
+ state (Literal["success", "max_steps_error"]): The final state of the agent after the run.
200
+ steps (list[dict]): The agent's memory, as a list of steps.
201
+ token_usage (TokenUsage | None): Count of tokens used during the run.
202
+ timing (Timing): Timing details of the agent run: start time, end time, duration.
203
+ messages (list[dict]): The agent's memory, as a list of messages.
204
+ <Deprecated version="1.22.0">
205
+ Parameter 'messages' is deprecated and will be removed in version 1.25. Please use 'steps' instead.
206
+ </Deprecated>
207
+ """
208
+
209
+ output: Any | None
210
+ state: Literal["success", "max_steps_error"]
211
+ steps: list[dict]
212
+ token_usage: TokenUsage | None
213
+ timing: Timing
214
+
215
+ def __init__(self, output=None, state=None, steps=None, token_usage=None, timing=None, messages=None):
216
+ # Handle deprecated 'messages' parameter
217
+ if messages is not None:
218
+ if steps is not None:
219
+ raise ValueError("Cannot specify both 'messages' and 'steps' parameters. Use 'steps' instead.")
220
+ warnings.warn(
221
+ "Parameter 'messages' is deprecated and will be removed in version 1.25. Please use 'steps' instead.",
222
+ FutureWarning,
223
+ stacklevel=2,
224
+ )
225
+ steps = messages
226
+
227
+ # Initialize with dataclass fields
228
+ self.output = output
229
+ self.state = state
230
+ self.steps = steps
231
+ self.token_usage = token_usage
232
+ self.timing = timing
233
+
234
+ @property
235
+ def messages(self):
236
+ """Backward compatibility property that returns steps."""
237
+ warnings.warn(
238
+ "Parameter 'messages' is deprecated and will be removed in version 1.25. Please use 'steps' instead.",
239
+ FutureWarning,
240
+ stacklevel=2,
241
+ )
242
+ return self.steps
243
+
244
+ def dict(self):
245
+ return {
246
+ "output": self.output,
247
+ "state": self.state,
248
+ "steps": self.steps,
249
+ "token_usage": self.token_usage.dict() if self.token_usage is not None else None,
250
+ "timing": self.timing.dict(),
251
+ }
252
+
253
+
254
+ StreamEvent: TypeAlias = Union[
255
+ ChatMessageStreamDelta,
256
+ ChatMessageToolCall,
257
+ ActionOutput,
258
+ ToolCall,
259
+ ToolOutput,
260
+ PlanningStep,
261
+ ActionStep,
262
+ FinalAnswerStep,
263
+ ]
264
+
265
+
266
+ class MultiStepAgent(ABC):
267
+ """
268
+ Agent class that solves the given task step by step, using the ReAct framework:
269
+ While the objective is not reached, the agent will perform a cycle of action (given by the LLM) and observation (obtained from the environment).
270
+
271
+ Args:
272
+ tools (`list[Tool]`): [`Tool`]s that the agent can use.
273
+ model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
274
+ prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
275
+ instructions (`str`, *optional*): Custom instructions for the agent, will be inserted in the system prompt.
276
+ max_steps (`int`, default `20`): Maximum number of steps the agent can take to solve the task.
277
+ add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools.
278
+ verbosity_level (`LogLevel`, default `LogLevel.INFO`): Level of verbosity of the agent's logs.
279
+ managed_agents (`list`, *optional*): Managed agents that the agent can call.
280
+ step_callbacks (`list[Callable]` | `dict[Type[MemoryStep], Callable | list[Callable]]`, *optional*): Callbacks that will be called at each step.
281
+ planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
282
+ name (`str`, *optional*): Necessary for a managed agent only - the name by which this agent can be called.
283
+ description (`str`, *optional*): Necessary for a managed agent only - the description of this agent.
284
+ provide_run_summary (`bool`, *optional*): Whether to provide a run summary when called as a managed agent.
285
+ final_answer_checks (`list[Callable]`, *optional*): List of validation functions to run before accepting a final answer.
286
+ Each function should:
287
+ - Take the final answer and the agent's memory as arguments.
288
+ - Return a boolean indicating whether the final answer is valid.
289
+ return_full_result (`bool`, default `False`): Whether to return the full [`RunResult`] object or just the final answer output from the agent run.
290
+ """
291
+
292
+ def __init__(
293
+ self,
294
+ tools: list[Tool],
295
+ model: Model,
296
+ prompt_templates: PromptTemplates | None = None,
297
+ instructions: str | None = None,
298
+ max_steps: int = 20,
299
+ add_base_tools: bool = False,
300
+ verbosity_level: LogLevel = LogLevel.INFO,
301
+ managed_agents: list | None = None,
302
+ step_callbacks: list[Callable] | dict[Type[MemoryStep], Callable | list[Callable]] | None = None,
303
+ planning_interval: int | None = None,
304
+ name: str | None = None,
305
+ description: str | None = None,
306
+ provide_run_summary: bool = False,
307
+ final_answer_checks: list[Callable] | None = None,
308
+ return_full_result: bool = False,
309
+ logger: AgentLogger | None = None,
310
+ ):
311
+ self.agent_name = self.__class__.__name__
312
+ self.model = model
313
+ self.prompt_templates = prompt_templates or EMPTY_PROMPT_TEMPLATES
314
+ if prompt_templates is not None:
315
+ missing_keys = set(EMPTY_PROMPT_TEMPLATES.keys()) - set(prompt_templates.keys())
316
+ assert not missing_keys, (
317
+ f"Some prompt templates are missing from your custom `prompt_templates`: {missing_keys}"
318
+ )
319
+ for key, value in EMPTY_PROMPT_TEMPLATES.items():
320
+ if isinstance(value, dict):
321
+ for subkey in value.keys():
322
+ assert key in prompt_templates.keys() and (subkey in prompt_templates[key].keys()), (
323
+ f"Some prompt templates are missing from your custom `prompt_templates`: {subkey} under {key}"
324
+ )
325
+
326
+ self.max_steps = max_steps
327
+ self.step_number = 0
328
+ self.planning_interval = planning_interval
329
+ self.state: dict[str, Any] = {}
330
+ self.name = self._validate_name(name)
331
+ self.description = description
332
+ self.provide_run_summary = provide_run_summary
333
+ self.final_answer_checks = final_answer_checks if final_answer_checks is not None else []
334
+ self.return_full_result = return_full_result
335
+ self.instructions = instructions
336
+ self._setup_managed_agents(managed_agents)
337
+ self._setup_tools(tools, add_base_tools)
338
+ self._validate_tools_and_managed_agents(tools, managed_agents)
339
+
340
+ self.task: str | None = None
341
+ self.memory = AgentMemory(self.system_prompt)
342
+
343
+ if logger is None:
344
+ self.logger = AgentLogger(level=verbosity_level)
345
+ else:
346
+ self.logger = logger
347
+
348
+ self.monitor = Monitor(self.model, self.logger)
349
+ self._setup_step_callbacks(step_callbacks)
350
+ self.stream_outputs = False
351
+
352
+ @property
353
+ def system_prompt(self) -> str:
354
+ return self.initialize_system_prompt()
355
+
356
+ @system_prompt.setter
357
+ def system_prompt(self, value: str):
358
+ raise AttributeError(
359
+ """The 'system_prompt' property is read-only. Use 'self.prompt_templates["system_prompt"]' instead."""
360
+ )
361
+
362
+ def _validate_name(self, name: str | None) -> str | None:
363
+ if name is not None and not is_valid_name(name):
364
+ raise ValueError(f"Agent name '{name}' must be a valid Python identifier and not a reserved keyword.")
365
+ return name
366
+
367
+ def _setup_managed_agents(self, managed_agents: list | None = None) -> None:
368
+ """Setup managed agents with proper logging."""
369
+ self.managed_agents = {}
370
+ if managed_agents:
371
+ assert all(agent.name and agent.description for agent in managed_agents), (
372
+ "All managed agents need both a name and a description!"
373
+ )
374
+ self.managed_agents = {agent.name: agent for agent in managed_agents}
375
+ # Ensure managed agents can be called as tools by the model: set their inputs and output_type
376
+ for agent in self.managed_agents.values():
377
+ agent.inputs = {
378
+ "task": {"type": "string", "description": "Long detailed description of the task."},
379
+ "additional_args": {
380
+ "type": "object",
381
+ "description": "Dictionary of extra inputs to pass to the managed agent, e.g. images, dataframes, or any other contextual data it may need.",
382
+ },
383
+ }
384
+ agent.output_type = "string"
385
+
386
+ def _setup_tools(self, tools, add_base_tools):
387
+ assert all(isinstance(tool, BaseTool) for tool in tools), (
388
+ "All elements must be instance of BaseTool (or a subclass)"
389
+ )
390
+ self.tools = {tool.name: tool for tool in tools}
391
+ if add_base_tools:
392
+ self.tools.update(
393
+ {
394
+ name: cls()
395
+ for name, cls in TOOL_MAPPING.items()
396
+ if name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent"
397
+ }
398
+ )
399
+ self.tools.setdefault("final_answer", FinalAnswerTool())
400
+
401
+ def _validate_tools_and_managed_agents(self, tools, managed_agents):
402
+ tool_and_managed_agent_names = [tool.name for tool in tools]
403
+ if managed_agents is not None:
404
+ tool_and_managed_agent_names += [agent.name for agent in managed_agents]
405
+ if self.name:
406
+ tool_and_managed_agent_names.append(self.name)
407
+ if len(tool_and_managed_agent_names) != len(set(tool_and_managed_agent_names)):
408
+ raise ValueError(
409
+ "Each tool or managed_agent should have a unique name! You passed these duplicate names: "
410
+ f"{[name for name in tool_and_managed_agent_names if tool_and_managed_agent_names.count(name) > 1]}"
411
+ )
412
+
413
+ def _setup_step_callbacks(self, step_callbacks):
414
+ # Initialize step callbacks registry
415
+ self.step_callbacks = CallbackRegistry()
416
+ if step_callbacks:
417
+ # Register callbacks list only for ActionStep for backward compatibility
418
+ if isinstance(step_callbacks, list):
419
+ for callback in step_callbacks:
420
+ self.step_callbacks.register(ActionStep, callback)
421
+ # Register callbacks dict for specific step classes
422
+ elif isinstance(step_callbacks, dict):
423
+ for step_cls, callbacks in step_callbacks.items():
424
+ if not isinstance(callbacks, list):
425
+ callbacks = [callbacks]
426
+ for callback in callbacks:
427
+ self.step_callbacks.register(step_cls, callback)
428
+ else:
429
+ raise ValueError("step_callbacks must be a list or a dict")
430
+ # Register monitor update_metrics only for ActionStep for backward compatibility
431
+ self.step_callbacks.register(ActionStep, self.monitor.update_metrics)
432
+
433
+ def run(
434
+ self,
435
+ task: str,
436
+ stream: bool = False,
437
+ reset: bool = True,
438
+ images: list["PIL.Image.Image"] | None = None,
439
+ additional_args: dict | None = None,
440
+ max_steps: int | None = None,
441
+ return_full_result: bool | None = None,
442
+ ) -> Any | RunResult:
443
+ """
444
+ Run the agent for the given task.
445
+
446
+ Args:
447
+ task (`str`): Task to perform.
448
+ stream (`bool`): Whether to run in streaming mode.
449
+ If `True`, returns a generator that yields each step as it is executed. You must iterate over this generator to process the individual steps (e.g., using a for loop or `next()`).
450
+ If `False`, executes all steps internally and returns only the final answer after completion.
451
+ reset (`bool`): Whether to reset the conversation or keep it going from previous run.
452
+ images (`list[PIL.Image.Image]`, *optional*): Image(s) objects.
453
+ additional_args (`dict`, *optional*): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names!
454
+ max_steps (`int`, *optional*): Maximum number of steps the agent can take to solve the task. if not provided, will use the agent's default value.
455
+ return_full_result (`bool`, *optional*): Whether to return the full [`RunResult`] object or just the final answer output.
456
+ If `None` (default), the agent's `self.return_full_result` setting is used.
457
+
458
+ Example:
459
+ ```py
460
+ from smolagents import CodeAgent
461
+ agent = CodeAgent(tools=[])
462
+ agent.run("What is the result of 2 power 3.7384?")
463
+ ```
464
+ """
465
+ max_steps = max_steps or self.max_steps
466
+ self.task = task
467
+ self.interrupt_switch = False
468
+ if additional_args:
469
+ self.state.update(additional_args)
470
+ self.task += f"""
471
+ You have been provided with these additional arguments, that you can access directly using the keys as variables:
472
+ {str(additional_args)}."""
473
+
474
+ self.memory.system_prompt = SystemPromptStep(system_prompt=self.system_prompt)
475
+ if reset:
476
+ self.memory.reset()
477
+ self.monitor.reset()
478
+
479
+ self.logger.log_task(
480
+ content=self.task.strip(),
481
+ subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}",
482
+ level=LogLevel.INFO,
483
+ title=self.name if hasattr(self, "name") else None,
484
+ )
485
+ self.memory.steps.append(TaskStep(task=self.task, task_images=images))
486
+
487
+ if getattr(self, "python_executor", None):
488
+ self.python_executor.send_variables(variables=self.state)
489
+ self.python_executor.send_tools({**self.tools, **self.managed_agents})
490
+
491
+ if stream:
492
+ # The steps are returned as they are executed through a generator to iterate on.
493
+ return self._run_stream(task=self.task, max_steps=max_steps, images=images)
494
+
495
+ run_start_time = time.time()
496
+ steps = list(self._run_stream(task=self.task, max_steps=max_steps, images=images))
497
+
498
+ # Outputs are returned only at the end. We only look at the last step.
499
+ assert isinstance(steps[-1], FinalAnswerStep)
500
+ output = steps[-1].output
501
+
502
+ return_full_result = return_full_result if return_full_result is not None else self.return_full_result
503
+ if return_full_result:
504
+ total_input_tokens = 0
505
+ total_output_tokens = 0
506
+ correct_token_usage = True
507
+ for step in self.memory.steps:
508
+ if isinstance(step, (ActionStep, PlanningStep)):
509
+ if step.token_usage is None:
510
+ correct_token_usage = False
511
+ break
512
+ else:
513
+ total_input_tokens += step.token_usage.input_tokens
514
+ total_output_tokens += step.token_usage.output_tokens
515
+ if correct_token_usage:
516
+ token_usage = TokenUsage(input_tokens=total_input_tokens, output_tokens=total_output_tokens)
517
+ else:
518
+ token_usage = None
519
+
520
+ if self.memory.steps and isinstance(getattr(self.memory.steps[-1], "error", None), AgentMaxStepsError):
521
+ state = "max_steps_error"
522
+ else:
523
+ state = "success"
524
+
525
+ step_dicts = self.memory.get_full_steps()
526
+
527
+ return RunResult(
528
+ output=output,
529
+ token_usage=token_usage,
530
+ steps=step_dicts,
531
+ timing=Timing(start_time=run_start_time, end_time=time.time()),
532
+ state=state,
533
+ )
534
+
535
+ return output
536
+
537
+ def _run_stream(
538
+ self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None
539
+ ) -> Generator[ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta]:
540
+ self.step_number = 1
541
+ returned_final_answer = False
542
+ while not returned_final_answer and self.step_number <= max_steps:
543
+ if self.interrupt_switch:
544
+ raise AgentError("Agent interrupted.", self.logger)
545
+
546
+ # Run a planning step if scheduled
547
+ if self.planning_interval is not None and (
548
+ self.step_number == 1 or (self.step_number - 1) % self.planning_interval == 0
549
+ ):
550
+ planning_start_time = time.time()
551
+ planning_step = None
552
+ for element in self._generate_planning_step(
553
+ task, is_first_step=len(self.memory.steps) == 1, step=self.step_number
554
+ ): # Don't use the attribute step_number here, because there can be steps from previous runs
555
+ yield element
556
+ planning_step = element
557
+ assert isinstance(planning_step, PlanningStep) # Last yielded element should be a PlanningStep
558
+ planning_end_time = time.time()
559
+ planning_step.timing = Timing(
560
+ start_time=planning_start_time,
561
+ end_time=planning_end_time,
562
+ )
563
+ self._finalize_step(planning_step)
564
+ self.memory.steps.append(planning_step)
565
+
566
+ # Start action step!
567
+ action_step_start_time = time.time()
568
+ action_step = ActionStep(
569
+ step_number=self.step_number,
570
+ timing=Timing(start_time=action_step_start_time),
571
+ observations_images=images,
572
+ )
573
+ self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
574
+ try:
575
+ for output in self._step_stream(action_step):
576
+ # Yield all
577
+ yield output
578
+
579
+ if isinstance(output, ActionOutput) and output.is_final_answer:
580
+ final_answer = output.output
581
+ self.logger.log(
582
+ Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"),
583
+ level=LogLevel.INFO,
584
+ )
585
+
586
+ if self.final_answer_checks:
587
+ self._validate_final_answer(final_answer)
588
+ returned_final_answer = True
589
+ action_step.is_final_answer = True
590
+
591
+ except AgentGenerationError as e:
592
+ # Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit.
593
+ raise e
594
+ except AgentError as e:
595
+ # Other AgentError types are caused by the Model, so we should log them and iterate.
596
+ action_step.error = e
597
+ finally:
598
+ self._finalize_step(action_step)
599
+ self.memory.steps.append(action_step)
600
+ yield action_step
601
+ self.step_number += 1
602
+
603
+ if not returned_final_answer and self.step_number == max_steps + 1:
604
+ final_answer = self._handle_max_steps_reached(task)
605
+ yield action_step
606
+ yield FinalAnswerStep(handle_agent_output_types(final_answer))
607
+
608
+ def _validate_final_answer(self, final_answer: Any):
609
+ for check_function in self.final_answer_checks:
610
+ try:
611
+ assert check_function(final_answer, self.memory)
612
+ except Exception as e:
613
+ raise AgentError(f"Check {check_function.__name__} failed with error: {e}", self.logger)
614
+
615
+ def _finalize_step(self, memory_step: ActionStep | PlanningStep):
616
+ memory_step.timing.end_time = time.time()
617
+ self.step_callbacks.callback(memory_step, agent=self)
618
+
619
+ def _handle_max_steps_reached(self, task: str) -> Any:
620
+ action_step_start_time = time.time()
621
+ final_answer = self.provide_final_answer(task)
622
+ final_memory_step = ActionStep(
623
+ step_number=self.step_number,
624
+ error=AgentMaxStepsError("Reached max steps.", self.logger),
625
+ timing=Timing(start_time=action_step_start_time, end_time=time.time()),
626
+ token_usage=final_answer.token_usage,
627
+ )
628
+ final_memory_step.action_output = final_answer.content
629
+ self._finalize_step(final_memory_step)
630
+ self.memory.steps.append(final_memory_step)
631
+ return final_answer.content
632
+
633
+ def _generate_planning_step(
634
+ self, task, is_first_step: bool, step: int
635
+ ) -> Generator[ChatMessageStreamDelta | PlanningStep]:
636
+ start_time = time.time()
637
+ if is_first_step:
638
+ input_messages = [
639
+ ChatMessage(
640
+ role=MessageRole.USER,
641
+ content=[
642
+ {
643
+ "type": "text",
644
+ "text": populate_template(
645
+ self.prompt_templates["planning"]["initial_plan"],
646
+ variables={"task": task, "tools": self.tools, "managed_agents": self.managed_agents},
647
+ ),
648
+ }
649
+ ],
650
+ )
651
+ ]
652
+ if self.stream_outputs and hasattr(self.model, "generate_stream"):
653
+ plan_message_content = ""
654
+ output_stream = self.model.generate_stream(input_messages, stop_sequences=["<end_plan>"]) # type: ignore
655
+ input_tokens, output_tokens = 0, 0
656
+ with Live("", console=self.logger.console, vertical_overflow="visible") as live:
657
+ for event in output_stream:
658
+ if event.content is not None:
659
+ plan_message_content += event.content
660
+ live.update(Markdown(plan_message_content))
661
+ if event.token_usage:
662
+ input_tokens = event.token_usage.input_tokens
663
+ output_tokens += event.token_usage.output_tokens
664
+ yield event
665
+ else:
666
+ plan_message = self.model.generate(input_messages, stop_sequences=["<end_plan>"])
667
+ plan_message_content = plan_message.content
668
+ input_tokens, output_tokens = 0, 0
669
+ if plan_message.token_usage:
670
+ input_tokens = plan_message.token_usage.input_tokens
671
+ output_tokens = plan_message.token_usage.output_tokens
672
+ plan = textwrap.dedent(
673
+ f"""Here are the facts I know and the plan of action that I will follow to solve the task:\n```\n{plan_message_content}\n```"""
674
+ )
675
+ else:
676
+ # Summary mode removes the system prompt and previous planning messages output by the model.
677
+ # Removing previous planning messages avoids influencing too much the new plan.
678
+ memory_messages = self.write_memory_to_messages(summary_mode=True)
679
+ plan_update_pre = ChatMessage(
680
+ role=MessageRole.SYSTEM,
681
+ content=[
682
+ {
683
+ "type": "text",
684
+ "text": populate_template(
685
+ self.prompt_templates["planning"]["update_plan_pre_messages"], variables={"task": task}
686
+ ),
687
+ }
688
+ ],
689
+ )
690
+ plan_update_post = ChatMessage(
691
+ role=MessageRole.USER,
692
+ content=[
693
+ {
694
+ "type": "text",
695
+ "text": populate_template(
696
+ self.prompt_templates["planning"]["update_plan_post_messages"],
697
+ variables={
698
+ "task": task,
699
+ "tools": self.tools,
700
+ "managed_agents": self.managed_agents,
701
+ "remaining_steps": (self.max_steps - step),
702
+ },
703
+ ),
704
+ }
705
+ ],
706
+ )
707
+ input_messages = [plan_update_pre] + memory_messages + [plan_update_post]
708
+ if self.stream_outputs and hasattr(self.model, "generate_stream"):
709
+ plan_message_content = ""
710
+ input_tokens, output_tokens = 0, 0
711
+ with Live("", console=self.logger.console, vertical_overflow="visible") as live:
712
+ for event in self.model.generate_stream(
713
+ input_messages,
714
+ stop_sequences=["<end_plan>"],
715
+ ): # type: ignore
716
+ if event.content is not None:
717
+ plan_message_content += event.content
718
+ live.update(Markdown(plan_message_content))
719
+ if event.token_usage:
720
+ input_tokens = event.token_usage.input_tokens
721
+ output_tokens += event.token_usage.output_tokens
722
+ yield event
723
+ else:
724
+ plan_message = self.model.generate(input_messages, stop_sequences=["<end_plan>"])
725
+ plan_message_content = plan_message.content
726
+ input_tokens, output_tokens = 0, 0
727
+ if plan_message.token_usage:
728
+ input_tokens = plan_message.token_usage.input_tokens
729
+ output_tokens = plan_message.token_usage.output_tokens
730
+ plan = textwrap.dedent(
731
+ f"""I still need to solve the task I was given:\n```\n{self.task}\n```\n\nHere are the facts I know and my new/updated plan of action to solve the task:\n```\n{plan_message_content}\n```"""
732
+ )
733
+ log_headline = "Initial plan" if is_first_step else "Updated plan"
734
+ self.logger.log(Rule(f"[bold]{log_headline}", style="orange"), Text(plan), level=LogLevel.INFO)
735
+ yield PlanningStep(
736
+ model_input_messages=input_messages,
737
+ plan=plan,
738
+ model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content=plan_message_content),
739
+ token_usage=TokenUsage(input_tokens=input_tokens, output_tokens=output_tokens),
740
+ timing=Timing(start_time=start_time, end_time=time.time()),
741
+ )
742
+
743
+ @abstractmethod
744
+ def initialize_system_prompt(self) -> str:
745
+ """To be implemented in child classes"""
746
+ ...
747
+
748
+ def interrupt(self):
749
+ """Interrupts the agent execution."""
750
+ self.interrupt_switch = True
751
+
752
+ def write_memory_to_messages(
753
+ self,
754
+ summary_mode: bool = False,
755
+ ) -> list[ChatMessage]:
756
+ """
757
+ Reads past llm_outputs, actions, and observations or errors from the memory into a series of messages
758
+ that can be used as input to the LLM. Adds a number of keywords (such as PLAN, error, etc) to help
759
+ the LLM.
760
+ """
761
+ messages = self.memory.system_prompt.to_messages(summary_mode=summary_mode)
762
+ for memory_step in self.memory.steps:
763
+ messages.extend(memory_step.to_messages(summary_mode=summary_mode))
764
+ return messages
765
+
766
+ def _step_stream(
767
+ self, memory_step: ActionStep
768
+ ) -> Generator[ChatMessageStreamDelta | ToolCall | ToolOutput | ActionOutput]:
769
+ """
770
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
771
+ Yields ChatMessageStreamDelta during the run if streaming is enabled.
772
+ At the end, yields either None if the step is not final, or the final answer.
773
+ """
774
+ raise NotImplementedError("This method should be implemented in child classes")
775
+
776
+ def step(self, memory_step: ActionStep) -> Any:
777
+ """
778
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
779
+ Returns either None if the step is not final, or the final answer.
780
+ """
781
+ return list(self._step_stream(memory_step))[-1]
782
+
783
+ def extract_action(self, model_output: str, split_token: str) -> tuple[str, str]:
784
+ """
785
+ Parse action from the LLM output
786
+
787
+ Args:
788
+ model_output (`str`): Output of the LLM
789
+ split_token (`str`): Separator for the action. Should match the example in the system prompt.
790
+ """
791
+ try:
792
+ split = model_output.split(split_token)
793
+ rationale, action = (
794
+ split[-2],
795
+ split[-1],
796
+ ) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
797
+ except Exception:
798
+ raise AgentParsingError(
799
+ f"No '{split_token}' token provided in your output.\nYour output:\n{model_output}\n. Be sure to include an action, prefaced with '{split_token}'!",
800
+ self.logger,
801
+ )
802
+ return rationale.strip(), action.strip()
803
+
804
+ def provide_final_answer(self, task: str) -> ChatMessage:
805
+ """
806
+ Provide the final answer to the task, based on the logs of the agent's interactions.
807
+
808
+ Args:
809
+ task (`str`): Task to perform.
810
+ images (`list[PIL.Image.Image]`, *optional*): Image(s) objects.
811
+
812
+ Returns:
813
+ `str`: Final answer to the task.
814
+ """
815
+ messages = [
816
+ ChatMessage(
817
+ role=MessageRole.SYSTEM,
818
+ content=[
819
+ {
820
+ "type": "text",
821
+ "text": self.prompt_templates["final_answer"]["pre_messages"],
822
+ }
823
+ ],
824
+ )
825
+ ]
826
+ messages += self.write_memory_to_messages()[1:]
827
+ messages.append(
828
+ ChatMessage(
829
+ role=MessageRole.USER,
830
+ content=[
831
+ {
832
+ "type": "text",
833
+ "text": populate_template(
834
+ self.prompt_templates["final_answer"]["post_messages"], variables={"task": task}
835
+ ),
836
+ }
837
+ ],
838
+ )
839
+ )
840
+ try:
841
+ chat_message: ChatMessage = self.model.generate(messages)
842
+ return chat_message
843
+ except Exception as e:
844
+ return ChatMessage(
845
+ role=MessageRole.ASSISTANT,
846
+ content=[{"type": "text", "text": f"Error in generating final LLM output: {e}"}],
847
+ )
848
+
849
+ def visualize(self):
850
+ """Creates a rich tree visualization of the agent's structure."""
851
+ self.logger.visualize_agent_tree(self)
852
+
853
+ def replay(self, detailed: bool = False):
854
+ """Prints a pretty replay of the agent's steps.
855
+
856
+ Args:
857
+ detailed (bool, optional): If True, also displays the memory at each step. Defaults to False.
858
+ Careful: will increase log length exponentially. Use only for debugging.
859
+ """
860
+ self.memory.replay(self.logger, detailed=detailed)
861
+
862
+ def __call__(self, task: str, **kwargs):
863
+ """Adds additional prompting for the managed agent, runs it, and wraps the output.
864
+ This method is called only by a managed agent.
865
+ """
866
+ full_task = populate_template(
867
+ self.prompt_templates["managed_agent"]["task"],
868
+ variables=dict(name=self.name, task=task),
869
+ )
870
+ result = self.run(full_task, **kwargs)
871
+ if isinstance(result, RunResult):
872
+ report = result.output
873
+ else:
874
+ report = result
875
+ answer = populate_template(
876
+ self.prompt_templates["managed_agent"]["report"], variables=dict(name=self.name, final_answer=report)
877
+ )
878
+ if self.provide_run_summary:
879
+ answer += "\n\nFor more detail, find below a summary of this agent's work:\n<summary_of_work>\n"
880
+ for message in self.write_memory_to_messages(summary_mode=True):
881
+ content = message.content
882
+ answer += "\n" + truncate_content(str(content)) + "\n---"
883
+ answer += "\n</summary_of_work>"
884
+ return answer
885
+
886
+ def save(self, output_dir: str | Path, relative_path: str | None = None):
887
+ """
888
+ Saves the relevant code files for your agent. This will copy the code of your agent in `output_dir` as well as autogenerate:
889
+
890
+ - a `tools` folder containing the logic for each of the tools under `tools/{tool_name}.py`.
891
+ - a `managed_agents` folder containing the logic for each of the managed agents.
892
+ - an `agent.json` file containing a dictionary representing your agent.
893
+ - a `prompt.yaml` file containing the prompt templates used by your agent.
894
+ - an `app.py` file providing a UI for your agent when it is exported to a Space with `agent.push_to_hub()`
895
+ - a `requirements.txt` containing the names of the modules used by your tool (as detected when inspecting its
896
+ code)
897
+
898
+ Args:
899
+ output_dir (`str` or `Path`): The folder in which you want to save your agent.
900
+ """
901
+ make_init_file(output_dir)
902
+
903
+ # Recursively save managed agents
904
+ if self.managed_agents:
905
+ make_init_file(os.path.join(output_dir, "managed_agents"))
906
+ for agent_name, agent in self.managed_agents.items():
907
+ agent_suffix = f"managed_agents.{agent_name}"
908
+ if relative_path:
909
+ agent_suffix = relative_path + "." + agent_suffix
910
+ agent.save(os.path.join(output_dir, "managed_agents", agent_name), relative_path=agent_suffix)
911
+
912
+ class_name = self.__class__.__name__
913
+
914
+ # Save tools to different .py files
915
+ for tool in self.tools.values():
916
+ make_init_file(os.path.join(output_dir, "tools"))
917
+ tool.save(os.path.join(output_dir, "tools"), tool_file_name=tool.name, make_gradio_app=False)
918
+
919
+ # Save prompts to yaml
920
+ yaml_prompts = yaml.safe_dump(
921
+ self.prompt_templates,
922
+ default_style="|", # This forces block literals for all strings
923
+ default_flow_style=False,
924
+ width=float("inf"),
925
+ sort_keys=False,
926
+ allow_unicode=True,
927
+ indent=2,
928
+ )
929
+
930
+ with open(os.path.join(output_dir, "prompts.yaml"), "w", encoding="utf-8") as f:
931
+ f.write(yaml_prompts)
932
+
933
+ # Save agent dictionary to json
934
+ agent_dict = self.to_dict()
935
+ agent_dict["tools"] = [tool.name for tool in self.tools.values()]
936
+ agent_dict["managed_agents"] = {agent.name: agent.__class__.__name__ for agent in self.managed_agents.values()}
937
+ with open(os.path.join(output_dir, "agent.json"), "w", encoding="utf-8") as f:
938
+ json.dump(agent_dict, f, indent=4)
939
+
940
+ # Save requirements
941
+ with open(os.path.join(output_dir, "requirements.txt"), "w", encoding="utf-8") as f:
942
+ f.writelines(f"{r}\n" for r in agent_dict["requirements"])
943
+
944
+ # Make agent.py file with Gradio UI
945
+ agent_name = f"agent_{self.name}" if getattr(self, "name", None) else "agent"
946
+ managed_agent_relative_path = relative_path + "." if relative_path is not None else ""
947
+ app_template = create_agent_gradio_app_template()
948
+
949
+ # Render the app.py file from Jinja2 template
950
+ app_text = app_template.render(
951
+ {
952
+ "agent_name": agent_name,
953
+ "class_name": class_name,
954
+ "agent_dict": agent_dict,
955
+ "tools": self.tools,
956
+ "managed_agents": self.managed_agents,
957
+ "managed_agent_relative_path": managed_agent_relative_path,
958
+ }
959
+ )
960
+
961
+ with open(os.path.join(output_dir, "app.py"), "w", encoding="utf-8") as f:
962
+ f.write(app_text + "\n") # Append newline at the end
963
+
964
+ def to_dict(self) -> dict[str, Any]:
965
+ """Convert the agent to a dictionary representation.
966
+
967
+ Returns:
968
+ `dict`: Dictionary representation of the agent.
969
+ """
970
+ # TODO: handle serializing step_callbacks and final_answer_checks
971
+ for attr in ["final_answer_checks", "step_callbacks"]:
972
+ if getattr(self, attr, None):
973
+ self.logger.log(f"This agent has {attr}: they will be ignored by this method.", LogLevel.INFO)
974
+
975
+ tool_dicts = [tool.to_dict() for tool in self.tools.values()]
976
+ tool_requirements = {req for tool in self.tools.values() for req in tool.to_dict()["requirements"]}
977
+ managed_agents_requirements = {
978
+ req for managed_agent in self.managed_agents.values() for req in managed_agent.to_dict()["requirements"]
979
+ }
980
+ requirements = tool_requirements | managed_agents_requirements
981
+ if hasattr(self, "authorized_imports"):
982
+ requirements.update(
983
+ {package.split(".")[0] for package in self.authorized_imports if package not in BASE_BUILTIN_MODULES}
984
+ )
985
+
986
+ agent_dict = {
987
+ "class": self.__class__.__name__,
988
+ "tools": tool_dicts,
989
+ "model": {
990
+ "class": self.model.__class__.__name__,
991
+ "data": self.model.to_dict(),
992
+ },
993
+ "managed_agents": [managed_agent.to_dict() for managed_agent in self.managed_agents.values()],
994
+ "prompt_templates": self.prompt_templates,
995
+ "max_steps": self.max_steps,
996
+ "verbosity_level": int(self.logger.level),
997
+ "planning_interval": self.planning_interval,
998
+ "name": self.name,
999
+ "description": self.description,
1000
+ "requirements": sorted(requirements),
1001
+ }
1002
+ return agent_dict
1003
+
1004
+ @classmethod
1005
+ def from_dict(cls, agent_dict: dict[str, Any], **kwargs) -> "MultiStepAgent":
1006
+ """Create agent from a dictionary representation.
1007
+
1008
+ Args:
1009
+ agent_dict (`dict[str, Any]`): Dictionary representation of the agent.
1010
+ **kwargs: Additional keyword arguments that will override agent_dict values.
1011
+
1012
+ Returns:
1013
+ `MultiStepAgent`: Instance of the agent class.
1014
+ """
1015
+ # Load model
1016
+ model_info = agent_dict["model"]
1017
+ model_class = getattr(importlib.import_module("smolagents.models"), model_info["class"])
1018
+ model = model_class.from_dict(model_info["data"])
1019
+ # Load tools
1020
+ tools = []
1021
+ for tool_info in agent_dict["tools"]:
1022
+ tools.append(Tool.from_code(tool_info["code"]))
1023
+ # Load managed agents
1024
+ managed_agents = []
1025
+ for managed_agent_dict in agent_dict["managed_agents"]:
1026
+ agent_class = getattr(importlib.import_module("smolagents.agents"), managed_agent_dict["class"])
1027
+ managed_agent = agent_class.from_dict(managed_agent_dict, **kwargs)
1028
+ managed_agents.append(managed_agent)
1029
+ # Extract base agent parameters
1030
+ agent_args = {
1031
+ "model": model,
1032
+ "tools": tools,
1033
+ "managed_agents": managed_agents,
1034
+ "prompt_templates": agent_dict.get("prompt_templates"),
1035
+ "max_steps": agent_dict.get("max_steps"),
1036
+ "verbosity_level": agent_dict.get("verbosity_level"),
1037
+ "planning_interval": agent_dict.get("planning_interval"),
1038
+ "name": agent_dict.get("name"),
1039
+ "description": agent_dict.get("description"),
1040
+ }
1041
+ # Filter out None values to use defaults from __init__
1042
+ agent_args = {k: v for k, v in agent_args.items() if v is not None}
1043
+ # Update with any additional kwargs
1044
+ agent_args.update(kwargs)
1045
+ # Create agent instance
1046
+ return cls(**agent_args)
1047
+
1048
+ @classmethod
1049
+ def from_hub(
1050
+ cls,
1051
+ repo_id: str,
1052
+ token: str | None = None,
1053
+ trust_remote_code: bool = False,
1054
+ **kwargs,
1055
+ ):
1056
+ """
1057
+ Loads an agent defined on the Hub.
1058
+
1059
+ <Tip warning={true}>
1060
+
1061
+ Loading a tool from the Hub means that you'll download the tool and execute it locally.
1062
+ ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
1063
+ installing a package using pip/npm/apt.
1064
+
1065
+ </Tip>
1066
+
1067
+ Args:
1068
+ repo_id (`str`):
1069
+ The name of the repo on the Hub where your tool is defined.
1070
+ token (`str`, *optional*):
1071
+ The token to identify you on hf.co. If unset, will use the token generated when running
1072
+ `huggingface-cli login` (stored in `~/.huggingface`).
1073
+ trust_remote_code(`bool`, *optional*, defaults to False):
1074
+ This flags marks that you understand the risk of running remote code and that you trust this tool.
1075
+ If not setting this to True, loading the tool from Hub will fail.
1076
+ kwargs (additional keyword arguments, *optional*):
1077
+ Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
1078
+ `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your agent, and the
1079
+ others will be passed along to its init.
1080
+ """
1081
+ if not trust_remote_code:
1082
+ raise ValueError(
1083
+ "Loading an agent from Hub requires to acknowledge you trust its code: to do so, pass `trust_remote_code=True`."
1084
+ )
1085
+
1086
+ # Get the agent's Hub folder.
1087
+ download_kwargs = {"token": token, "repo_type": "space"} | {
1088
+ key: kwargs.pop(key)
1089
+ for key in [
1090
+ "cache_dir",
1091
+ "force_download",
1092
+ "proxies",
1093
+ "revision",
1094
+ "local_files_only",
1095
+ ]
1096
+ if key in kwargs
1097
+ }
1098
+
1099
+ download_folder = Path(snapshot_download(repo_id=repo_id, **download_kwargs))
1100
+ return cls.from_folder(download_folder, **kwargs)
1101
+
1102
+ @classmethod
1103
+ def from_folder(cls, folder: str | Path, **kwargs):
1104
+ """Loads an agent from a local folder.
1105
+
1106
+ Args:
1107
+ folder (`str` or `Path`): The folder where the agent is saved.
1108
+ **kwargs: Additional keyword arguments that will be passed to the agent's init.
1109
+ """
1110
+ # Load agent.json
1111
+ folder = Path(folder)
1112
+ agent_dict = json.loads((folder / "agent.json").read_text())
1113
+
1114
+ # Load managed agents from their respective folders, recursively
1115
+ managed_agents = []
1116
+ for managed_agent_name, managed_agent_class_name in agent_dict["managed_agents"].items():
1117
+ agent_cls = getattr(importlib.import_module("smolagents.agents"), managed_agent_class_name)
1118
+ managed_agents.append(agent_cls.from_folder(folder / "managed_agents" / managed_agent_name))
1119
+ agent_dict["managed_agents"] = {}
1120
+
1121
+ # Load tools
1122
+ tools = []
1123
+ for tool_name in agent_dict["tools"]:
1124
+ tool_code = (folder / "tools" / f"{tool_name}.py").read_text()
1125
+ tools.append({"name": tool_name, "code": tool_code})
1126
+ agent_dict["tools"] = tools
1127
+
1128
+ # Add managed agents to kwargs to override the empty list in from_dict
1129
+ if managed_agents:
1130
+ kwargs["managed_agents"] = managed_agents
1131
+
1132
+ return cls.from_dict(agent_dict, **kwargs)
1133
+
1134
+ def push_to_hub(
1135
+ self,
1136
+ repo_id: str,
1137
+ commit_message: str = "Upload agent",
1138
+ private: bool | None = None,
1139
+ token: bool | str | None = None,
1140
+ create_pr: bool = False,
1141
+ ) -> str:
1142
+ """
1143
+ Upload the agent to the Hub.
1144
+
1145
+ Parameters:
1146
+ repo_id (`str`):
1147
+ The name of the repository you want to push to. It should contain your organization name when
1148
+ pushing to a given organization.
1149
+ commit_message (`str`, *optional*, defaults to `"Upload agent"`):
1150
+ Message to commit while pushing.
1151
+ private (`bool`, *optional*, defaults to `None`):
1152
+ Whether to make the repo private. If `None`, the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
1153
+ token (`bool` or `str`, *optional*):
1154
+ The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
1155
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
1156
+ create_pr (`bool`, *optional*, defaults to `False`):
1157
+ Whether to create a PR with the uploaded files or directly commit.
1158
+ """
1159
+ repo_url = create_repo(
1160
+ repo_id=repo_id,
1161
+ token=token,
1162
+ private=private,
1163
+ exist_ok=True,
1164
+ repo_type="space",
1165
+ space_sdk="gradio",
1166
+ )
1167
+ repo_id = repo_url.repo_id
1168
+ metadata_update(
1169
+ repo_id,
1170
+ {"tags": ["smolagents", "agent"]},
1171
+ repo_type="space",
1172
+ token=token,
1173
+ overwrite=True,
1174
+ )
1175
+
1176
+ with tempfile.TemporaryDirectory() as work_dir:
1177
+ self.save(work_dir)
1178
+ logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
1179
+ return upload_folder(
1180
+ repo_id=repo_id,
1181
+ commit_message=commit_message,
1182
+ folder_path=work_dir,
1183
+ token=token,
1184
+ create_pr=create_pr,
1185
+ repo_type="space",
1186
+ )
1187
+
1188
+
1189
+ class ToolCallingAgent(MultiStepAgent):
1190
+ """
1191
+ This agent uses JSON-like tool calls, using method `model.get_tool_call` to leverage the LLM engine's tool calling capabilities.
1192
+
1193
+ Args:
1194
+ tools (`list[Tool]`): [`Tool`]s that the agent can use.
1195
+ model (`Model`): Model that will generate the agent's actions.
1196
+ prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
1197
+ planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
1198
+ stream_outputs (`bool`, *optional*, default `False`): Whether to stream outputs during execution.
1199
+ max_tool_threads (`int`, *optional*): Maximum number of threads for parallel tool calls.
1200
+ Higher values increase concurrency but resource usage as well.
1201
+ Defaults to `ThreadPoolExecutor`'s default.
1202
+ **kwargs: Additional keyword arguments.
1203
+ """
1204
+
1205
+ def __init__(
1206
+ self,
1207
+ tools: list[Tool],
1208
+ model: Model,
1209
+ prompt_templates: PromptTemplates | None = None,
1210
+ planning_interval: int | None = None,
1211
+ stream_outputs: bool = False,
1212
+ max_tool_threads: int | None = None,
1213
+ **kwargs,
1214
+ ):
1215
+ prompt_templates = prompt_templates or yaml.safe_load(
1216
+ importlib.resources.files("smolagents.prompts").joinpath("toolcalling_agent.yaml").read_text()
1217
+ )
1218
+ super().__init__(
1219
+ tools=tools,
1220
+ model=model,
1221
+ prompt_templates=prompt_templates,
1222
+ planning_interval=planning_interval,
1223
+ **kwargs,
1224
+ )
1225
+ # Streaming setup
1226
+ self.stream_outputs = stream_outputs
1227
+ if self.stream_outputs and not hasattr(self.model, "generate_stream"):
1228
+ raise ValueError(
1229
+ "`stream_outputs` is set to True, but the model class implements no `generate_stream` method."
1230
+ )
1231
+ # Tool calling setup
1232
+ self.max_tool_threads = max_tool_threads
1233
+
1234
+ @property
1235
+ def tools_and_managed_agents(self):
1236
+ """Returns a combined list of tools and managed agents."""
1237
+ return list(self.tools.values()) + list(self.managed_agents.values())
1238
+
1239
+ def initialize_system_prompt(self) -> str:
1240
+ system_prompt = populate_template(
1241
+ self.prompt_templates["system_prompt"],
1242
+ variables={
1243
+ "tools": self.tools,
1244
+ "managed_agents": self.managed_agents,
1245
+ "custom_instructions": self.instructions,
1246
+ },
1247
+ )
1248
+ return system_prompt
1249
+
1250
+ def _step_stream(
1251
+ self, memory_step: ActionStep
1252
+ ) -> Generator[ChatMessageStreamDelta | ToolCall | ToolOutput | ActionOutput]:
1253
+ """
1254
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
1255
+ Yields ChatMessageStreamDelta during the run if streaming is enabled.
1256
+ At the end, yields either None if the step is not final, or the final answer.
1257
+ """
1258
+ memory_messages = self.write_memory_to_messages()
1259
+
1260
+ input_messages = memory_messages.copy()
1261
+
1262
+ # Add new step in logs
1263
+ memory_step.model_input_messages = input_messages
1264
+
1265
+ try:
1266
+ if self.stream_outputs and hasattr(self.model, "generate_stream"):
1267
+ output_stream = self.model.generate_stream(
1268
+ input_messages,
1269
+ stop_sequences=["Observation:", "Calling tools:"],
1270
+ tools_to_call_from=self.tools_and_managed_agents,
1271
+ )
1272
+
1273
+ chat_message_stream_deltas: list[ChatMessageStreamDelta] = []
1274
+ with Live("", console=self.logger.console, vertical_overflow="visible") as live:
1275
+ for event in output_stream:
1276
+ chat_message_stream_deltas.append(event)
1277
+ live.update(
1278
+ Markdown(agglomerate_stream_deltas(chat_message_stream_deltas).render_as_markdown())
1279
+ )
1280
+ yield event
1281
+ chat_message = agglomerate_stream_deltas(chat_message_stream_deltas)
1282
+ else:
1283
+ chat_message: ChatMessage = self.model.generate(
1284
+ input_messages,
1285
+ stop_sequences=["Observation:", "Calling tools:"],
1286
+ tools_to_call_from=self.tools_and_managed_agents,
1287
+ )
1288
+ if chat_message.content is None and chat_message.raw is not None:
1289
+ log_content = str(chat_message.raw)
1290
+ else:
1291
+ log_content = str(chat_message.content) or ""
1292
+
1293
+ self.logger.log_markdown(
1294
+ content=log_content,
1295
+ title="Output message of the LLM:",
1296
+ level=LogLevel.DEBUG,
1297
+ )
1298
+
1299
+ # Record model output
1300
+ memory_step.model_output_message = chat_message
1301
+ memory_step.model_output = chat_message.content
1302
+ memory_step.token_usage = chat_message.token_usage
1303
+ except Exception as e:
1304
+ raise AgentGenerationError(f"Error while generating output:\n{e}", self.logger) from e
1305
+
1306
+ if chat_message.tool_calls is None or len(chat_message.tool_calls) == 0:
1307
+ try:
1308
+ chat_message = self.model.parse_tool_calls(chat_message)
1309
+ except Exception as e:
1310
+ raise AgentParsingError(f"Error while parsing tool call from model output: {e}", self.logger)
1311
+ else:
1312
+ for tool_call in chat_message.tool_calls:
1313
+ tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
1314
+ final_answer, got_final_answer = None, False
1315
+ for output in self.process_tool_calls(chat_message, memory_step):
1316
+ yield output
1317
+ if isinstance(output, ToolOutput):
1318
+ if output.is_final_answer:
1319
+ if len(chat_message.tool_calls) > 1:
1320
+ raise AgentExecutionError(
1321
+ "If you want to return an answer, please do not perform any other tool calls than the final answer tool call!",
1322
+ self.logger,
1323
+ )
1324
+ if got_final_answer:
1325
+ raise AgentToolExecutionError(
1326
+ "You returned multiple final answers. Please return only one single final answer!",
1327
+ self.logger,
1328
+ )
1329
+ final_answer = output.output
1330
+ got_final_answer = True
1331
+
1332
+ # Manage state variables
1333
+ if isinstance(final_answer, str) and final_answer in self.state.keys():
1334
+ final_answer = self.state[final_answer]
1335
+ yield ActionOutput(
1336
+ output=final_answer,
1337
+ is_final_answer=got_final_answer,
1338
+ )
1339
+
1340
+ def process_tool_calls(
1341
+ self, chat_message: ChatMessage, memory_step: ActionStep
1342
+ ) -> Generator[ToolCall | ToolOutput]:
1343
+ """Process tool calls from the model output and update agent memory.
1344
+
1345
+ Args:
1346
+ chat_message (`ChatMessage`): Chat message containing tool calls from the model.
1347
+ memory_step (`ActionStep)`: Memory ActionStep to update with results.
1348
+
1349
+ Yields:
1350
+ `ToolCall | ToolOutput`: The tool call or tool output.
1351
+ """
1352
+ parallel_calls: dict[str, ToolCall] = {}
1353
+ assert chat_message.tool_calls is not None
1354
+ for chat_tool_call in chat_message.tool_calls:
1355
+ tool_call = ToolCall(
1356
+ name=chat_tool_call.function.name, arguments=chat_tool_call.function.arguments, id=chat_tool_call.id
1357
+ )
1358
+ yield tool_call
1359
+ parallel_calls[tool_call.id] = tool_call
1360
+
1361
+ # Helper function to process a single tool call
1362
+ def process_single_tool_call(tool_call: ToolCall) -> ToolOutput:
1363
+ tool_name = tool_call.name
1364
+ tool_arguments = tool_call.arguments or {}
1365
+ self.logger.log(
1366
+ Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")),
1367
+ level=LogLevel.INFO,
1368
+ )
1369
+ tool_call_result = self.execute_tool_call(tool_name, tool_arguments)
1370
+ tool_call_result_type = type(tool_call_result)
1371
+ if tool_call_result_type in [AgentImage, AgentAudio]:
1372
+ if tool_call_result_type == AgentImage:
1373
+ observation_name = "image.png"
1374
+ elif tool_call_result_type == AgentAudio:
1375
+ observation_name = "audio.mp3"
1376
+ # TODO: tool_call_result naming could allow for different names of same type
1377
+ self.state[observation_name] = tool_call_result
1378
+ observation = f"Stored '{observation_name}' in memory."
1379
+ else:
1380
+ observation = str(tool_call_result).strip()
1381
+ self.logger.log(
1382
+ f"Observations: {observation.replace('[', '|')}", # escape potential rich-tag-like components
1383
+ level=LogLevel.INFO,
1384
+ )
1385
+ is_final_answer = tool_name == "final_answer"
1386
+
1387
+ return ToolOutput(
1388
+ id=tool_call.id,
1389
+ output=tool_call_result,
1390
+ is_final_answer=is_final_answer,
1391
+ observation=observation,
1392
+ tool_call=tool_call,
1393
+ )
1394
+
1395
+ # Process tool calls in parallel
1396
+ outputs = {}
1397
+ if len(parallel_calls) == 1:
1398
+ # If there's only one call, process it directly
1399
+ tool_call = list(parallel_calls.values())[0]
1400
+ tool_output = process_single_tool_call(tool_call)
1401
+ outputs[tool_output.id] = tool_output
1402
+ yield tool_output
1403
+ else:
1404
+ # If multiple tool calls, process them in parallel
1405
+ with ThreadPoolExecutor(self.max_tool_threads) as executor:
1406
+ futures = [
1407
+ executor.submit(process_single_tool_call, tool_call) for tool_call in parallel_calls.values()
1408
+ ]
1409
+ for future in as_completed(futures):
1410
+ tool_output = future.result()
1411
+ outputs[tool_output.id] = tool_output
1412
+ yield tool_output
1413
+
1414
+ memory_step.tool_calls = [parallel_calls[k] for k in sorted(parallel_calls.keys())]
1415
+ memory_step.observations = memory_step.observations or ""
1416
+ for tool_output in [outputs[k] for k in sorted(outputs.keys())]:
1417
+ memory_step.observations += tool_output.observation + "\n"
1418
+ memory_step.observations = (
1419
+ memory_step.observations.rstrip("\n") if memory_step.observations else memory_step.observations
1420
+ )
1421
+
1422
+ def _substitute_state_variables(self, arguments: dict[str, str] | str) -> dict[str, Any] | str:
1423
+ """Replace string values in arguments with their corresponding state values if they exist."""
1424
+ if isinstance(arguments, dict):
1425
+ return {
1426
+ key: self.state.get(value, value) if isinstance(value, str) else value
1427
+ for key, value in arguments.items()
1428
+ }
1429
+ return arguments
1430
+
1431
+ def execute_tool_call(self, tool_name: str, arguments: dict[str, str] | str) -> Any:
1432
+ """
1433
+ Execute a tool or managed agent with the provided arguments.
1434
+
1435
+ The arguments are replaced with the actual values from the state if they refer to state variables.
1436
+
1437
+ Args:
1438
+ tool_name (`str`): Name of the tool or managed agent to execute.
1439
+ arguments (dict[str, str] | str): Arguments passed to the tool call.
1440
+ """
1441
+ # Check if the tool exists
1442
+ available_tools = {**self.tools, **self.managed_agents}
1443
+ if tool_name not in available_tools:
1444
+ raise AgentToolExecutionError(
1445
+ f"Unknown tool {tool_name}, should be one of: {', '.join(available_tools)}.", self.logger
1446
+ )
1447
+
1448
+ # Get the tool and substitute state variables in arguments
1449
+ tool = available_tools[tool_name]
1450
+ arguments = self._substitute_state_variables(arguments)
1451
+ is_managed_agent = tool_name in self.managed_agents
1452
+
1453
+ try:
1454
+ validate_tool_arguments(tool, arguments)
1455
+ except (ValueError, TypeError) as e:
1456
+ raise AgentToolCallError(str(e), self.logger) from e
1457
+ except Exception as e:
1458
+ error_msg = f"Error executing tool '{tool_name}' with arguments {str(arguments)}: {type(e).__name__}: {e}"
1459
+ raise AgentToolExecutionError(error_msg, self.logger) from e
1460
+
1461
+ try:
1462
+ # Call tool with appropriate arguments
1463
+ if isinstance(arguments, dict):
1464
+ return tool(**arguments) if is_managed_agent else tool(**arguments, sanitize_inputs_outputs=True)
1465
+ else:
1466
+ return tool(arguments) if is_managed_agent else tool(arguments, sanitize_inputs_outputs=True)
1467
+
1468
+ except Exception as e:
1469
+ # Handle execution errors
1470
+ if is_managed_agent:
1471
+ error_msg = (
1472
+ f"Error executing request to team member '{tool_name}' with arguments {str(arguments)}: {e}\n"
1473
+ "Please try again or request to another team member"
1474
+ )
1475
+ else:
1476
+ error_msg = (
1477
+ f"Error executing tool '{tool_name}' with arguments {str(arguments)}: {type(e).__name__}: {e}\n"
1478
+ "Please try again or use another tool"
1479
+ )
1480
+ raise AgentToolExecutionError(error_msg, self.logger) from e
1481
+
1482
+
1483
+ class CodeAgent(MultiStepAgent):
1484
+ """
1485
+ In this agent, the tool calls will be formulated by the LLM in code format, then parsed and executed.
1486
+
1487
+ Args:
1488
+ tools (`list[Tool]`): [`Tool`]s that the agent can use.
1489
+ model (`Model`): Model that will generate the agent's actions.
1490
+ prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
1491
+ additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent.
1492
+ planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
1493
+ executor_type (`Literal["local", "e2b", "modal", "docker", "wasm"]`, default `"local"`): Type of code executor.
1494
+ executor_kwargs (`dict`, *optional*): Additional arguments to pass to initialize the executor.
1495
+ max_print_outputs_length (`int`, *optional*): Maximum length of the print outputs.
1496
+ stream_outputs (`bool`, *optional*, default `False`): Whether to stream outputs during execution.
1497
+ use_structured_outputs_internally (`bool`, default `False`): Whether to use structured generation at each action step: improves performance for many models.
1498
+
1499
+ <Added version="1.17.0"/>
1500
+ code_block_tags (`tuple[str, str]` | `Literal["markdown"]`, *optional*): Opening and closing tags for code blocks (regex strings). Pass a custom tuple, or pass 'markdown' to use ("```(?:python|py)", "\\n```"), leave empty to use ("<code>", "</code>").
1501
+ **kwargs: Additional keyword arguments.
1502
+ """
1503
+
1504
+ def __init__(
1505
+ self,
1506
+ tools: list[Tool],
1507
+ model: Model,
1508
+ prompt_templates: PromptTemplates | None = None,
1509
+ additional_authorized_imports: list[str] | None = None,
1510
+ planning_interval: int | None = None,
1511
+ executor_type: Literal["local", "e2b", "modal", "docker", "wasm"] = "local",
1512
+ executor_kwargs: dict[str, Any] | None = None,
1513
+ max_print_outputs_length: int | None = None,
1514
+ stream_outputs: bool = False,
1515
+ use_structured_outputs_internally: bool = False,
1516
+ code_block_tags: str | tuple[str, str] | None = None,
1517
+ **kwargs,
1518
+ ):
1519
+ self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
1520
+ self.authorized_imports = sorted(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
1521
+ self.max_print_outputs_length = max_print_outputs_length
1522
+ self._use_structured_outputs_internally = use_structured_outputs_internally
1523
+ if self._use_structured_outputs_internally:
1524
+ prompt_templates = prompt_templates or yaml.safe_load(
1525
+ importlib.resources.files("smolagents.prompts").joinpath("structured_code_agent.yaml").read_text()
1526
+ )
1527
+ else:
1528
+ prompt_templates = prompt_templates or yaml.safe_load(
1529
+ importlib.resources.files("smolagents.prompts").joinpath("code_agent.yaml").read_text()
1530
+ )
1531
+
1532
+ if isinstance(code_block_tags, str) and not code_block_tags == "markdown":
1533
+ raise ValueError("Only 'markdown' is supported for a string argument to `code_block_tags`.")
1534
+ self.code_block_tags = (
1535
+ code_block_tags
1536
+ if isinstance(code_block_tags, tuple)
1537
+ else ("```python", "```")
1538
+ if code_block_tags == "markdown"
1539
+ else ("<code>", "</code>")
1540
+ )
1541
+
1542
+ super().__init__(
1543
+ tools=tools,
1544
+ model=model,
1545
+ prompt_templates=prompt_templates,
1546
+ planning_interval=planning_interval,
1547
+ **kwargs,
1548
+ )
1549
+ self.stream_outputs = stream_outputs
1550
+ if self.stream_outputs and not hasattr(self.model, "generate_stream"):
1551
+ raise ValueError(
1552
+ "`stream_outputs` is set to True, but the model class implements no `generate_stream` method."
1553
+ )
1554
+ if "*" in self.additional_authorized_imports:
1555
+ self.logger.log(
1556
+ "Caution: you set an authorization for all imports, meaning your agent can decide to import any package it deems necessary. This might raise issues if the package is not installed in your environment.",
1557
+ level=LogLevel.INFO,
1558
+ )
1559
+ if executor_type not in {"local", "e2b", "modal", "docker", "wasm"}:
1560
+ raise ValueError(f"Unsupported executor type: {executor_type}")
1561
+ self.executor_type = executor_type
1562
+ self.executor_kwargs: dict[str, Any] = executor_kwargs or {}
1563
+ self.python_executor = self.create_python_executor()
1564
+
1565
+ def __enter__(self):
1566
+ return self
1567
+
1568
+ def __exit__(self, exc_type, exc_value, traceback):
1569
+ self.cleanup()
1570
+
1571
+ def cleanup(self):
1572
+ """Clean up resources used by the agent, such as the remote Python executor."""
1573
+ if hasattr(self.python_executor, "cleanup"):
1574
+ self.python_executor.cleanup()
1575
+
1576
+ def create_python_executor(self) -> PythonExecutor:
1577
+ if self.executor_type == "local":
1578
+ return LocalPythonExecutor(
1579
+ self.additional_authorized_imports,
1580
+ **{"max_print_outputs_length": self.max_print_outputs_length} | self.executor_kwargs,
1581
+ )
1582
+ else:
1583
+ if self.managed_agents:
1584
+ raise Exception("Managed agents are not yet supported with remote code execution.")
1585
+ remote_executors = {
1586
+ "e2b": E2BExecutor,
1587
+ "docker": DockerExecutor,
1588
+ "wasm": WasmExecutor,
1589
+ "modal": ModalExecutor,
1590
+ }
1591
+ return remote_executors[self.executor_type](
1592
+ self.additional_authorized_imports, self.logger, **self.executor_kwargs
1593
+ )
1594
+
1595
+ def initialize_system_prompt(self) -> str:
1596
+ system_prompt = populate_template(
1597
+ self.prompt_templates["system_prompt"],
1598
+ variables={
1599
+ "tools": self.tools,
1600
+ "managed_agents": self.managed_agents,
1601
+ "authorized_imports": (
1602
+ "You can import from any package you want."
1603
+ if "*" in self.authorized_imports
1604
+ else str(self.authorized_imports)
1605
+ ),
1606
+ "custom_instructions": self.instructions,
1607
+ "code_block_opening_tag": self.code_block_tags[0],
1608
+ "code_block_closing_tag": self.code_block_tags[1],
1609
+ },
1610
+ )
1611
+ return system_prompt
1612
+
1613
+ def _step_stream(
1614
+ self, memory_step: ActionStep
1615
+ ) -> Generator[ChatMessageStreamDelta | ToolCall | ToolOutput | ActionOutput]:
1616
+ """
1617
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
1618
+ Yields ChatMessageStreamDelta during the run if streaming is enabled.
1619
+ At the end, yields either None if the step is not final, or the final answer.
1620
+ """
1621
+ memory_messages = self.write_memory_to_messages()
1622
+
1623
+ input_messages = memory_messages.copy()
1624
+ ### Generate model output ###
1625
+ memory_step.model_input_messages = input_messages
1626
+ stop_sequences = ["Observation:", "Calling tools:"]
1627
+ if self.code_block_tags[1] not in self.code_block_tags[0]:
1628
+ # If the closing tag is contained in the opening tag, adding it as a stop sequence would cut short any code generation
1629
+ stop_sequences.append(self.code_block_tags[1])
1630
+ try:
1631
+ additional_args: dict[str, Any] = {}
1632
+ if self._use_structured_outputs_internally:
1633
+ additional_args["response_format"] = CODEAGENT_RESPONSE_FORMAT
1634
+ if self.stream_outputs:
1635
+ output_stream = self.model.generate_stream(
1636
+ input_messages,
1637
+ stop_sequences=stop_sequences,
1638
+ **additional_args,
1639
+ )
1640
+ chat_message_stream_deltas: list[ChatMessageStreamDelta] = []
1641
+ with Live("", console=self.logger.console, vertical_overflow="visible") as live:
1642
+ for event in output_stream:
1643
+ chat_message_stream_deltas.append(event)
1644
+ live.update(
1645
+ Markdown(agglomerate_stream_deltas(chat_message_stream_deltas).render_as_markdown())
1646
+ )
1647
+ yield event
1648
+ chat_message = agglomerate_stream_deltas(chat_message_stream_deltas)
1649
+ memory_step.model_output_message = chat_message
1650
+ output_text = chat_message.content
1651
+ else:
1652
+ chat_message: ChatMessage = self.model.generate(
1653
+ input_messages,
1654
+ stop_sequences=stop_sequences,
1655
+ **additional_args,
1656
+ )
1657
+ memory_step.model_output_message = chat_message
1658
+ output_text = chat_message.content
1659
+ self.logger.log_markdown(
1660
+ content=output_text or "",
1661
+ title="Output message of the LLM:",
1662
+ level=LogLevel.DEBUG,
1663
+ )
1664
+
1665
+ if not self._use_structured_outputs_internally:
1666
+ # This adds the end code sequence (i.e. the closing code block tag) to the history.
1667
+ # This will nudge subsequent LLM calls to finish with this end code sequence, thus efficiently stopping generation.
1668
+ if output_text and not output_text.strip().endswith(self.code_block_tags[1]):
1669
+ output_text += self.code_block_tags[1]
1670
+ memory_step.model_output_message.content = output_text
1671
+
1672
+ memory_step.token_usage = chat_message.token_usage
1673
+ memory_step.model_output = output_text
1674
+ except Exception as e:
1675
+ raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) from e
1676
+
1677
+ ### Parse output ###
1678
+ try:
1679
+ if self._use_structured_outputs_internally:
1680
+ code_action = json.loads(output_text)["code"]
1681
+ code_action = extract_code_from_text(code_action, self.code_block_tags) or code_action
1682
+ else:
1683
+ code_action = parse_code_blobs(output_text, self.code_block_tags)
1684
+ code_action = fix_final_answer_code(code_action)
1685
+ memory_step.code_action = code_action
1686
+ except Exception as e:
1687
+ error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
1688
+ raise AgentParsingError(error_msg, self.logger)
1689
+
1690
+ tool_call = ToolCall(
1691
+ name="python_interpreter",
1692
+ arguments=code_action,
1693
+ id=f"call_{len(self.memory.steps)}",
1694
+ )
1695
+ yield tool_call
1696
+ memory_step.tool_calls = [tool_call]
1697
+
1698
+ ### Execute action ###
1699
+ self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO)
1700
+ try:
1701
+ code_output = self.python_executor(code_action)
1702
+ execution_outputs_console = []
1703
+ if len(code_output.logs) > 0:
1704
+ execution_outputs_console += [
1705
+ Text("Execution logs:", style="bold"),
1706
+ Text(code_output.logs),
1707
+ ]
1708
+ observation = "Execution logs:\n" + code_output.logs
1709
+ except Exception as e:
1710
+ if hasattr(self.python_executor, "state") and "_print_outputs" in self.python_executor.state:
1711
+ execution_logs = str(self.python_executor.state["_print_outputs"])
1712
+ if len(execution_logs) > 0:
1713
+ execution_outputs_console = [
1714
+ Text("Execution logs:", style="bold"),
1715
+ Text(execution_logs),
1716
+ ]
1717
+ memory_step.observations = "Execution logs:\n" + execution_logs
1718
+ self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
1719
+ error_msg = str(e)
1720
+ if "Import of " in error_msg and " is not allowed" in error_msg:
1721
+ self.logger.log(
1722
+ "[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
1723
+ level=LogLevel.INFO,
1724
+ )
1725
+ raise AgentExecutionError(error_msg, self.logger)
1726
+
1727
+ truncated_output = truncate_content(str(code_output.output))
1728
+ observation += "Last output from code snippet:\n" + truncated_output
1729
+ memory_step.observations = observation
1730
+
1731
+ if not code_output.is_final_answer:
1732
+ execution_outputs_console += [
1733
+ Text(
1734
+ f"Out: {truncated_output}",
1735
+ ),
1736
+ ]
1737
+ self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
1738
+ memory_step.action_output = code_output.output
1739
+ yield ActionOutput(output=code_output.output, is_final_answer=code_output.is_final_answer)
1740
+
1741
+ def to_dict(self) -> dict[str, Any]:
1742
+ """Convert the agent to a dictionary representation.
1743
+
1744
+ Returns:
1745
+ `dict`: Dictionary representation of the agent.
1746
+ """
1747
+ agent_dict = super().to_dict()
1748
+ agent_dict["authorized_imports"] = self.authorized_imports
1749
+ agent_dict["executor_type"] = self.executor_type
1750
+ agent_dict["executor_kwargs"] = self.executor_kwargs
1751
+ agent_dict["max_print_outputs_length"] = self.max_print_outputs_length
1752
+ return agent_dict
1753
+
1754
+ @classmethod
1755
+ def from_dict(cls, agent_dict: dict[str, Any], **kwargs) -> "CodeAgent":
1756
+ """Create CodeAgent from a dictionary representation.
1757
+
1758
+ Args:
1759
+ agent_dict (`dict[str, Any]`): Dictionary representation of the agent.
1760
+ **kwargs: Additional keyword arguments that will override agent_dict values.
1761
+
1762
+ Returns:
1763
+ `CodeAgent`: Instance of the CodeAgent class.
1764
+ """
1765
+ # Add CodeAgent-specific parameters to kwargs
1766
+ code_agent_kwargs = {
1767
+ "additional_authorized_imports": agent_dict.get("authorized_imports"),
1768
+ "executor_type": agent_dict.get("executor_type"),
1769
+ "executor_kwargs": agent_dict.get("executor_kwargs"),
1770
+ "max_print_outputs_length": agent_dict.get("max_print_outputs_length"),
1771
+ "code_block_tags": agent_dict.get("code_block_tags"),
1772
+ }
1773
+ # Filter out None values
1774
+ code_agent_kwargs = {k: v for k, v in code_agent_kwargs.items() if v is not None}
1775
+ # Update with any additional kwargs
1776
+ code_agent_kwargs.update(kwargs)
1777
+ # Call the parent class's from_dict method
1778
+ return super().from_dict(agent_dict, **code_agent_kwargs)