|
|
import json |
|
|
import os |
|
|
import ssl |
|
|
import aiohttp |
|
|
import asyncio |
|
|
from agents import function_tool |
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Union, Optional |
|
|
from bs4 import BeautifulSoup |
|
|
from dotenv import load_dotenv |
|
|
from pydantic import BaseModel, Field |
|
|
from crawl4ai import * |
|
|
|
|
|
load_dotenv() |
|
|
CONTENT_LENGTH_LIMIT = 10000 |
|
|
SEARCH_PROVIDER = os.getenv("SEARCH_PROVIDER", "serper").lower() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScrapeResult(BaseModel): |
|
|
url: str = Field(description="The URL of the webpage") |
|
|
text: str = Field(description="The full text content of the webpage") |
|
|
title: str = Field(description="The title of the webpage") |
|
|
description: str = Field(description="A short description of the webpage") |
|
|
|
|
|
|
|
|
class WebpageSnippet(BaseModel): |
|
|
url: str = Field(description="The URL of the webpage") |
|
|
title: str = Field(description="The title of the webpage") |
|
|
description: Optional[str] = Field(description="A short description of the webpage") |
|
|
|
|
|
|
|
|
class SearchResults(BaseModel): |
|
|
results_list: List[WebpageSnippet] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_serper_client = None |
|
|
|
|
|
|
|
|
@function_tool |
|
|
async def web_search(query: str) -> Union[List[ScrapeResult], str]: |
|
|
"""Perform a web search for a given query and get back the URLs along with their titles, descriptions and text contents. |
|
|
|
|
|
Args: |
|
|
query: The search query |
|
|
|
|
|
Returns: |
|
|
List of ScrapeResult objects which have the following fields: |
|
|
- url: The URL of the search result |
|
|
- title: The title of the search result |
|
|
- description: The description of the search result |
|
|
- text: The full text content of the search result |
|
|
""" |
|
|
|
|
|
if SEARCH_PROVIDER == "openai": |
|
|
|
|
|
|
|
|
return f"The web_search function is not used when SEARCH_PROVIDER is set to 'openai'. Please check your configuration." |
|
|
else: |
|
|
try: |
|
|
|
|
|
global _serper_client |
|
|
if _serper_client is None: |
|
|
_serper_client = SerperClient() |
|
|
|
|
|
search_results = await _serper_client.search( |
|
|
query, filter_for_relevance=True, max_results=5 |
|
|
) |
|
|
results = await scrape_urls(search_results) |
|
|
return results |
|
|
except Exception as e: |
|
|
|
|
|
return f"Sorry, I encountered an error while searching: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FILTER_AGENT_INSTRUCTIONS = f""" |
|
|
You are a search result filter. Your task is to analyze a list of SERP search results and determine which ones are relevant |
|
|
to the original query based on the link, title and snippet. Return only the relevant results in the specified format. |
|
|
|
|
|
- Remove any results that refer to entities that have similar names to the queried entity, but are not the same. |
|
|
- E.g. if the query asks about a company "Amce Inc, acme.com", remove results with "acmesolutions.com" or "acme.net" in the link. |
|
|
|
|
|
Only output JSON. Follow the JSON schema below. Do not output anything else. I will be parsing this with Pydantic so output valid JSON only: |
|
|
{SearchResults.model_json_schema()} |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ssl_context = ssl.create_default_context() |
|
|
ssl_context.check_hostname = False |
|
|
ssl_context.verify_mode = ssl.CERT_NONE |
|
|
ssl_context.set_ciphers( |
|
|
"DEFAULT:@SECLEVEL=1" |
|
|
) |
|
|
|
|
|
|
|
|
class SerperClient: |
|
|
"""A client for the Serper API to perform Google searches.""" |
|
|
|
|
|
def __init__(self, api_key: str = None): |
|
|
self.api_key = api_key or os.getenv("SERPER_API_KEY") |
|
|
if not self.api_key: |
|
|
raise ValueError( |
|
|
"No API key provided. Set SERPER_API_KEY environment variable." |
|
|
) |
|
|
|
|
|
self.url = "https://google.serper.dev/search" |
|
|
self.headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"} |
|
|
|
|
|
async def search( |
|
|
self, query: str, filter_for_relevance: bool = True, max_results: int = 5 |
|
|
) -> List[WebpageSnippet]: |
|
|
"""Perform a Google search using Serper API and fetch basic details for top results. |
|
|
|
|
|
Args: |
|
|
query: The search query |
|
|
num_results: Maximum number of results to return (max 10) |
|
|
|
|
|
Returns: |
|
|
Dictionary with search results |
|
|
""" |
|
|
connector = aiohttp.TCPConnector(ssl=ssl_context) |
|
|
async with aiohttp.ClientSession(connector=connector) as session: |
|
|
async with session.post( |
|
|
self.url, headers=self.headers, json={"q": query, "autocorrect": False} |
|
|
) as response: |
|
|
response.raise_for_status() |
|
|
results = await response.json() |
|
|
results_list = [ |
|
|
WebpageSnippet( |
|
|
url=result.get("link", ""), |
|
|
title=result.get("title", ""), |
|
|
description=result.get("snippet", ""), |
|
|
) |
|
|
for result in results.get("organic", []) |
|
|
] |
|
|
|
|
|
if not results_list: |
|
|
return [] |
|
|
|
|
|
if not filter_for_relevance: |
|
|
return results_list[:max_results] |
|
|
|
|
|
|
|
|
|
|
|
return await self._filter_results(results_list, query, max_results=max_results) |
|
|
|
|
|
async def _filter_results( |
|
|
self, results: List[WebpageSnippet], query: str, max_results: int = 5 |
|
|
) -> List[WebpageSnippet]: |
|
|
|
|
|
filtered_results = [ |
|
|
res |
|
|
for res in results |
|
|
if "pmc.ncbi.nlm.nih.gov" not in res.url |
|
|
and "pubmed.ncbi.nlm.nih.gov" not in res.url |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def fetch_url(session, url): |
|
|
try: |
|
|
async with session.get(url, timeout=5) as response: |
|
|
return response.status == 200 |
|
|
except Exception as e: |
|
|
print(f"Error accessing {url}: {str(e)}") |
|
|
return False |
|
|
|
|
|
async def filter_unreachable_urls(results): |
|
|
async with aiohttp.ClientSession() as session: |
|
|
tasks = [fetch_url(session, res.url) for res in results] |
|
|
reachable = await asyncio.gather(*tasks) |
|
|
return [ |
|
|
res for res, can_access in zip(results, reachable) if can_access |
|
|
] |
|
|
|
|
|
reachable_results = await filter_unreachable_urls(filtered_results) |
|
|
|
|
|
|
|
|
return reachable_results[:max_results] |
|
|
|
|
|
|
|
|
async def scrape_urls(items: List[WebpageSnippet]) -> List[ScrapeResult]: |
|
|
"""Fetch text content from provided URLs. |
|
|
|
|
|
Args: |
|
|
items: List of SearchEngineResult items to extract content from |
|
|
|
|
|
Returns: |
|
|
List of ScrapeResult objects which have the following fields: |
|
|
- url: The URL of the search result |
|
|
- title: The title of the search result |
|
|
- description: The description of the search result |
|
|
- text: The full text content of the search result |
|
|
""" |
|
|
connector = aiohttp.TCPConnector(ssl=ssl_context) |
|
|
async with aiohttp.ClientSession(connector=connector) as session: |
|
|
|
|
|
tasks = [] |
|
|
for item in items: |
|
|
if item.url: |
|
|
tasks.append(fetch_and_process_url(session, item)) |
|
|
|
|
|
|
|
|
results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
|
|
|
|
|
|
return [r for r in results if isinstance(r, ScrapeResult)] |
|
|
|
|
|
|
|
|
async def fetch_and_process_url( |
|
|
session: aiohttp.ClientSession, item: WebpageSnippet |
|
|
) -> ScrapeResult: |
|
|
"""Helper function to fetch and process a single URL.""" |
|
|
|
|
|
if not is_valid_url(item.url): |
|
|
return ScrapeResult( |
|
|
url=item.url, |
|
|
title=item.title, |
|
|
description=item.description, |
|
|
text=f"Error fetching content: URL contains restricted file extension", |
|
|
) |
|
|
|
|
|
try: |
|
|
async with session.get(item.url, timeout=8) as response: |
|
|
if response.status == 200: |
|
|
content = await response.text() |
|
|
|
|
|
text_content = await asyncio.get_event_loop().run_in_executor( |
|
|
None, html_to_text, content |
|
|
) |
|
|
text_content = text_content[ |
|
|
:CONTENT_LENGTH_LIMIT |
|
|
] |
|
|
return ScrapeResult( |
|
|
url=item.url, |
|
|
title=item.title, |
|
|
description=item.description, |
|
|
text=text_content, |
|
|
) |
|
|
else: |
|
|
|
|
|
return ScrapeResult( |
|
|
url=item.url, |
|
|
title=item.title, |
|
|
description=item.description, |
|
|
text=f"Error fetching content: HTTP {response.status}", |
|
|
) |
|
|
except Exception as e: |
|
|
|
|
|
return ScrapeResult( |
|
|
url=item.url, |
|
|
title=item.title, |
|
|
description=item.description, |
|
|
text=f"Error fetching content: {str(e)}", |
|
|
) |
|
|
|
|
|
|
|
|
def html_to_text(html_content: str) -> str: |
|
|
""" |
|
|
Strips out all of the unnecessary elements from the HTML context to prepare it for text extraction / LLM processing. |
|
|
""" |
|
|
|
|
|
soup = BeautifulSoup(html_content, "lxml") |
|
|
|
|
|
|
|
|
tags_to_extract = ("h1", "h2", "h3", "h4", "h5", "h6", "p", "li", "blockquote") |
|
|
|
|
|
|
|
|
extracted_text = "\n".join( |
|
|
element.get_text(strip=True) |
|
|
for element in soup.find_all(tags_to_extract) |
|
|
if element.get_text(strip=True) |
|
|
) |
|
|
|
|
|
return extracted_text |
|
|
|
|
|
|
|
|
def is_valid_url(url: str) -> bool: |
|
|
"""Check that a URL does not contain restricted file extensions.""" |
|
|
if any( |
|
|
ext in url |
|
|
for ext in [ |
|
|
".pdf", |
|
|
".doc", |
|
|
".xls", |
|
|
".ppt", |
|
|
".zip", |
|
|
".rar", |
|
|
".7z", |
|
|
".txt", |
|
|
".js", |
|
|
".xml", |
|
|
".css", |
|
|
".png", |
|
|
".jpg", |
|
|
".jpeg", |
|
|
".gif", |
|
|
".ico", |
|
|
".svg", |
|
|
".webp", |
|
|
".mp3", |
|
|
".mp4", |
|
|
".avi", |
|
|
".mov", |
|
|
".wmv", |
|
|
".flv", |
|
|
".wma", |
|
|
".wav", |
|
|
".m4a", |
|
|
".m4v", |
|
|
".m4b", |
|
|
".m4p", |
|
|
".m4u", |
|
|
] |
|
|
): |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
async def url_to_contents(url): |
|
|
async with AsyncWebCrawler() as crawler: |
|
|
result = await crawler.arun( |
|
|
url=url, |
|
|
) |
|
|
|
|
|
|
|
|
return result.markdown |
|
|
|
|
|
|
|
|
async def url_to_fit_contents(res): |
|
|
|
|
|
str_fit_max = 40000 |
|
|
|
|
|
browser_config = BrowserConfig( |
|
|
headless=True, |
|
|
verbose=True, |
|
|
) |
|
|
run_config = CrawlerRunConfig( |
|
|
cache_mode=CacheMode.DISABLED, |
|
|
markdown_generator=DefaultMarkdownGenerator( |
|
|
content_filter=PruningContentFilter( |
|
|
threshold=1.0, threshold_type="fixed", min_word_threshold=0 |
|
|
) |
|
|
), |
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
try: |
|
|
async with AsyncWebCrawler(config=browser_config) as crawler: |
|
|
|
|
|
result = await asyncio.wait_for( |
|
|
crawler.arun(url=res.url, config=run_config), timeout=15 |
|
|
) |
|
|
print(f"char before filtering {len(result.markdown.raw_markdown)}.") |
|
|
print(f"char after filtering {len(result.markdown.fit_markdown)}.") |
|
|
return result.markdown.fit_markdown[ |
|
|
:str_fit_max |
|
|
] |
|
|
except asyncio.TimeoutError: |
|
|
print(f"Timeout occurred while accessing {res.url}.") |
|
|
return res.text[:str_fit_max] |
|
|
except Exception as e: |
|
|
print(f"Exception occurred: {str(e)}") |
|
|
return res.text[:str_fit_max] |
|
|
|