from typing import Any, Optional, Callable, List from smolagents.tools import Tool import duckduckgo_search import googlesearch from functools import wraps def setup_search_dependency(package_name: str, import_func: Callable): """Utility function to handle search dependency setup""" try: return import_func() except ImportError as e: raise ImportError( f"You must install package `{package_name}` to run this tool: " f"for instance run `pip install {package_name}`." ) from e def handle_search_errors(func): """Decorator to handle common search error cases""" @wraps(func) def wrapper(*args, **kwargs): try: results = func(*args, **kwargs) if not results: raise Exception("No results found! Try a less restrictive/shorter query.") return results except Exception as e: return f"Error performing search: {str(e)}" return wrapper class DuckDuckGoSearchTool(Tool): name = "web_search" description = "Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results." inputs = {'query': {'type': 'string', 'description': 'The search query to perform.'}} output_type = "string" def __init__(self, max_results=10, **kwargs): super().__init__() self.max_results = max_results self.ddgs = setup_search_dependency( 'duckduckgo-search', lambda: duckduckgo_search.DDGS(**kwargs) ) @handle_search_errors def forward(self, query: str) -> str: results = self.ddgs.text(query, max_results=self.max_results) formatted_results = [ f"[{result['title']}]({result['href']})\n{result['body']}" for result in results ] return "## Search Results\n\n" + "\n\n".join(formatted_results) class GoogleSearchTool(Tool): name = "google_search" description = "Performs a Google web search based on your query and returns the top search results." inputs = {'query': {'type': 'string', 'description': 'The search query to perform.'}} output_type = "string" def __init__(self, max_results=10, **kwargs): super().__init__() self.max_results = max_results self.search = setup_search_dependency( 'googlesearch-python', lambda: googlesearch.search ) @handle_search_errors def forward(self, query: str) -> str: search_results = list(self.search(query, num_results=self.max_results)) formatted_results = [ f"[Result {i+1}]({url})\n{url}" for i, url in enumerate(search_results) ] return "## Search Results\n\n" + "\n\n".join(formatted_results)