File size: 2,424 Bytes
f0df228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from smolagents import Tool
from googleapiclient.discovery import build
import os


class GoogleSearchTool(Tool):
    name = "web_search"
    description = "Performs a Google search and returns top results in markdown format."

    inputs = {
        "query": {
            "type": "string",
            "description": "Search string to query the web.",
        }
    }

    output_type = "string"
    skip_forward_signature_validation = True

    def __init__(self, api_key: str | None = None, search_engine_id: str | None = None, num_results: int = 10, **kwargs):
        from dotenv import load_dotenv
        load_dotenv()

        self.api_key = api_key or os.getenv("GOOGLE_SEARCH_API_KEY")
        self.search_engine_id = search_engine_id or os.getenv("GOOGLE_SEARCH_ENGINE_ID")

        if not self.api_key:
            raise EnvironmentError("GOOGLE_SEARCH_API_KEY is not configured.")
        if not self.search_engine_id:
            raise EnvironmentError("GOOGLE_SEARCH_ENGINE_ID is not configured.")

        self.search = build("customsearch", "v1", developerKey=self.api_key).cse()
        self.max_results = num_results

        super().__init__(**kwargs)

    def _collect_params(self) -> dict:
        return {}

    def forward(self, query: str, *args, **kwargs) -> str:
        params = {
            "q": query,
            "cx": self.search_engine_id,
            "fields": "items(link,title,snippet)",
            "num": self.max_results,
        }

        params.update(self._collect_params(*args, **kwargs))
        results = self.search.list(**params).execute()

        if "items" not in results:
            return "No search results found."

        return "\n\n".join(
            f"{item['title']}\n{item['link']}\n{item['snippet']}" for item in results["items"]
        )


class GoogleSiteSearchTool(GoogleSearchTool):
    name = "site_search"
    description = "Performs a Google search scoped to a specific domain, such as wikipedia.org or arxiv.org."

    inputs = {
        "query": {
            "type": "string",
            "description": "Search string to query.",
        },
        "site": {
            "type": "string",
            "description": "Domain to restrict search to (e.g., reddit.com, wikipedia.org).",
        },
    }

    def _collect_params(self, site: str) -> dict:
        return {
            "siteSearch": site,
            "siteSearchFilter": "i",
        }