Shane commited on
Commit
5acc791
·
1 Parent(s): 9778322

refactoring with help from Claude to remove duplicated code

Browse files
Files changed (1) hide show
  1. tools/web_search.py +46 -31
tools/web_search.py CHANGED
@@ -1,7 +1,31 @@
1
- from typing import Any, Optional
2
  from smolagents.tools import Tool
3
  import duckduckgo_search
4
  import googlesearch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  class DuckDuckGoSearchTool(Tool):
7
  name = "web_search"
@@ -12,20 +36,19 @@ class DuckDuckGoSearchTool(Tool):
12
  def __init__(self, max_results=10, **kwargs):
13
  super().__init__()
14
  self.max_results = max_results
15
- try:
16
- from duckduckgo_search import DDGS
17
- except ImportError as e:
18
- raise ImportError(
19
- "You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
20
- ) from e
21
- self.ddgs = DDGS(**kwargs)
22
 
 
23
  def forward(self, query: str) -> str:
24
  results = self.ddgs.text(query, max_results=self.max_results)
25
- if len(results) == 0:
26
- raise Exception("No results found! Try a less restrictive/shorter query.")
27
- postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
28
- return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
 
29
 
30
  class GoogleSearchTool(Tool):
31
  name = "google_search"
@@ -36,24 +59,16 @@ class GoogleSearchTool(Tool):
36
  def __init__(self, max_results=10, **kwargs):
37
  super().__init__()
38
  self.max_results = max_results
39
- try:
40
- from googlesearch import search
41
- except ImportError as e:
42
- raise ImportError(
43
- "You must install package `googlesearch-python` to run this tool: for instance run `pip install googlesearch-python`."
44
- ) from e
45
- self.search = search
46
 
 
47
  def forward(self, query: str) -> str:
48
- results = []
49
- try:
50
- # Get search results (URLs only)
51
- search_results = list(self.search(query, num_results=self.max_results))
52
-
53
- # If no results found
54
- if len(search_results) == 0:
55
- raise Exception("No results found! Try a less restrictive/shorter query.")
56
-
57
- return "## Search Results\n\n" + "\n\n".join([f"[Result {i+1}]({url})\n{url}" for i, url in enumerate(search_results)])
58
- except Exception as e:
59
- return f"Error performing search: {str(e)}"
 
1
+ from typing import Any, Optional, Callable, List
2
  from smolagents.tools import Tool
3
  import duckduckgo_search
4
  import googlesearch
5
+ from functools import wraps
6
+
7
+ def setup_search_dependency(package_name: str, import_func: Callable):
8
+ """Utility function to handle search dependency setup"""
9
+ try:
10
+ return import_func()
11
+ except ImportError as e:
12
+ raise ImportError(
13
+ f"You must install package `{package_name}` to run this tool: "
14
+ f"for instance run `pip install {package_name}`."
15
+ ) from e
16
+
17
+ def handle_search_errors(func):
18
+ """Decorator to handle common search error cases"""
19
+ @wraps(func)
20
+ def wrapper(*args, **kwargs):
21
+ try:
22
+ results = func(*args, **kwargs)
23
+ if not results:
24
+ raise Exception("No results found! Try a less restrictive/shorter query.")
25
+ return results
26
+ except Exception as e:
27
+ return f"Error performing search: {str(e)}"
28
+ return wrapper
29
 
30
  class DuckDuckGoSearchTool(Tool):
31
  name = "web_search"
 
36
  def __init__(self, max_results=10, **kwargs):
37
  super().__init__()
38
  self.max_results = max_results
39
+ self.ddgs = setup_search_dependency(
40
+ 'duckduckgo-search',
41
+ lambda: duckduckgo_search.DDGS(**kwargs)
42
+ )
 
 
 
43
 
44
+ @handle_search_errors
45
  def forward(self, query: str) -> str:
46
  results = self.ddgs.text(query, max_results=self.max_results)
47
+ formatted_results = [
48
+ f"[{result['title']}]({result['href']})\n{result['body']}"
49
+ for result in results
50
+ ]
51
+ return "## Search Results\n\n" + "\n\n".join(formatted_results)
52
 
53
  class GoogleSearchTool(Tool):
54
  name = "google_search"
 
59
  def __init__(self, max_results=10, **kwargs):
60
  super().__init__()
61
  self.max_results = max_results
62
+ self.search = setup_search_dependency(
63
+ 'googlesearch-python',
64
+ lambda: googlesearch.search
65
+ )
 
 
 
66
 
67
+ @handle_search_errors
68
  def forward(self, query: str) -> str:
69
+ search_results = list(self.search(query, num_results=self.max_results))
70
+ formatted_results = [
71
+ f"[Result {i+1}]({url})\n{url}"
72
+ for i, url in enumerate(search_results)
73
+ ]
74
+ return "## Search Results\n\n" + "\n\n".join(formatted_results)