Spaces:
Sleeping
Sleeping
File size: 2,424 Bytes
f0df228 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from smolagents import Tool
from googleapiclient.discovery import build
import os
class GoogleSearchTool(Tool):
name = "web_search"
description = "Performs a Google search and returns top results in markdown format."
inputs = {
"query": {
"type": "string",
"description": "Search string to query the web.",
}
}
output_type = "string"
skip_forward_signature_validation = True
def __init__(self, api_key: str | None = None, search_engine_id: str | None = None, num_results: int = 10, **kwargs):
from dotenv import load_dotenv
load_dotenv()
self.api_key = api_key or os.getenv("GOOGLE_SEARCH_API_KEY")
self.search_engine_id = search_engine_id or os.getenv("GOOGLE_SEARCH_ENGINE_ID")
if not self.api_key:
raise EnvironmentError("GOOGLE_SEARCH_API_KEY is not configured.")
if not self.search_engine_id:
raise EnvironmentError("GOOGLE_SEARCH_ENGINE_ID is not configured.")
self.search = build("customsearch", "v1", developerKey=self.api_key).cse()
self.max_results = num_results
super().__init__(**kwargs)
def _collect_params(self) -> dict:
return {}
def forward(self, query: str, *args, **kwargs) -> str:
params = {
"q": query,
"cx": self.search_engine_id,
"fields": "items(link,title,snippet)",
"num": self.max_results,
}
params.update(self._collect_params(*args, **kwargs))
results = self.search.list(**params).execute()
if "items" not in results:
return "No search results found."
return "\n\n".join(
f"{item['title']}\n{item['link']}\n{item['snippet']}" for item in results["items"]
)
class GoogleSiteSearchTool(GoogleSearchTool):
name = "site_search"
description = "Performs a Google search scoped to a specific domain, such as wikipedia.org or arxiv.org."
inputs = {
"query": {
"type": "string",
"description": "Search string to query.",
},
"site": {
"type": "string",
"description": "Domain to restrict search to (e.g., reddit.com, wikipedia.org).",
},
}
def _collect_params(self, site: str) -> dict:
return {
"siteSearch": site,
"siteSearchFilter": "i",
} |