import os import tempfile import requests from typing import Dict, List, Optional from smolagents.tools import Tool import time class ImageHandlerTool(Tool): name = "image_handler" description = "Gets or generates images for a given topic, with fallback options" inputs = { 'query': {'type': 'string', 'description': 'The topic to get images for'}, 'num_images': { 'type': 'integer', 'description': 'Number of images to get/generate', 'nullable': True }, 'style': { 'type': 'string', 'description': 'Style for generated images (e.g., "photo", "artistic", "realistic")', 'nullable': True }, 'skip_web_search': { 'type': 'boolean', 'description': 'Whether to skip web search and go straight to generation', 'nullable': True } } output_type = "object" def __init__(self, web_search_tool, image_gen_tool, temp_dir=None): super().__init__() self.web_search = web_search_tool self.image_gen = image_gen_tool self.temp_dir = temp_dir or tempfile.gettempdir() def _download_image(self, url: str, filename: str) -> Optional[str]: """Downloads an image from a URL and saves it to a temporary file""" try: headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' } response = requests.get(url, timeout=10, headers=headers) response.raise_for_status() # Check if response is actually an image content_type = response.headers.get('content-type', '') if not content_type.startswith('image/'): return None # Ensure temp directory exists os.makedirs(self.temp_dir, exist_ok=True) # Save the image file_path = os.path.join(self.temp_dir, filename) with open(file_path, 'wb') as f: f.write(response.content) return file_path except Exception as e: print(f"Failed to download image from {url}: {str(e)}") return None def _try_web_search(self, query: str, num_images: int) -> List[Dict]: """Attempts to find images via web search""" results = [] try: # Try different search queries with better targeting search_queries = [ f"{query} high resolution photo", f"{query} professional photography", f"{query} best pictures", f"{query} travel photography" ] for search_query in search_queries: if len(results) >= num_images: break time.sleep(2) # Rate limiting try: search_results = self.web_search.forward( query=search_query, max_results=num_images) if isinstance(search_results, str): # Handle string responses continue for idx, result in enumerate(search_results): if len(results) >= num_images: break # Try both image_url and direct URL fields url = result.get('image_url') or result.get('url') if url and url.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')): filename = f"{query.replace(' ', '_')}_{idx}.jpg" file_path = self._download_image(url, filename) if file_path: results.append({ 'file_path': file_path, 'source': 'web', 'url': url, 'title': result.get('title', ''), 'attribution': result.get('source', '') }) except Exception as search_error: print(f"Search query failed: {str(search_error)}") continue except Exception as e: print(f"Web search failed: {str(e)}") return results def _generate_images(self, query: str, num_images: int, style: str = "photo") -> List[Dict]: """Generates images using the image generation tool""" results = [] try: # Enhanced prompts for better generation prompts = [ f"Generate a {style} style image of {query}, high quality, detailed", f"Create a {style} representation of {query}, professional quality", f"Make a {style} image showing {query}, realistic and clear" ] for idx in range(num_images): prompt = prompts[idx % len(prompts)].replace( query, query + f" {idx+1}") try: response = self.image_gen.forward(prompt=prompt) if isinstance(response, dict) and 'image_path' in response: results.append({ 'file_path': response['image_path'], 'source': 'generated', 'prompt': prompt, 'style': style }) elif isinstance(response, str) and os.path.exists(response): results.append({ 'file_path': response, 'source': 'generated', 'prompt': prompt, 'style': style }) except Exception as gen_error: print( f"Failed to generate image {idx+1}: {str(gen_error)}") continue time.sleep(1) # Brief pause between generations except Exception as e: print(f"Image generation failed: {str(e)}") return results def forward(self, query: str, num_images: int = 2, style: str = "photo", skip_web_search: bool = False) -> Dict: """Gets or generates images for the query Args: query: What to get images of num_images: How many images to get style: Style for generated images skip_web_search: Whether to skip web search Returns: Dict containing results and status """ all_results = [] # Try web search first unless skipped if not skip_web_search: web_results = self._try_web_search(query, num_images) all_results.extend(web_results) # If we don't have enough images, try generation if len(all_results) < num_images: remaining = num_images - len(all_results) generated = self._generate_images(query, remaining, style) all_results.extend(generated) if not all_results: return { "status": "error", "message": "Failed to get any images" } return { "status": "success", "images": all_results, "total": len(all_results), "sources": { "web": len([img for img in all_results if img['source'] == 'web']), "generated": len([img for img in all_results if img['source'] == 'generated']) } }