FD900 commited on
Commit
fa39ad6
·
verified ·
1 Parent(s): 06f1955

Update tools/google_search_tool.py

Browse files
Files changed (1) hide show
  1. tools/google_search_tool.py +32 -32
tools/google_search_tool.py CHANGED
@@ -1,36 +1,37 @@
1
- from smolagents import Tool
2
  from googleapiclient.discovery import build
3
  import os
4
 
5
-
6
- class GoogleSearchTool(Tool):
7
  name = "web_search"
8
- description = "Performs a Google search and returns top results in markdown format."
9
 
10
  inputs = {
11
  "query": {
12
  "type": "string",
13
- "description": "Search string to query the web.",
14
  }
15
  }
16
 
17
  output_type = "string"
18
  skip_forward_signature_validation = True
19
 
20
- def __init__(self, api_key: str | None = None, search_engine_id: str | None = None, num_results: int = 10, **kwargs):
21
- from dotenv import load_dotenv
22
- load_dotenv()
23
-
24
- self.api_key = api_key or os.getenv("GOOGLE_SEARCH_API_KEY")
25
- self.search_engine_id = search_engine_id or os.getenv("GOOGLE_SEARCH_ENGINE_ID")
 
 
 
26
 
27
- if not self.api_key:
28
- raise EnvironmentError("GOOGLE_SEARCH_API_KEY is not configured.")
29
- if not self.search_engine_id:
30
- raise EnvironmentError("GOOGLE_SEARCH_ENGINE_ID is not configured.")
31
 
32
- self.search = build("customsearch", "v1", developerKey=self.api_key).cse()
33
- self.max_results = num_results
 
34
 
35
  super().__init__(**kwargs)
36
 
@@ -38,41 +39,40 @@ class GoogleSearchTool(Tool):
38
  return {}
39
 
40
  def forward(self, query: str, *args, **kwargs) -> str:
41
- params = {
42
  "q": query,
43
- "cx": self.search_engine_id,
44
  "fields": "items(link,title,snippet)",
45
- "num": self.max_results,
46
  }
 
47
 
48
- params.update(self._collect_params(*args, **kwargs))
49
- results = self.search.list(**params).execute()
50
-
51
- if "items" not in results:
52
- return "No search results found."
53
 
54
- return "\n\n".join(
55
- f"{item['title']}\n{item['link']}\n{item['snippet']}" for item in results["items"]
56
- )
57
 
 
 
58
 
59
  class GoogleSiteSearchTool(GoogleSearchTool):
60
  name = "site_search"
61
- description = "Performs a Google search scoped to a specific domain, such as wikipedia.org or arxiv.org."
62
 
63
  inputs = {
64
  "query": {
65
  "type": "string",
66
- "description": "Search string to query.",
67
  },
68
  "site": {
69
  "type": "string",
70
- "description": "Domain to restrict search to (e.g., reddit.com, wikipedia.org).",
71
  },
72
  }
73
 
74
  def _collect_params(self, site: str) -> dict:
75
  return {
76
  "siteSearch": site,
77
- "siteSearchFilter": "i",
78
  }
 
1
+ from tools.base_tool import BaseTool
2
  from googleapiclient.discovery import build
3
  import os
4
 
5
+ class GoogleSearchTool(BaseTool):
 
6
  name = "web_search"
7
+ description = "Performs a Google web search for a query and returns the top results in markdown format."
8
 
9
  inputs = {
10
  "query": {
11
  "type": "string",
12
+ "description": "Search query for Google."
13
  }
14
  }
15
 
16
  output_type = "string"
17
  skip_forward_signature_validation = True
18
 
19
+ def __init__(
20
+ self,
21
+ api_key: str | None = None,
22
+ search_engine_id: str | None = None,
23
+ num_results: int = 10,
24
+ **kwargs,
25
+ ):
26
+ api_key = api_key or os.getenv("GOOGLE_SEARCH_API_KEY")
27
+ search_engine_id = search_engine_id or os.getenv("GOOGLE_SEARCH_ENGINE_ID")
28
 
29
+ if not api_key or not search_engine_id:
30
+ raise EnvironmentError("Missing Google Search API credentials.")
 
 
31
 
32
+ self.search_service = build("customsearch", "v1", developerKey=api_key).cse()
33
+ self.engine_id = search_engine_id
34
+ self.num_results = num_results
35
 
36
  super().__init__(**kwargs)
37
 
 
39
  return {}
40
 
41
  def forward(self, query: str, *args, **kwargs) -> str:
42
+ search_args = {
43
  "q": query,
44
+ "cx": self.engine_id,
45
  "fields": "items(link,title,snippet)",
46
+ "num": self.num_results,
47
  }
48
+ search_args.update(self._collect_params(*args, **kwargs))
49
 
50
+ response = self.search_service.list(**search_args).execute()
51
+ items = response.get("items", [])
 
 
 
52
 
53
+ if not items:
54
+ return "No results found."
 
55
 
56
+ results = [f"**{item['title']}**\n{item['link']}\n{item['snippet']}" for item in items]
57
+ return "\n\n".join(results)
58
 
59
  class GoogleSiteSearchTool(GoogleSearchTool):
60
  name = "site_search"
61
+ description = "Searches a specific domain for a query using Google and returns results."
62
 
63
  inputs = {
64
  "query": {
65
  "type": "string",
66
+ "description": "Search query."
67
  },
68
  "site": {
69
  "type": "string",
70
+ "description": "Domain to restrict search to."
71
  },
72
  }
73
 
74
  def _collect_params(self, site: str) -> dict:
75
  return {
76
  "siteSearch": site,
77
+ "siteSearchFilter": "i"
78
  }