| from smolagents import Tool |
| from googleapiclient.discovery import build |
| import os |
|
|
|
|
| class GoogleSearchTool(Tool): |
| name = "web_search" |
| description = """Performs a google web search for query then returns top search results in markdown format.""" |
| inputs = { |
| "query": { |
| "type": "string", |
| "description": "The query to perform search.", |
| }, |
| } |
| output_type = "string" |
|
|
| skip_forward_signature_validation = True |
|
|
| def __init__( |
| self, |
| api_key: str | None = None, |
| search_engine_id: str | None = None, |
| num_results: int = 10, |
| **kwargs, |
| ): |
| api_key = api_key if api_key is not None else os.getenv("GOOGLE_SEARCH_API_KEY") |
| if not api_key: |
| raise ValueError( |
| "Please set the GOOGLE_SEARCH_API_KEY environment variable." |
| ) |
| search_engine_id = ( |
| search_engine_id |
| if search_engine_id is not None |
| else os.getenv("GOOGLE_SEARCH_ENGINE_ID") |
| ) |
| if not search_engine_id: |
| raise ValueError( |
| "Please set the GOOGLE_SEARCH_ENGINE_ID environment variable." |
| ) |
|
|
| self.cse = build("customsearch", "v1", developerKey=api_key).cse() |
| self.cx = search_engine_id |
| self.num = num_results |
| super().__init__(**kwargs) |
|
|
| def _collect_params(self) -> dict: |
| return {} |
|
|
| def forward(self, query: str, *args, **kwargs) -> str: |
| params = { |
| "q": query, |
| "cx": self.cx, |
| "fields": "items(title,link,snippet)", |
| "num": self.num, |
| } |
|
|
| params = params | self._collect_params(*args, **kwargs) |
|
|
| response = self.cse.list(**params).execute() |
| if "items" not in response: |
| return "No results found." |
|
|
| result = "\n\n".join( |
| [ |
| f"[{item['title']}]({item['link']})\n{item['snippet']}" |
| for item in response["items"] |
| ] |
| ) |
| return result |
|
|
|
|
| class GoogleSiteSearchTool(GoogleSearchTool): |
| name = "site_search" |
| description = """Performs a google search within the website for query then returns top search results in markdown format.""" |
| inputs = { |
| "query": { |
| "type": "string", |
| "description": "The query to perform search.", |
| }, |
| "site": { |
| "type": "string", |
| "description": "The domain of the site on which to search.", |
| }, |
| } |
|
|
| def _collect_params(self, site: str) -> dict: |
| return { |
| "siteSearch": site, |
| "siteSearchFilter": "i", |
| } |
|
|