First_agent_template / tools /save_image.py
Erfan
Fix gradio UI and add image saving tool and .gitignore
20f10c1
from __future__ import annotations
import re
import shutil
import uuid
from io import BytesIO
from pathlib import Path
from typing import Any, Optional
from smolagents.tools import Tool
class SaveImageTool(Tool):
name = "save_image"
description = (
"Save an image to the local `generated_images/` folder and return the saved file path. "
"Use this instead of importing `os` for filesystem operations."
)
inputs = {
"image": {"type": "any", "description": "An image object (PIL, AgentImage), bytes, or an existing file path."},
"filename": {
"type": "string",
"description": "Optional output filename (e.g. `cat.png`). Defaults to a random `.png` name.",
"nullable": True,
},
}
output_type = "string"
def forward(self, image: Any, filename: Optional[str] = None) -> str:
base_dir = Path("generated_images")
base_dir.mkdir(parents=True, exist_ok=True)
base_dir_resolved = base_dir.resolve()
safe_name = self._sanitize_filename(filename) if filename else f"image_{uuid.uuid4().hex[:8]}.png"
out_path = (base_dir / safe_name).resolve()
if not out_path.is_relative_to(base_dir_resolved):
raise ValueError("Refusing to write outside `generated_images/`.")
# If `image` is already a path on disk, just copy it.
if isinstance(image, (str, Path)):
src = Path(image)
if src.exists() and src.is_file():
shutil.copyfile(src, out_path)
return str(out_path)
pil_img = self._to_pil(image)
pil_img.save(out_path)
return str(out_path)
@staticmethod
def _sanitize_filename(filename: str) -> str:
name = Path(filename).name # drop any path parts
name = re.sub(r"[^A-Za-z0-9._-]", "_", name).strip("._")
if not name:
name = f"image_{uuid.uuid4().hex[:8]}.png"
if "." not in name:
name += ".png"
return name
@staticmethod
def _to_pil(image: Any):
from PIL import Image
if hasattr(image, "save"):
return image
if hasattr(image, "to_pil"):
pil = image.to_pil()
if pil is not None and hasattr(pil, "save"):
return pil
if isinstance(image, (bytes, bytearray)):
return Image.open(BytesIO(image))
if hasattr(image, "to_string"):
as_str = image.to_string()
if isinstance(as_str, str):
p = Path(as_str)
if p.exists() and p.is_file():
return Image.open(p)
raise TypeError(f"Unsupported image type for saving: {type(image)}")