Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import re | |
| import shutil | |
| import uuid | |
| from io import BytesIO | |
| from pathlib import Path | |
| from typing import Any, Optional | |
| from smolagents.tools import Tool | |
| class SaveImageTool(Tool): | |
| name = "save_image" | |
| description = ( | |
| "Save an image to the local `generated_images/` folder and return the saved file path. " | |
| "Use this instead of importing `os` for filesystem operations." | |
| ) | |
| inputs = { | |
| "image": {"type": "any", "description": "An image object (PIL, AgentImage), bytes, or an existing file path."}, | |
| "filename": { | |
| "type": "string", | |
| "description": "Optional output filename (e.g. `cat.png`). Defaults to a random `.png` name.", | |
| "nullable": True, | |
| }, | |
| } | |
| output_type = "string" | |
| def forward(self, image: Any, filename: Optional[str] = None) -> str: | |
| base_dir = Path("generated_images") | |
| base_dir.mkdir(parents=True, exist_ok=True) | |
| base_dir_resolved = base_dir.resolve() | |
| safe_name = self._sanitize_filename(filename) if filename else f"image_{uuid.uuid4().hex[:8]}.png" | |
| out_path = (base_dir / safe_name).resolve() | |
| if not out_path.is_relative_to(base_dir_resolved): | |
| raise ValueError("Refusing to write outside `generated_images/`.") | |
| # If `image` is already a path on disk, just copy it. | |
| if isinstance(image, (str, Path)): | |
| src = Path(image) | |
| if src.exists() and src.is_file(): | |
| shutil.copyfile(src, out_path) | |
| return str(out_path) | |
| pil_img = self._to_pil(image) | |
| pil_img.save(out_path) | |
| return str(out_path) | |
| def _sanitize_filename(filename: str) -> str: | |
| name = Path(filename).name # drop any path parts | |
| name = re.sub(r"[^A-Za-z0-9._-]", "_", name).strip("._") | |
| if not name: | |
| name = f"image_{uuid.uuid4().hex[:8]}.png" | |
| if "." not in name: | |
| name += ".png" | |
| return name | |
| def _to_pil(image: Any): | |
| from PIL import Image | |
| if hasattr(image, "save"): | |
| return image | |
| if hasattr(image, "to_pil"): | |
| pil = image.to_pil() | |
| if pil is not None and hasattr(pil, "save"): | |
| return pil | |
| if isinstance(image, (bytes, bytearray)): | |
| return Image.open(BytesIO(image)) | |
| if hasattr(image, "to_string"): | |
| as_str = image.to_string() | |
| if isinstance(as_str, str): | |
| p = Path(as_str) | |
| if p.exists() and p.is_file(): | |
| return Image.open(p) | |
| raise TypeError(f"Unsupported image type for saving: {type(image)}") | |