First_agent_template / tools /image_handler.py
juanmaguitar's picture
more reliable way of getting images
e27aff8
import os
import tempfile
import requests
from typing import Dict, List, Optional
from smolagents.tools import Tool
import time
class ImageHandlerTool(Tool):
name = "image_handler"
description = "Gets or generates images for a given topic, with fallback options"
inputs = {
'query': {'type': 'string', 'description': 'The topic to get images for'},
'num_images': {
'type': 'integer',
'description': 'Number of images to get/generate',
'nullable': True
},
'style': {
'type': 'string',
'description': 'Style for generated images (e.g., "photo", "artistic", "realistic")',
'nullable': True
},
'skip_web_search': {
'type': 'boolean',
'description': 'Whether to skip web search and go straight to generation',
'nullable': True
}
}
output_type = "object"
def __init__(self, web_search_tool, image_gen_tool, temp_dir=None):
super().__init__()
self.web_search = web_search_tool
self.image_gen = image_gen_tool
self.temp_dir = temp_dir or tempfile.gettempdir()
def _download_image(self, url: str, filename: str) -> Optional[str]:
"""Downloads an image from a URL and saves it to a temporary file"""
try:
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
response = requests.get(url, timeout=10, headers=headers)
response.raise_for_status()
# Check if response is actually an image
content_type = response.headers.get('content-type', '')
if not content_type.startswith('image/'):
return None
# Ensure temp directory exists
os.makedirs(self.temp_dir, exist_ok=True)
# Save the image
file_path = os.path.join(self.temp_dir, filename)
with open(file_path, 'wb') as f:
f.write(response.content)
return file_path
except Exception as e:
print(f"Failed to download image from {url}: {str(e)}")
return None
def _try_web_search(self, query: str, num_images: int) -> List[Dict]:
"""Attempts to find images via web search"""
results = []
try:
# Try different search queries with better targeting
search_queries = [
f"{query} high resolution photo",
f"{query} professional photography",
f"{query} best pictures",
f"{query} travel photography"
]
for search_query in search_queries:
if len(results) >= num_images:
break
time.sleep(2) # Rate limiting
try:
search_results = self.web_search.forward(
query=search_query, max_results=num_images)
if isinstance(search_results, str): # Handle string responses
continue
for idx, result in enumerate(search_results):
if len(results) >= num_images:
break
# Try both image_url and direct URL fields
url = result.get('image_url') or result.get('url')
if url and url.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')):
filename = f"{query.replace(' ', '_')}_{idx}.jpg"
file_path = self._download_image(url, filename)
if file_path:
results.append({
'file_path': file_path,
'source': 'web',
'url': url,
'title': result.get('title', ''),
'attribution': result.get('source', '')
})
except Exception as search_error:
print(f"Search query failed: {str(search_error)}")
continue
except Exception as e:
print(f"Web search failed: {str(e)}")
return results
def _generate_images(self, query: str, num_images: int, style: str = "photo") -> List[Dict]:
"""Generates images using the image generation tool"""
results = []
try:
# Enhanced prompts for better generation
prompts = [
f"Generate a {style} style image of {query}, high quality, detailed",
f"Create a {style} representation of {query}, professional quality",
f"Make a {style} image showing {query}, realistic and clear"
]
for idx in range(num_images):
prompt = prompts[idx % len(prompts)].replace(
query, query + f" {idx+1}")
try:
response = self.image_gen.forward(prompt=prompt)
if isinstance(response, dict) and 'image_path' in response:
results.append({
'file_path': response['image_path'],
'source': 'generated',
'prompt': prompt,
'style': style
})
elif isinstance(response, str) and os.path.exists(response):
results.append({
'file_path': response,
'source': 'generated',
'prompt': prompt,
'style': style
})
except Exception as gen_error:
print(
f"Failed to generate image {idx+1}: {str(gen_error)}")
continue
time.sleep(1) # Brief pause between generations
except Exception as e:
print(f"Image generation failed: {str(e)}")
return results
def forward(self, query: str, num_images: int = 2, style: str = "photo", skip_web_search: bool = False) -> Dict:
"""Gets or generates images for the query
Args:
query: What to get images of
num_images: How many images to get
style: Style for generated images
skip_web_search: Whether to skip web search
Returns:
Dict containing results and status
"""
all_results = []
# Try web search first unless skipped
if not skip_web_search:
web_results = self._try_web_search(query, num_images)
all_results.extend(web_results)
# If we don't have enough images, try generation
if len(all_results) < num_images:
remaining = num_images - len(all_results)
generated = self._generate_images(query, remaining, style)
all_results.extend(generated)
if not all_results:
return {
"status": "error",
"message": "Failed to get any images"
}
return {
"status": "success",
"images": all_results,
"total": len(all_results),
"sources": {
"web": len([img for img in all_results if img['source'] == 'web']),
"generated": len([img for img in all_results if img['source'] == 'generated'])
}
}