Spaces:
Sleeping
Sleeping
File size: 7,669 Bytes
a33f19e e27aff8 a33f19e e27aff8 a33f19e e27aff8 a33f19e e27aff8 a33f19e e27aff8 a33f19e e27aff8 a33f19e e27aff8 a33f19e e27aff8 a33f19e e27aff8 a33f19e e27aff8 a33f19e e27aff8 a33f19e e27aff8 a33f19e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import os
import tempfile
import requests
from typing import Dict, List, Optional
from smolagents.tools import Tool
import time
class ImageHandlerTool(Tool):
name = "image_handler"
description = "Gets or generates images for a given topic, with fallback options"
inputs = {
'query': {'type': 'string', 'description': 'The topic to get images for'},
'num_images': {
'type': 'integer',
'description': 'Number of images to get/generate',
'nullable': True
},
'style': {
'type': 'string',
'description': 'Style for generated images (e.g., "photo", "artistic", "realistic")',
'nullable': True
},
'skip_web_search': {
'type': 'boolean',
'description': 'Whether to skip web search and go straight to generation',
'nullable': True
}
}
output_type = "object"
def __init__(self, web_search_tool, image_gen_tool, temp_dir=None):
super().__init__()
self.web_search = web_search_tool
self.image_gen = image_gen_tool
self.temp_dir = temp_dir or tempfile.gettempdir()
def _download_image(self, url: str, filename: str) -> Optional[str]:
"""Downloads an image from a URL and saves it to a temporary file"""
try:
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
response = requests.get(url, timeout=10, headers=headers)
response.raise_for_status()
# Check if response is actually an image
content_type = response.headers.get('content-type', '')
if not content_type.startswith('image/'):
return None
# Ensure temp directory exists
os.makedirs(self.temp_dir, exist_ok=True)
# Save the image
file_path = os.path.join(self.temp_dir, filename)
with open(file_path, 'wb') as f:
f.write(response.content)
return file_path
except Exception as e:
print(f"Failed to download image from {url}: {str(e)}")
return None
def _try_web_search(self, query: str, num_images: int) -> List[Dict]:
"""Attempts to find images via web search"""
results = []
try:
# Try different search queries with better targeting
search_queries = [
f"{query} high resolution photo",
f"{query} professional photography",
f"{query} best pictures",
f"{query} travel photography"
]
for search_query in search_queries:
if len(results) >= num_images:
break
time.sleep(2) # Rate limiting
try:
search_results = self.web_search.forward(
query=search_query, max_results=num_images)
if isinstance(search_results, str): # Handle string responses
continue
for idx, result in enumerate(search_results):
if len(results) >= num_images:
break
# Try both image_url and direct URL fields
url = result.get('image_url') or result.get('url')
if url and url.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')):
filename = f"{query.replace(' ', '_')}_{idx}.jpg"
file_path = self._download_image(url, filename)
if file_path:
results.append({
'file_path': file_path,
'source': 'web',
'url': url,
'title': result.get('title', ''),
'attribution': result.get('source', '')
})
except Exception as search_error:
print(f"Search query failed: {str(search_error)}")
continue
except Exception as e:
print(f"Web search failed: {str(e)}")
return results
def _generate_images(self, query: str, num_images: int, style: str = "photo") -> List[Dict]:
"""Generates images using the image generation tool"""
results = []
try:
# Enhanced prompts for better generation
prompts = [
f"Generate a {style} style image of {query}, high quality, detailed",
f"Create a {style} representation of {query}, professional quality",
f"Make a {style} image showing {query}, realistic and clear"
]
for idx in range(num_images):
prompt = prompts[idx % len(prompts)].replace(
query, query + f" {idx+1}")
try:
response = self.image_gen.forward(prompt=prompt)
if isinstance(response, dict) and 'image_path' in response:
results.append({
'file_path': response['image_path'],
'source': 'generated',
'prompt': prompt,
'style': style
})
elif isinstance(response, str) and os.path.exists(response):
results.append({
'file_path': response,
'source': 'generated',
'prompt': prompt,
'style': style
})
except Exception as gen_error:
print(
f"Failed to generate image {idx+1}: {str(gen_error)}")
continue
time.sleep(1) # Brief pause between generations
except Exception as e:
print(f"Image generation failed: {str(e)}")
return results
def forward(self, query: str, num_images: int = 2, style: str = "photo", skip_web_search: bool = False) -> Dict:
"""Gets or generates images for the query
Args:
query: What to get images of
num_images: How many images to get
style: Style for generated images
skip_web_search: Whether to skip web search
Returns:
Dict containing results and status
"""
all_results = []
# Try web search first unless skipped
if not skip_web_search:
web_results = self._try_web_search(query, num_images)
all_results.extend(web_results)
# If we don't have enough images, try generation
if len(all_results) < num_images:
remaining = num_images - len(all_results)
generated = self._generate_images(query, remaining, style)
all_results.extend(generated)
if not all_results:
return {
"status": "error",
"message": "Failed to get any images"
}
return {
"status": "success",
"images": all_results,
"total": len(all_results),
"sources": {
"web": len([img for img in all_results if img['source'] == 'web']),
"generated": len([img for img in all_results if img['source'] == 'generated'])
}
}
|