FD900 commited on
Commit
f0df228
·
verified ·
1 Parent(s): ccbae19

Update tools/google_search_tool.py

Browse files
Files changed (1) hide show
  1. tools/google_search_tool.py +78 -0
tools/google_search_tool.py CHANGED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ from googleapiclient.discovery import build
3
+ import os
4
+
5
+
6
+ class GoogleSearchTool(Tool):
7
+ name = "web_search"
8
+ description = "Performs a Google search and returns top results in markdown format."
9
+
10
+ inputs = {
11
+ "query": {
12
+ "type": "string",
13
+ "description": "Search string to query the web.",
14
+ }
15
+ }
16
+
17
+ output_type = "string"
18
+ skip_forward_signature_validation = True
19
+
20
+ def __init__(self, api_key: str | None = None, search_engine_id: str | None = None, num_results: int = 10, **kwargs):
21
+ from dotenv import load_dotenv
22
+ load_dotenv()
23
+
24
+ self.api_key = api_key or os.getenv("GOOGLE_SEARCH_API_KEY")
25
+ self.search_engine_id = search_engine_id or os.getenv("GOOGLE_SEARCH_ENGINE_ID")
26
+
27
+ if not self.api_key:
28
+ raise EnvironmentError("GOOGLE_SEARCH_API_KEY is not configured.")
29
+ if not self.search_engine_id:
30
+ raise EnvironmentError("GOOGLE_SEARCH_ENGINE_ID is not configured.")
31
+
32
+ self.search = build("customsearch", "v1", developerKey=self.api_key).cse()
33
+ self.max_results = num_results
34
+
35
+ super().__init__(**kwargs)
36
+
37
+ def _collect_params(self) -> dict:
38
+ return {}
39
+
40
+ def forward(self, query: str, *args, **kwargs) -> str:
41
+ params = {
42
+ "q": query,
43
+ "cx": self.search_engine_id,
44
+ "fields": "items(link,title,snippet)",
45
+ "num": self.max_results,
46
+ }
47
+
48
+ params.update(self._collect_params(*args, **kwargs))
49
+ results = self.search.list(**params).execute()
50
+
51
+ if "items" not in results:
52
+ return "No search results found."
53
+
54
+ return "\n\n".join(
55
+ f"{item['title']}\n{item['link']}\n{item['snippet']}" for item in results["items"]
56
+ )
57
+
58
+
59
+ class GoogleSiteSearchTool(GoogleSearchTool):
60
+ name = "site_search"
61
+ description = "Performs a Google search scoped to a specific domain, such as wikipedia.org or arxiv.org."
62
+
63
+ inputs = {
64
+ "query": {
65
+ "type": "string",
66
+ "description": "Search string to query.",
67
+ },
68
+ "site": {
69
+ "type": "string",
70
+ "description": "Domain to restrict search to (e.g., reddit.com, wikipedia.org).",
71
+ },
72
+ }
73
+
74
+ def _collect_params(self, site: str) -> dict:
75
+ return {
76
+ "siteSearch": site,
77
+ "siteSearchFilter": "i",
78
+ }