Gabandino's picture
Fix typo for final submission
a14ace8 verified
from smolagents import Tool
from googleapiclient.discovery import build
import os
class GoogleSearchTool(Tool):
name = "web_search"
description = """Performs a google web search for a query then returns top search results in markdown format."""
inputs = {
"query": {
"type": "string",
"description": "The query to perform a web search for"
}
}
output_type = "string"
skip_forward_signature_validation = True
def __init__(
self,
api_key: str | None = None,
search_engine_id: str | None = None,
num_results: int = 10,
**kwargs,
):
from dotenv import load_dotenv
load_dotenv()
api_key = os.getenv("GOOGLE_SEARCH_API_KEY")
search_engine_id = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
if not api_key:
raise ValueError("GOOGLE_SEARCH_API_KEY is not set")
if not search_engine_id:
raise ValueError("GOOGLE_SEARCH_ENGINE_ID is not set")
self.cse = build(
"customsearch",
"v1",
developerKey=api_key
).cse()
self.cx = search_engine_id
self.num = num_results
super().__init__(**kwargs)
def _collect_params(self) -> dict:
return {}
def forward(self, query: str, *args, **kwargs) -> str:
params = {
"q": query,
"cx": self.cx,
"fields": "items(link, title, snippet)",
"num": self.num,
}
params = params | self._collect_params(*args, **kwargs)
res = self.cse.list(**params).execute()
if "items" not in res:
return "No results found"
return "\n\n".join(f"{item['title']}\n{item['link']}\n{item['snippet']}" for item in res["items"])
class GoogleSiteSearchTool(GoogleSearchTool):
name = "site_search"
description = """Searches a specific website for a given query and returns the site contents in markdown format. Use when information is likely to be found on a particular domain, such as reddit.com, wikipedia.org, ieee.org, or arxiv.org."""
inputs = {
"query": {
"type": "string",
"description": "The query to perform search."
},
"site": {
"type": "string",
"description": "The domain of the site on which to search",
},
}
def _collect_params(self, site: str) -> dict:
return {
"siteSearch": site,
"siteSearchFilter": "i",
}