agent-ui / backend /tools.py
lvwerra's picture
lvwerra HF Staff
Add save_image tool to image agent
b75e768
"""
Centralized Tool Definitions & Execution Functions.
All OpenAI function-calling tool definitions live here.
Agent handlers compose tools by importing what they need:
from tools import execute_code, upload_files, download_files
TOOLS = [execute_code, upload_files, download_files]
Execution functions for tools that run server-side (web tools)
are also defined here, prefixed with `execute_`.
"""
import base64
import io
import json
import logging
import re
from typing import List, Dict, Optional
from urllib.parse import urljoin, urlparse
import httpx
import requests
logger = logging.getLogger(__name__)
# ============================================================
# Code execution tools (used by code agent)
# ============================================================
execute_code = {
"type": "function",
"function": {
"name": "execute_code",
"description": "Execute Python code in a stateful environment. Variables and imports persist between executions.",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The Python code to execute."
}
},
"required": ["code"]
}
}
}
upload_files = {
"type": "function",
"function": {
"name": "upload_files",
"description": "Upload files from the local workspace to the code execution environment for analysis. Files will be available at /home/user/<filename>. Use this to load data files, scripts, or any files you need to analyze.",
"parameters": {
"type": "object",
"properties": {
"paths": {
"type": "array",
"items": {"type": "string"},
"description": "List of file paths relative to the workspace root (e.g., ['data/sales.csv', 'config.json'])"
}
},
"required": ["paths"]
}
}
}
download_files = {
"type": "function",
"function": {
"name": "download_files",
"description": "Download files from the code execution environment to the local workspace. Use this to save generated files, processed data, or any output files you want to keep.",
"parameters": {
"type": "object",
"properties": {
"files": {
"type": "array",
"items": {
"type": "object",
"properties": {
"sandbox_path": {
"type": "string",
"description": "Path in the sandbox (e.g., '/home/user/output.csv')"
},
"local_path": {
"type": "string",
"description": "Destination path relative to workspace (e.g., 'results/output.csv')"
}
},
"required": ["sandbox_path", "local_path"]
},
"description": "List of files to download with their sandbox and local paths"
}
},
"required": ["files"]
}
}
}
# ============================================================
# Web tools (used by web agent)
# ============================================================
web_search = {
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web using Google. Returns titles, URLs, and short snippets for each result. Use this to find information, discover relevant pages, and get an overview of a topic.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
"num_results": {
"type": "integer",
"description": "Number of results to return (default: 5, max: 10)",
"default": 5
}
},
"required": ["query"]
}
}
}
read_url = {
"type": "function",
"function": {
"name": "read_url",
"description": "Fetch a web page and extract its main content as clean text with images and links. Returns content in chunks of ~10,000 characters. If the page is longer than one chunk, the response will indicate the total number of chunks — call again with a higher chunk number to continue reading. Set html=true to get a stripped-down HTML version of the page — only use this if the default text mode doesn't return enough detail (e.g., missing images, tables, or structured data).",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to read"
},
"chunk": {
"type": "integer",
"description": "Which chunk to read (0-indexed, default: 0). Use this to continue reading a long page.",
"default": 0
},
"use_html": {
"type": "boolean",
"description": "If true, return stripped-down HTML instead of extracted text. Only use when the default mode misses important content like images, tables, or page structure.",
"default": False
}
},
"required": ["url"]
}
}
}
screenshot_url = {
"type": "function",
"function": {
"name": "screenshot_url",
"description": "Take a screenshot of a web page. Use this when you need to see the visual layout, images, charts, or design of a page. The screenshot will be sent to you as an image.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to screenshot"
}
},
"required": ["url"]
}
}
}
# ============================================================
# Web tool execution functions
# ============================================================
_USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
def execute_web_search(query: str, serper_key: str, num_results: int = 5) -> str:
"""Search via Serper API, return formatted results as JSON string."""
url = "https://google.serper.dev/search"
payload = json.dumps({"q": query, "num": min(num_results, 10)})
headers = {
"X-API-KEY": serper_key,
"Content-Type": "application/json"
}
try:
response = requests.post(url, headers=headers, data=payload, timeout=10)
if response.status_code != 200:
return json.dumps({"error": f"Search API returned status {response.status_code}"})
data = response.json()
results = []
for item in data.get("organic", []):
results.append({
"title": item.get("title", ""),
"url": item.get("link", ""),
"snippet": item.get("snippet", "")
})
return json.dumps(results, indent=2)
except Exception as e:
logger.error(f"Web search error: {e}")
return json.dumps({"error": str(e)})
_CHUNK_SIZE = 10_000
_read_url_cache: Dict[str, str] = {} # url -> full markdown content
def _fetch_html(url: str) -> str:
"""Fetch raw HTML from URL. Returns HTML string or raises on error."""
resp = httpx.get(
url,
follow_redirects=True,
timeout=15,
headers={"User-Agent": _USER_AGENT},
)
if resp.status_code != 200:
raise RuntimeError(f"HTTP {resp.status_code} fetching {url}")
return resp.text
def _extract_text(html: str, url: str) -> str:
"""Extract main content as text with inline images and links.
Uses trafilatura (preferred) with fallback to readability+markdownify.
"""
# Try trafilatura first
try:
import trafilatura
text = trafilatura.extract(
html, include_images=True, include_tables=True,
include_links=True, output_format="txt", url=url,
)
if text and len(text.strip()) > 50:
from bs4 import BeautifulSoup
soup = BeautifulSoup(html, "html.parser")
title_tag = soup.find("title")
title = title_tag.get_text(strip=True) if title_tag else ""
body = text.strip()
return f"# {title}\n\n{body}" if title and not body.startswith(title) else body
except ImportError:
pass
# Fallback: readability + markdownify
try:
from readability import Document
from markdownify import markdownify
except ImportError:
return "Error: trafilatura or readability-lxml packages required."
doc = Document(html)
title = doc.title()
content_html = doc.summary()
md = markdownify(content_html, strip=["script", "style"])
def resolve_match(match):
img_url = match.group(2)
if img_url.startswith(("http://", "https://", "data:")):
return match.group(0)
return f"![{match.group(1)}]({urljoin(url, img_url)})"
md = re.sub(r'!\[([^\]]*)\]\(([^)]+)\)', resolve_match, md)
md = re.sub(r'\n{3,}', '\n\n', md).strip()
return f"# {title}\n\n{md}" if title else md
def _extract_html(raw_html: str) -> str:
"""Return stripped-down HTML preserving structure for inspection.
Removes scripts/styles/SVGs, strips non-essential attributes,
and focuses on the main content area. Capped at 30k chars.
"""
from bs4 import BeautifulSoup
soup = BeautifulSoup(raw_html, "html.parser")
for tag in soup.find_all(["script", "style", "svg", "noscript", "iframe"]):
tag.decompose()
keep_attrs = {"href", "src", "alt", "title", "class", "id",
"data-src", "srcset", "width", "height", "role"}
for tag in soup.find_all(True):
if tag.attrs is None:
continue
attrs = dict(tag.attrs)
for attr in attrs:
if attr not in keep_attrs:
del tag[attr]
main = (soup.find("main") or soup.find(id="content")
or soup.find(class_="mw-body-content")
or soup.find(id="mw-content-text") or soup.body)
result = main.prettify() if main else soup.prettify()
result = re.sub(r'\n\s*\n', '\n', result)
if len(result) > 30_000:
result = result[:30_000] + "\n<!-- truncated at 30k chars -->"
return result
def execute_read_url(url: str, chunk: int = 0, use_html: bool = False) -> str:
"""Fetch URL and return a specific chunk (0-indexed) of the content.
By default extracts clean text with images/links via trafilatura.
Set use_html=True to get stripped-down HTML — only use when the default
text mode doesn't return enough detail (e.g., missing images, tables,
or structured data).
"""
cache_key = f"{url}::{'html' if use_html else 'text'}"
if cache_key in _read_url_cache:
full_content = _read_url_cache[cache_key]
else:
try:
raw_html = _fetch_html(url)
full_content = _extract_html(raw_html) if use_html else _extract_text(raw_html, url)
except Exception as e:
logger.error(f"Read URL error for {url}: {e}")
return f"Error reading {url}: {str(e)}"
_read_url_cache[cache_key] = full_content
if full_content.startswith("Error"):
return full_content
total_len = len(full_content)
total_chunks = max(1, -(-total_len // _CHUNK_SIZE)) # ceil division
chunk = max(0, min(chunk, total_chunks - 1))
if total_chunks == 1:
return full_content
start = chunk * _CHUNK_SIZE
end = start + _CHUNK_SIZE
chunk_content = full_content[start:end]
return f"{chunk_content}\n\n[Chunk {chunk}/{total_chunks - 1} | Chars {start}-{min(end, total_len)} of {total_len} total]"
def execute_screenshot_url(url: str) -> Optional[str]:
"""Take a screenshot of a URL using Playwright, return base64 PNG."""
try:
from playwright.sync_api import sync_playwright
except ImportError:
return None # Caller should handle gracefully
try:
with sync_playwright() as p:
browser = p.chromium.launch(headless=True)
page = browser.new_page(viewport={"width": 1280, "height": 720})
page.goto(url, wait_until="networkidle", timeout=15000)
screenshot_bytes = page.screenshot(full_page=False)
browser.close()
return base64.b64encode(screenshot_bytes).decode("utf-8")
except Exception as e:
logger.error(f"Screenshot error for {url}: {e}")
return None
# ============================================================
# Image tools (used by image agent)
# ============================================================
generate_image = {
"type": "function",
"function": {
"name": "generate_image",
"description": "Generate an image from a text prompt. Returns an image reference name (e.g., 'image_1') that you can see and use with edit_image.",
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Detailed text description of the image to generate"
},
"model": {
"type": "string",
"description": "HuggingFace model to use (default: black-forest-labs/FLUX.1-schnell)",
"default": "black-forest-labs/FLUX.1-schnell"
}
},
"required": ["prompt"]
}
}
}
edit_image = {
"type": "function",
"function": {
"name": "edit_image",
"description": "Edit or transform an existing image using a text prompt. The source can be a URL (https://...) or a reference to a previously generated/loaded image (e.g., 'image_1').",
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Text description of the edit or transformation to apply"
},
"source": {
"type": "string",
"description": "Image URL or reference name from a previous tool call (e.g., 'image_1')"
},
"model": {
"type": "string",
"description": "HuggingFace model to use (default: black-forest-labs/FLUX.1-Kontext-dev)",
"default": "black-forest-labs/FLUX.1-Kontext-dev"
}
},
"required": ["prompt", "source"]
}
}
}
read_image = {
"type": "function",
"function": {
"name": "read_image",
"description": "Load a raster image (PNG, JPEG, GIF, WebP, BMP) from a URL or local file path. SVG is NOT supported. Returns an image reference name (e.g., 'image_1') that you can see and use with edit_image.",
"parameters": {
"type": "object",
"properties": {
"source": {
"type": "string",
"description": "URL (http/https) or local file path (e.g., 'plot.png', 'output/chart.jpg')"
}
},
"required": ["source"]
}
}
}
save_image = {
"type": "function",
"function": {
"name": "save_image",
"description": "Save an image to the workspace as a PNG file. Source can be a reference (e.g., 'image_1') or a URL.",
"parameters": {
"type": "object",
"properties": {
"source": {
"type": "string",
"description": "Image reference from a previous tool call (e.g., 'image_1') or a URL"
},
"filename": {
"type": "string",
"description": "Filename to save as (e.g., 'logo.png'). Will be saved in the workspace root."
}
},
"required": ["source", "filename"]
}
}
}
# Keep old name as alias for backwards compatibility
read_image_url = read_image
# ============================================================
# Image tool execution functions
# ============================================================
def execute_generate_image(prompt: str, hf_token: str, model: str = "black-forest-labs/FLUX.1-schnell") -> tuple:
"""Text-to-image via HF InferenceClient. Returns (base64_png, None) on success or (None, error_str) on failure."""
try:
from huggingface_hub import InferenceClient
except ImportError:
return None, "huggingface_hub not installed"
try:
client = InferenceClient(token=hf_token)
image = client.text_to_image(prompt, model=model)
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8"), None
except Exception as e:
logger.error(f"Generate image error: {e}")
return None, str(e)
def execute_edit_image(prompt: str, source_image_bytes: bytes, hf_token: str, model: str = "black-forest-labs/FLUX.1-Kontext-dev") -> tuple:
"""Image-to-image via HF InferenceClient. Returns (base64_png, None) on success or (None, error_str) on failure."""
try:
from huggingface_hub import InferenceClient
from PIL import Image
except ImportError:
return None, "huggingface_hub or Pillow not installed"
try:
client = InferenceClient(token=hf_token)
input_image = Image.open(io.BytesIO(source_image_bytes))
# Resize large images to avoid API failures (most models expect ~1024px)
MAX_EDIT_DIM = 1024
if max(input_image.size) > MAX_EDIT_DIM:
input_image.thumbnail((MAX_EDIT_DIM, MAX_EDIT_DIM), Image.LANCZOS)
logger.info(f"Resized input image to {input_image.size} for editing")
result = client.image_to_image(input_image, prompt=prompt, model=model)
buffer = io.BytesIO()
result.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8"), None
except Exception as e:
logger.error(f"Edit image error: {e}")
return None, str(e)
def execute_read_image(source: str, files_root: str = None) -> Optional[str]:
"""Load image from URL or local file path, return base64 string or None on error.
Supported formats: PNG, JPEG, GIF, WebP, BMP. SVG is NOT supported.
"""
import os
# Check if it's a URL
if source.startswith(("http://", "https://")):
try:
resp = httpx.get(
source,
follow_redirects=True,
timeout=15,
headers={"User-Agent": _USER_AGENT}
)
if resp.status_code != 200:
logger.error(f"Read image error: HTTP {resp.status_code} for {source}")
return None
return base64.b64encode(resp.content).decode("utf-8")
except Exception as e:
logger.error(f"Read image URL error for {source}: {e}")
return None
# Local file path
if files_root:
full_path = os.path.normpath(os.path.join(files_root, source))
# Security: ensure path stays within files_root
if not full_path.startswith(os.path.normpath(files_root)):
logger.error(f"Read image error: path escapes files_root: {source}")
return None
else:
full_path = os.path.abspath(source)
try:
if not os.path.isfile(full_path):
logger.error(f"Read image error: file not found: {full_path}")
return None
with open(full_path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
except Exception as e:
logger.error(f"Read image file error for {full_path}: {e}")
return None
def extract_and_download_images(markdown: str, max_images: int = 5) -> List[str]:
"""Extract image URLs from markdown and download them as base64 strings.
Returns list of base64-encoded image strings (PNG/JPEG).
Skips SVGs, data URIs, and failed downloads.
"""
import re as _re
img_pattern = _re.compile(r'!\[[^\]]*\]\(([^)]+)\)')
urls = img_pattern.findall(markdown)
results = []
for url in urls:
if len(results) >= max_images:
break
if url.startswith("data:") or url.endswith(".svg"):
continue
try:
resp = httpx.get(
url,
follow_redirects=True,
timeout=10,
headers={"User-Agent": _USER_AGENT}
)
if resp.status_code != 200:
continue
ct = resp.headers.get("content-type", "")
if not ct.startswith("image/"):
continue
results.append(base64.b64encode(resp.content).decode("utf-8"))
except Exception:
continue
return results
# Keep old name as alias
def execute_read_image_url(url: str) -> Optional[str]:
return execute_read_image(url)
# ============================================================
# HTML display tool (used by command center)
# ============================================================
show_html = {
"type": "function",
"function": {
"name": "show_html",
"description": "Display HTML content in the chat. Accepts either a file path to an HTML file or a raw HTML string. Use this to show interactive visualizations, maps, charts, or any HTML content produced by a code agent.",
"parameters": {
"type": "object",
"properties": {
"source": {
"type": "string",
"description": "Either a file path (e.g., 'workspace/map.html') or a raw HTML string (starting with '<')"
}
},
"required": ["source"]
}
}
}
def execute_show_html(source: str, files_root: str = None) -> dict:
"""Load HTML from a file path or use a raw HTML string.
Returns dict with:
- "content": str description for the LLM
- "html": the HTML content string (or None on error)
"""
import os
if source.strip().startswith("<"):
return {
"content": "Rendered inline HTML content.",
"html": source,
}
# File path — resolve relative to files_root
file_path = source
if files_root and not os.path.isabs(file_path):
file_path = os.path.join(files_root, file_path)
try:
with open(file_path, "r", encoding="utf-8") as f:
html_content = f.read()
return {
"content": f"Rendered HTML from file: {source}",
"html": html_content,
}
except Exception as e:
return {
"content": f"Failed to load HTML from '{source}': {e}",
"html": None,
}
# ============================================================
# Direct tool registry (used by command center)
# ============================================================
# Each entry combines the OpenAI tool schema with an execute function.
# The execute function receives (args_dict, context_dict).
DIRECT_TOOL_REGISTRY = {
"show_html": {
"schema": show_html,
"execute": lambda args, ctx: execute_show_html(
args.get("source", ""), files_root=ctx.get("files_root")
),
},
}