|
|
from typing import Any, Optional |
|
|
from smolagents.tools import Tool |
|
|
import html |
|
|
import xml |
|
|
import requests |
|
|
|
|
|
class WebSearchTool(Tool): |
|
|
name = "web_search" |
|
|
description = "Performs a web search for a query and returns a string of the top search results formatted as markdown with titles, links, and descriptions." |
|
|
inputs = {'query': {'type': 'string', 'description': 'The search query to perform.'}} |
|
|
output_type = "string" |
|
|
|
|
|
def __init__(self, max_results: int = 10, engine: str = "duckduckgo"): |
|
|
super().__init__() |
|
|
self.max_results = max_results |
|
|
self.engine = engine |
|
|
|
|
|
def forward(self, query: str) -> str: |
|
|
results = self.search(query) |
|
|
if len(results) == 0: |
|
|
raise Exception("No results found! Try a less restrictive/shorter query.") |
|
|
return self.parse_results(results) |
|
|
|
|
|
def search(self, query: str) -> list: |
|
|
if self.engine == "duckduckgo": |
|
|
return self.search_duckduckgo(query) |
|
|
elif self.engine == "bing": |
|
|
return self.search_bing(query) |
|
|
else: |
|
|
raise ValueError(f"Unsupported engine: {self.engine}") |
|
|
|
|
|
def parse_results(self, results: list) -> str: |
|
|
return "## Search Results\n\n" + "\n\n".join( |
|
|
[f"[{result['title']}]({result['link']})\n{result['description']}" for result in results] |
|
|
) |
|
|
|
|
|
def search_duckduckgo(self, query: str) -> list: |
|
|
import requests |
|
|
|
|
|
response = requests.get( |
|
|
"https://lite.duckduckgo.com/lite/", |
|
|
params={"q": query}, |
|
|
headers={"User-Agent": "Mozilla/5.0"}, |
|
|
) |
|
|
response.raise_for_status() |
|
|
parser = self._create_duckduckgo_parser() |
|
|
parser.feed(response.text) |
|
|
return parser.results |
|
|
|
|
|
def _create_duckduckgo_parser(self): |
|
|
from html.parser import HTMLParser |
|
|
|
|
|
class SimpleResultParser(HTMLParser): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.results = [] |
|
|
self.current = {} |
|
|
self.capture_title = False |
|
|
self.capture_description = False |
|
|
self.capture_link = False |
|
|
|
|
|
def handle_starttag(self, tag, attrs): |
|
|
attrs = dict(attrs) |
|
|
if tag == "a" and attrs.get("class") == "result-link": |
|
|
self.capture_title = True |
|
|
elif tag == "td" and attrs.get("class") == "result-snippet": |
|
|
self.capture_description = True |
|
|
elif tag == "span" and attrs.get("class") == "link-text": |
|
|
self.capture_link = True |
|
|
|
|
|
def handle_endtag(self, tag): |
|
|
if tag == "a" and self.capture_title: |
|
|
self.capture_title = False |
|
|
elif tag == "td" and self.capture_description: |
|
|
self.capture_description = False |
|
|
elif tag == "span" and self.capture_link: |
|
|
self.capture_link = False |
|
|
elif tag == "tr": |
|
|
|
|
|
if {"title", "description", "link"} <= self.current.keys(): |
|
|
self.current["description"] = " ".join(self.current["description"]) |
|
|
self.results.append(self.current) |
|
|
self.current = {} |
|
|
|
|
|
def handle_data(self, data): |
|
|
if self.capture_title: |
|
|
self.current["title"] = data.strip() |
|
|
elif self.capture_description: |
|
|
self.current.setdefault("description", []) |
|
|
self.current["description"].append(data.strip()) |
|
|
elif self.capture_link: |
|
|
self.current["link"] = "https://" + data.strip() |
|
|
|
|
|
return SimpleResultParser() |
|
|
|
|
|
def search_bing(self, query: str) -> list: |
|
|
import xml.etree.ElementTree as ET |
|
|
|
|
|
import requests |
|
|
|
|
|
response = requests.get( |
|
|
"https://www.bing.com/search", |
|
|
params={"q": query, "format": "rss"}, |
|
|
) |
|
|
response.raise_for_status() |
|
|
root = ET.fromstring(response.text) |
|
|
items = root.findall(".//item") |
|
|
results = [ |
|
|
{ |
|
|
"title": item.findtext("title"), |
|
|
"link": item.findtext("link"), |
|
|
"description": item.findtext("description"), |
|
|
} |
|
|
for item in items[: self.max_results] |
|
|
] |
|
|
return results |
|
|
|