#!/usr/bin/env python # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import base64 import importlib.util import inspect import json import keyword import os import random import re import time from functools import lru_cache from io import BytesIO from logging import Logger from pathlib import Path from textwrap import dedent from typing import TYPE_CHECKING, Any, Callable import jinja2 if TYPE_CHECKING: from smolagents.memory import AgentLogger __all__ = ["AgentError"] @lru_cache def _is_package_available(package_name: str) -> bool: return importlib.util.find_spec(package_name) is not None BASE_BUILTIN_MODULES = [ "collections", "datetime", "itertools", "math", "queue", "random", "re", "stat", "statistics", "time", "unicodedata", ] def sanitize_for_rich(value) -> str: """ Convert arbitrary values (including bytes / control characters) into a safe string for Rich. - Decodes bytes-like inputs using UTF-8 with replacement. - Escapes bracket sequences that could be interpreted as markup while preserving valid Rich tags. - Replaces ASCII control characters (except common whitespace) with visible escape sequences. """ if value is None: s = "" elif isinstance(value, str): s = value elif isinstance(value, (bytes, bytearray, memoryview)): s = bytes(value).decode("utf-8", errors="replace") else: s = str(value) out: list[str] = [] for ch in s: code = ord(ch) if ch in ("\n", "\t", "\r"): out.append(ch) elif code < 32 or code == 127: out.append(f"\\x{code:02x}") else: out.append(ch) return "".join(out) class AgentError(Exception): """Base class for other agent-related exceptions""" def __init__(self, message, logger: "AgentLogger"): super().__init__(message) self.message = message logger.log_error(message) def dict(self) -> dict[str, str]: return {"type": self.__class__.__name__, "message": str(self.message)} class AgentParsingError(AgentError): """Exception raised for errors in parsing in the agent""" pass class AgentExecutionError(AgentError): """Exception raised for errors in execution in the agent""" pass class AgentMaxStepsError(AgentError): """Exception raised for errors in execution in the agent""" pass class AgentToolCallError(AgentExecutionError): """Exception raised for errors when incorrect arguments are passed to the tool""" pass class AgentToolExecutionError(AgentExecutionError): """Exception raised for errors when executing a tool""" pass class AgentGenerationError(AgentError): """Exception raised for errors in generation in the agent""" pass def make_json_serializable(obj: Any) -> Any: """Recursive function to make objects JSON serializable""" if obj is None: return None elif isinstance(obj, (str, int, float, bool)): # Try to parse string as JSON if it looks like a JSON object/array if isinstance(obj, str): try: if (obj.startswith("{") and obj.endswith("}")) or (obj.startswith("[") and obj.endswith("]")): parsed = json.loads(obj) return make_json_serializable(parsed) except json.JSONDecodeError: pass return obj elif isinstance(obj, (list, tuple)): return [make_json_serializable(item) for item in obj] elif isinstance(obj, dict): return {str(k): make_json_serializable(v) for k, v in obj.items()} elif hasattr(obj, "__dict__"): # For custom objects, convert their __dict__ to a serializable format return {"_type": obj.__class__.__name__, **{k: make_json_serializable(v) for k, v in obj.__dict__.items()}} else: # For any other type, convert to string return str(obj) def parse_json_blob(json_blob: str) -> tuple[dict[str, str], str]: "Extracts the JSON blob from the input and returns the JSON data and the rest of the input." try: first_accolade_index = json_blob.find("{") last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1] json_str = json_blob[first_accolade_index : last_accolade_index + 1] json_data = json.loads(json_str, strict=False) return json_data, json_blob[:first_accolade_index] except IndexError: raise ValueError("The model output does not contain any JSON blob.") except json.JSONDecodeError as e: place = e.pos if json_blob[place - 1 : place + 2] == "},\n": raise ValueError( "JSON is invalid: you probably tried to provide multiple tool calls in one action. PROVIDE ONLY ONE TOOL CALL." ) raise ValueError( f"The JSON blob you used is invalid due to the following error: {e}.\n" f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n" f"'{json_blob[place - 4 : place + 5]}'." ) def extract_code_from_text(text: str, code_block_tags: tuple[str, str]) -> str | None: """Extract code from the LLM's output.""" pattern = rf"{code_block_tags[0]}(.*?){code_block_tags[1]}" matches = re.findall(pattern, text, re.DOTALL) if matches: return "\n\n".join(match.strip() for match in matches) return None def parse_code_blobs(text: str, code_block_tags: tuple[str, str]) -> str: """Extract code blocs from the LLM's output. If a valid code block is passed, it returns it directly. Args: text (`str`): LLM's output text to parse. Returns: `str`: Extracted code block. Raises: ValueError: If no valid code block is found in the text. """ matches = extract_code_from_text(text, code_block_tags) if not matches: # Fallback to markdown pattern matches = extract_code_from_text(text, ("```(?:python|py)", "\n```")) if matches: return matches # Maybe the LLM outputted a code blob directly try: ast.parse(text) return text except SyntaxError: pass if "final" in text and "answer" in text: raise ValueError( dedent( f""" Your code snippet is invalid, because the regex pattern {code_block_tags[0]}(.*?){code_block_tags[1]} was not found in it. Here is your code snippet: {text} It seems like you're trying to return the final answer, you can do it as follows: {code_block_tags[0]} final_answer("YOUR FINAL ANSWER HERE") {code_block_tags[1]} """ ).strip() ) raise ValueError( dedent( f""" Your code snippet is invalid, because the regex pattern {code_block_tags[0]}(.*?){code_block_tags[1]} was not found in it. Here is your code snippet: {text} Make sure to include code with the correct pattern, for instance: Thoughts: Your thoughts {code_block_tags[0]} # Your python code here {code_block_tags[1]} """ ).strip() ) MAX_LENGTH_TRUNCATE_CONTENT = 20000 def truncate_content(content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT) -> str: if len(content) <= max_length: return content else: return ( content[: max_length // 2] + f"\n..._This content has been truncated to stay below {max_length} characters_...\n" + content[-max_length // 2 :] ) class ImportFinder(ast.NodeVisitor): def __init__(self): self.packages = set() def visit_Import(self, node): for alias in node.names: # Get the base package name (before any dots) base_package = alias.name.split(".")[0] self.packages.add(base_package) def visit_ImportFrom(self, node): if node.module: # for "from x import y" statements # Get the base package name (before any dots) base_package = node.module.split(".")[0] self.packages.add(base_package) def instance_to_source(instance, base_cls=None): """Convert an instance to its class source code representation.""" cls = instance.__class__ class_name = cls.__name__ # Start building class lines class_lines = [] if base_cls: class_lines.append(f"class {class_name}({base_cls.__name__}):") else: class_lines.append(f"class {class_name}:") # Add docstring if it exists and differs from base if cls.__doc__ and (not base_cls or cls.__doc__ != base_cls.__doc__): class_lines.append(f' """{cls.__doc__}"""') # Add class-level attributes class_attrs = { name: value for name, value in cls.__dict__.items() if not name.startswith("__") and not name == "_abc_impl" and not callable(value) and not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value) } for name, value in class_attrs.items(): if isinstance(value, str): # multiline value if "\n" in value: escaped_value = value.replace('"""', r"\"\"\"") # Escape triple quotes class_lines.append(f' {name} = """{escaped_value}"""') else: class_lines.append(f" {name} = {json.dumps(value)}") else: class_lines.append(f" {name} = {repr(value)}") if class_attrs: class_lines.append("") # Add methods methods = { name: func.__wrapped__ if hasattr(func, "__wrapped__") else func for name, func in cls.__dict__.items() if callable(func) and ( not base_cls or not hasattr(base_cls, name) or ( isinstance(func, (staticmethod, classmethod)) or (getattr(base_cls, name).__code__.co_code != func.__code__.co_code) ) ) } for name, method in methods.items(): method_source = get_source(method) # Clean up the indentation method_lines = method_source.split("\n") first_line = method_lines[0] indent = len(first_line) - len(first_line.lstrip()) method_lines = [line[indent:] for line in method_lines] method_source = "\n".join([" " + line if line.strip() else line for line in method_lines]) class_lines.append(method_source) class_lines.append("") # Find required imports using ImportFinder import_finder = ImportFinder() import_finder.visit(ast.parse("\n".join(class_lines))) required_imports = import_finder.packages # Build final code with imports final_lines = [] # Add base class import if needed if base_cls: final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}") # Add discovered imports for package in required_imports: final_lines.append(f"import {package}") if final_lines: # Add empty line after imports final_lines.append("") # Add the class code final_lines.extend(class_lines) return "\n".join(final_lines) def get_source(obj) -> str: """Get the source code of a class or callable object (e.g.: function, method). First attempts to get the source code using `inspect.getsource`. In a dynamic environment (e.g.: Jupyter, IPython), if this fails, falls back to retrieving the source code from the current interactive shell session. Args: obj: A class or callable object (e.g.: function, method) Returns: str: The source code of the object, dedented and stripped Raises: TypeError: If object is not a class or callable OSError: If source code cannot be retrieved from any source ValueError: If source cannot be found in IPython history Note: TODO: handle Python standard REPL """ if not (isinstance(obj, type) or callable(obj)): raise TypeError(f"Expected class or callable, got {type(obj)}") inspect_error = None try: # Handle dynamically created classes source = getattr(obj, "__source__", None) or inspect.getsource(obj) return dedent(source).strip() except OSError as e: # let's keep track of the exception to raise it if all further methods fail inspect_error = e try: import IPython shell = IPython.get_ipython() if not shell: raise ImportError("No active IPython shell found") all_cells = "\n".join(shell.user_ns.get("In", [])).strip() if not all_cells: raise ValueError("No code cells found in IPython session") tree = ast.parse(all_cells) for node in ast.walk(tree): if isinstance(node, (ast.ClassDef, ast.FunctionDef)) and node.name == obj.__name__: return dedent("\n".join(all_cells.split("\n")[node.lineno - 1 : node.end_lineno])).strip() raise ValueError(f"Could not find source code for {obj.__name__} in IPython history") except ImportError: # IPython is not available, let's just raise the original inspect error raise inspect_error except ValueError as e: # IPython is available but we couldn't find the source code, let's raise the error raise e from inspect_error def encode_image_base64(image): buffered = BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def make_image_url(base64_image): return f"data:image/png;base64,{base64_image}" def make_init_file(folder: str | Path): os.makedirs(folder, exist_ok=True) # Create __init__ with open(os.path.join(folder, "__init__.py"), "w"): pass def is_valid_name(name: str) -> bool: return name.isidentifier() and not keyword.iskeyword(name) if isinstance(name, str) else False AGENT_GRADIO_APP_TEMPLATE = """import yaml import os from smolagents import GradioUI, {{ class_name }}, {{ agent_dict['model']['class'] }} # Get current directory path CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) {% for tool in tools.values() -%} from {{managed_agent_relative_path}}tools.{{ tool.name }} import {{ tool.__class__.__name__ }} as {{ tool.name | camelcase }} {% endfor %} {% for managed_agent in managed_agents.values() -%} from {{managed_agent_relative_path}}managed_agents.{{ managed_agent.name }}.app import agent_{{ managed_agent.name }} {% endfor %} model = {{ agent_dict['model']['class'] }}( {% for key in agent_dict['model']['data'] if key != 'class' -%} {{ key }}={{ agent_dict['model']['data'][key]|repr }}, {% endfor %}) {% for tool in tools.values() -%} {{ tool.name }} = {{ tool.name | camelcase }}() {% endfor %} with open(os.path.join(CURRENT_DIR, "prompts.yaml"), 'r') as stream: prompt_templates = yaml.safe_load(stream) {{ agent_name }} = {{ class_name }}( model=model, tools=[{% for tool_name in tools.keys() if tool_name != "final_answer" %}{{ tool_name }}{% if not loop.last %}, {% endif %}{% endfor %}], managed_agents=[{% for subagent_name in managed_agents.keys() %}agent_{{ subagent_name }}{% if not loop.last %}, {% endif %}{% endfor %}], {% for attribute_name, value in agent_dict.items() if attribute_name not in ["class", "model", "tools", "prompt_templates", "authorized_imports", "managed_agents", "requirements"] -%} {{ attribute_name }}={{ value|repr }}, {% endfor %}prompt_templates=prompt_templates ) if __name__ == "__main__": GradioUI({{ agent_name }}).launch() """.strip() def create_agent_gradio_app_template(): env = jinja2.Environment(loader=jinja2.BaseLoader(), undefined=jinja2.StrictUndefined) env.filters["repr"] = repr env.filters["camelcase"] = lambda value: "".join(word.capitalize() for word in value.split("_")) return env.from_string(AGENT_GRADIO_APP_TEMPLATE) class RateLimiter: """Simple rate limiter that enforces a minimum delay between consecutive requests. This class is useful for limiting the rate of operations such as API requests, by ensuring that calls to `throttle()` are spaced out by at least a given interval based on the desired requests per minute. If no rate is specified (i.e., `requests_per_minute` is None), rate limiting is disabled and `throttle()` becomes a no-op. Args: requests_per_minute (`float | None`): Maximum number of allowed requests per minute. Use `None` to disable rate limiting. """ def __init__(self, requests_per_minute: float | None = None): self._enabled = requests_per_minute is not None self._interval = 60.0 / requests_per_minute if self._enabled else 0.0 self._last_call = 0.0 def throttle(self): """Pause execution to respect the rate limit, if enabled.""" if not self._enabled: return now = time.time() elapsed = now - self._last_call if elapsed < self._interval: time.sleep(self._interval - elapsed) self._last_call = time.time() class Retrying: """Simple retrying controller. Inspired from library [tenacity](https://github.com/jd/tenacity/).""" def __init__( self, max_attempts: int = 1, wait_seconds: float = 0.0, exponential_base: float = 2.0, jitter: bool = True, retry_predicate: Callable[[BaseException], bool] | None = None, reraise: bool = False, before_sleep_logger: tuple[Logger, int] | None = None, after_logger: tuple[Logger, int] | None = None, ): self.max_attempts = max_attempts self.wait_seconds = wait_seconds self.exponential_base = exponential_base self.jitter = jitter self.retry_predicate = retry_predicate self.reraise = reraise self.before_sleep_logger = before_sleep_logger self.after_logger = after_logger def __call__(self, fn, *args: Any, **kwargs: Any) -> Any: start_time = time.time() delay = self.wait_seconds for attempt_number in range(1, self.max_attempts + 1): try: result = fn(*args, **kwargs) # Log after successful call if we had retries if self.after_logger and attempt_number > 1: logger, log_level = self.after_logger seconds = time.time() - start_time fn_name = getattr(fn, "__name__", repr(fn)) logger.log( log_level, f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}/{self.max_attempts}.", ) return result except BaseException as e: # Check if we should retry should_retry = self.retry_predicate(e) if self.retry_predicate else False # If this is the last attempt or we shouldn't retry, raise if not should_retry or attempt_number >= self.max_attempts: if self.reraise: raise raise # Log after failed attempt if self.after_logger: logger, log_level = self.after_logger seconds = time.time() - start_time fn_name = getattr(fn, "__name__", repr(fn)) logger.log( log_level, f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}/{self.max_attempts}.", ) # Exponential backoff with jitter # https://cookbook.openai.com/examples/how_to_handle_rate_limits#example-3-manual-backoff-implementation delay *= self.exponential_base * (1 + self.jitter * random.random()) # Log before sleeping if self.before_sleep_logger: logger, log_level = self.before_sleep_logger fn_name = getattr(fn, "__name__", repr(fn)) logger.log( log_level, f"Retrying {fn_name} in {delay} seconds as it raised {e.__class__.__name__}: {e}.", ) # Sleep before next attempt if delay > 0: time.sleep(delay)