Gabandino commited on
Commit
a14ace8
·
verified ·
1 Parent(s): b04d151

Fix typo for final submission

Browse files
Files changed (1) hide show
  1. tools/google_search_tools.py +78 -78
tools/google_search_tools.py CHANGED
@@ -1,79 +1,79 @@
1
- from smolagents import Tool
2
- from googleapiclient.discovery import build
3
- import os
4
- class GoogleSearchTool(Tool):
5
- name = "web_search"
6
- description = """Performs a google web search for a query then returns top search results in markdown format."""
7
-
8
- inputs = {
9
- "query": {
10
- "type": "string",
11
- "description": "The query to perform a web search for"
12
- }
13
- }
14
- output_type = "string"
15
-
16
- skip_forward_signature_validation = True
17
-
18
- def __init__(
19
- self,
20
- api_key: str | None = None,
21
- search_engine_id: str | None = None,
22
- num_results: int = 10,
23
- **kwargs,
24
- ):
25
- from dotenv import load_dotenv
26
- load_dotenv()
27
- api_key = os.getenv("GOOGLE_API_KEY")
28
- search_engine_id = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
29
- if not api_key:
30
- raise ValueError("GOOGLE_API_KEY is not set")
31
- if not search_engine_id:
32
- raise ValueError("GOOGLE_SEARCH_ENGINE_ID is not set")
33
-
34
- self.cse = build(
35
- "customsearch",
36
- "v1",
37
- developerKey=api_key
38
- ).cse()
39
- self.cx = search_engine_id
40
- self.num = num_results
41
- super().__init__(**kwargs)
42
-
43
- def _collect_params(self) -> dict:
44
- return {}
45
-
46
- def forward(self, query: str, *args, **kwargs) -> str:
47
- params = {
48
- "q": query,
49
- "cx": self.cx,
50
- "fields": "items(link, title, snippet)",
51
- "num": self.num,
52
- }
53
-
54
- params = params | self._collect_params(*args, **kwargs)
55
- res = self.cse.list(**params).execute()
56
- if "items" not in res:
57
- return "No results found"
58
-
59
- return "\n\n".join(f"{item['title']}\n{item['link']}\n{item['snippet']}" for item in res["items"])
60
-
61
- class GoogleSiteSearchTool(GoogleSearchTool):
62
- name = "site_search"
63
- description = """Searches a specific website for a given query and returns the site contents in markdown format. Use when information is likely to be found on a particular domain, such as reddit.com, wikipedia.org, ieee.org, or arxiv.org."""
64
- inputs = {
65
- "query": {
66
- "type": "string",
67
- "description": "The query to perform search."
68
- },
69
- "site": {
70
- "type": "string",
71
- "description": "The domain of the site on which to search",
72
- },
73
- }
74
-
75
- def _collect_params(self, site: str) -> dict:
76
- return {
77
- "siteSearch": site,
78
- "siteSearchFilter": "i",
79
  }
 
1
+ from smolagents import Tool
2
+ from googleapiclient.discovery import build
3
+ import os
4
+ class GoogleSearchTool(Tool):
5
+ name = "web_search"
6
+ description = """Performs a google web search for a query then returns top search results in markdown format."""
7
+
8
+ inputs = {
9
+ "query": {
10
+ "type": "string",
11
+ "description": "The query to perform a web search for"
12
+ }
13
+ }
14
+ output_type = "string"
15
+
16
+ skip_forward_signature_validation = True
17
+
18
+ def __init__(
19
+ self,
20
+ api_key: str | None = None,
21
+ search_engine_id: str | None = None,
22
+ num_results: int = 10,
23
+ **kwargs,
24
+ ):
25
+ from dotenv import load_dotenv
26
+ load_dotenv()
27
+ api_key = os.getenv("GOOGLE_SEARCH_API_KEY")
28
+ search_engine_id = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
29
+ if not api_key:
30
+ raise ValueError("GOOGLE_SEARCH_API_KEY is not set")
31
+ if not search_engine_id:
32
+ raise ValueError("GOOGLE_SEARCH_ENGINE_ID is not set")
33
+
34
+ self.cse = build(
35
+ "customsearch",
36
+ "v1",
37
+ developerKey=api_key
38
+ ).cse()
39
+ self.cx = search_engine_id
40
+ self.num = num_results
41
+ super().__init__(**kwargs)
42
+
43
+ def _collect_params(self) -> dict:
44
+ return {}
45
+
46
+ def forward(self, query: str, *args, **kwargs) -> str:
47
+ params = {
48
+ "q": query,
49
+ "cx": self.cx,
50
+ "fields": "items(link, title, snippet)",
51
+ "num": self.num,
52
+ }
53
+
54
+ params = params | self._collect_params(*args, **kwargs)
55
+ res = self.cse.list(**params).execute()
56
+ if "items" not in res:
57
+ return "No results found"
58
+
59
+ return "\n\n".join(f"{item['title']}\n{item['link']}\n{item['snippet']}" for item in res["items"])
60
+
61
+ class GoogleSiteSearchTool(GoogleSearchTool):
62
+ name = "site_search"
63
+ description = """Searches a specific website for a given query and returns the site contents in markdown format. Use when information is likely to be found on a particular domain, such as reddit.com, wikipedia.org, ieee.org, or arxiv.org."""
64
+ inputs = {
65
+ "query": {
66
+ "type": "string",
67
+ "description": "The query to perform search."
68
+ },
69
+ "site": {
70
+ "type": "string",
71
+ "description": "The domain of the site on which to search",
72
+ },
73
+ }
74
+
75
+ def _collect_params(self, site: str) -> dict:
76
+ return {
77
+ "siteSearch": site,
78
+ "siteSearchFilter": "i",
79
  }