Spaces:
Running
Running
| # | |
| # https://medium.com/@laurentkubaski/smolagents-duckduckgosearchtool-to-search-in-wikipedia-2578973bb131 | |
| # | |
| from __future__ import annotations | |
| import logging | |
| from typing import Any | |
| from urllib.parse import quote | |
| from ddgs import ddgs | |
| from ddgs.base import BaseSearchEngine | |
| from ddgs.results import TextResult | |
| # from ddgs.utils import json_loads | |
| from json import loads as json_loads | |
| from smolagents import DuckDuckGoSearchTool | |
| logger = logging.getLogger(__name__) | |
| class CustomWikipedia(BaseSearchEngine[TextResult]): | |
| """ | |
| A customized ddgs Wikipedia search engine that returns multiple results | |
| """ | |
| name = "wikipedia" | |
| category = "text" | |
| provider = "wikipedia" | |
| priority = 2 | |
| search_url = "https://{lang}.wikipedia.org/w/api.php?action=opensearch&search={query}" | |
| search_method = "GET" | |
| def build_payload( | |
| self, query: str, region: str, safesearch: str, timelimit: str | None, page: int = 1, **kwargs: Any | |
| ) -> dict[str, Any]: | |
| """ | |
| This is mostly a copy-paste of the original method where I've removed the "&limit=1" query parameter | |
| """ | |
| country, lang = region.lower().split("-") | |
| encoded_query = quote(query) | |
| #encoded_query = quote(query) | |
| self.search_url = ( | |
| # f"https://{lang}.wikipedia.org/w/api.php?action=opensearch&profile=fuzzy&limit=1&search={encoded_query}" | |
| f"https://{lang}.wikipedia.org/w/api.php?action=opensearch&profile=fuzzy&search={encoded_query}" | |
| ) | |
| payload: dict[str, Any] = {} | |
| self.lang = lang # used in extract_results | |
| return payload | |
| def extract_results(self, html_text: str) -> list[TextResult]: | |
| return self.extract_results_with_body(html_text) | |
| def extract_results_with_body(self, html_text: str) -> list[TextResult]: | |
| """ | |
| This is mostly a copy-paste of the original method except that I'm now looping over the results | |
| instead of just returning the first one | |
| """ | |
| json_data = json_loads(html_text) | |
| if not json_data[1]: | |
| return [] | |
| results = [] | |
| for title, href in zip(json_data[1], json_data[3]): | |
| result = TextResult() | |
| result.title = title | |
| result.href = href | |
| # Add body | |
| encoded_query = quote(result.title) | |
| resp_data = self.request( | |
| "GET", | |
| f"https://{self.lang}.wikipedia.org/w/api.php?action=query&format=json&prop=extracts&titles={encoded_query}&explaintext=0&exintro=0&redirects=1", | |
| ) | |
| if resp_data: | |
| page_json = json_loads(resp_data) | |
| try: | |
| result.body = list(page_json["query"]["pages"].values())[0]["extract"] | |
| except KeyError as ex: | |
| logger.warning(f"Error getting body from Wikipedia for title={result.title}: {ex}") | |
| if "may refer to:" not in result.body: | |
| results.append(result) | |
| return results | |
| class CustomDuckDuckGoSearchTool(DuckDuckGoSearchTool): | |
| """ | |
| A customized smolagents DuckDuckGoSearchTool that allows using a single search engine | |
| """ | |
| name = "web_search" | |
| description = "Performs a web search for a query and returns a list of the top search results formatted as markdown with page titles and urls." | |
| inputs = {"query": {"type": "string", "description": "The search query to perform."}} | |
| output_type = "string" | |
| def __init__(self, max_results: int = 10, rate_limit: float | None = 1.0, backend: str = "auto", **kwargs): | |
| super().__init__(max_results=max_results, rate_limit=rate_limit, **kwargs) | |
| self.backend = backend | |
| if backend == "wikipedia": | |
| ddgs.ENGINES["text"]["wikipedia"] = CustomWikipedia | |
| def forward(self, query: str) -> str: | |
| """ | |
| This is mostly a copy-paste of the original method where I'm adding the self.backend attribute | |
| when calling self.ddgs.text() | |
| """ | |
| self._enforce_rate_limit() | |
| results = self.ddgs.text( | |
| query=query, | |
| max_results=self.max_results, | |
| backend=self.backend) | |
| if len(results) == 0: | |
| raise Exception("No results found! Try a less restrictive/shorter query.") | |
| postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results] | |
| return "## Search Results\n\n" + "\n\n".join(postprocessed_results) | |
| if __name__ == "__main__": | |
| tool = CustomDuckDuckGoSearchTool( | |
| max_results=3, | |
| rate_limit=1.0, | |
| backend="wikipedia") | |
| result = tool( | |
| query='Leopard' | |
| ) | |
| print(result) |