Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import requests | |
| from typing import Dict, List, Optional | |
| from smolagents.tools import Tool | |
| import time | |
| class ImageHandlerTool(Tool): | |
| name = "image_handler" | |
| description = "Gets or generates images for a given topic, with fallback options" | |
| inputs = { | |
| 'query': {'type': 'string', 'description': 'The topic to get images for'}, | |
| 'num_images': { | |
| 'type': 'integer', | |
| 'description': 'Number of images to get/generate', | |
| 'nullable': True | |
| }, | |
| 'style': { | |
| 'type': 'string', | |
| 'description': 'Style for generated images (e.g., "photo", "artistic", "realistic")', | |
| 'nullable': True | |
| }, | |
| 'skip_web_search': { | |
| 'type': 'boolean', | |
| 'description': 'Whether to skip web search and go straight to generation', | |
| 'nullable': True | |
| } | |
| } | |
| output_type = "object" | |
| def __init__(self, web_search_tool, image_gen_tool, temp_dir=None): | |
| super().__init__() | |
| self.web_search = web_search_tool | |
| self.image_gen = image_gen_tool | |
| self.temp_dir = temp_dir or tempfile.gettempdir() | |
| def _download_image(self, url: str, filename: str) -> Optional[str]: | |
| """Downloads an image from a URL and saves it to a temporary file""" | |
| try: | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
| } | |
| response = requests.get(url, timeout=10, headers=headers) | |
| response.raise_for_status() | |
| # Check if response is actually an image | |
| content_type = response.headers.get('content-type', '') | |
| if not content_type.startswith('image/'): | |
| return None | |
| # Ensure temp directory exists | |
| os.makedirs(self.temp_dir, exist_ok=True) | |
| # Save the image | |
| file_path = os.path.join(self.temp_dir, filename) | |
| with open(file_path, 'wb') as f: | |
| f.write(response.content) | |
| return file_path | |
| except Exception as e: | |
| print(f"Failed to download image from {url}: {str(e)}") | |
| return None | |
| def _try_web_search(self, query: str, num_images: int) -> List[Dict]: | |
| """Attempts to find images via web search""" | |
| results = [] | |
| try: | |
| # Try different search queries with better targeting | |
| search_queries = [ | |
| f"{query} high resolution photo", | |
| f"{query} professional photography", | |
| f"{query} best pictures", | |
| f"{query} travel photography" | |
| ] | |
| for search_query in search_queries: | |
| if len(results) >= num_images: | |
| break | |
| time.sleep(2) # Rate limiting | |
| try: | |
| search_results = self.web_search.forward( | |
| query=search_query, max_results=num_images) | |
| if isinstance(search_results, str): # Handle string responses | |
| continue | |
| for idx, result in enumerate(search_results): | |
| if len(results) >= num_images: | |
| break | |
| # Try both image_url and direct URL fields | |
| url = result.get('image_url') or result.get('url') | |
| if url and url.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')): | |
| filename = f"{query.replace(' ', '_')}_{idx}.jpg" | |
| file_path = self._download_image(url, filename) | |
| if file_path: | |
| results.append({ | |
| 'file_path': file_path, | |
| 'source': 'web', | |
| 'url': url, | |
| 'title': result.get('title', ''), | |
| 'attribution': result.get('source', '') | |
| }) | |
| except Exception as search_error: | |
| print(f"Search query failed: {str(search_error)}") | |
| continue | |
| except Exception as e: | |
| print(f"Web search failed: {str(e)}") | |
| return results | |
| def _generate_images(self, query: str, num_images: int, style: str = "photo") -> List[Dict]: | |
| """Generates images using the image generation tool""" | |
| results = [] | |
| try: | |
| # Enhanced prompts for better generation | |
| prompts = [ | |
| f"Generate a {style} style image of {query}, high quality, detailed", | |
| f"Create a {style} representation of {query}, professional quality", | |
| f"Make a {style} image showing {query}, realistic and clear" | |
| ] | |
| for idx in range(num_images): | |
| prompt = prompts[idx % len(prompts)].replace( | |
| query, query + f" {idx+1}") | |
| try: | |
| response = self.image_gen.forward(prompt=prompt) | |
| if isinstance(response, dict) and 'image_path' in response: | |
| results.append({ | |
| 'file_path': response['image_path'], | |
| 'source': 'generated', | |
| 'prompt': prompt, | |
| 'style': style | |
| }) | |
| elif isinstance(response, str) and os.path.exists(response): | |
| results.append({ | |
| 'file_path': response, | |
| 'source': 'generated', | |
| 'prompt': prompt, | |
| 'style': style | |
| }) | |
| except Exception as gen_error: | |
| print( | |
| f"Failed to generate image {idx+1}: {str(gen_error)}") | |
| continue | |
| time.sleep(1) # Brief pause between generations | |
| except Exception as e: | |
| print(f"Image generation failed: {str(e)}") | |
| return results | |
| def forward(self, query: str, num_images: int = 2, style: str = "photo", skip_web_search: bool = False) -> Dict: | |
| """Gets or generates images for the query | |
| Args: | |
| query: What to get images of | |
| num_images: How many images to get | |
| style: Style for generated images | |
| skip_web_search: Whether to skip web search | |
| Returns: | |
| Dict containing results and status | |
| """ | |
| all_results = [] | |
| # Try web search first unless skipped | |
| if not skip_web_search: | |
| web_results = self._try_web_search(query, num_images) | |
| all_results.extend(web_results) | |
| # If we don't have enough images, try generation | |
| if len(all_results) < num_images: | |
| remaining = num_images - len(all_results) | |
| generated = self._generate_images(query, remaining, style) | |
| all_results.extend(generated) | |
| if not all_results: | |
| return { | |
| "status": "error", | |
| "message": "Failed to get any images" | |
| } | |
| return { | |
| "status": "success", | |
| "images": all_results, | |
| "total": len(all_results), | |
| "sources": { | |
| "web": len([img for img in all_results if img['source'] == 'web']), | |
| "generated": len([img for img in all_results if img['source'] == 'generated']) | |
| } | |
| } | |