FinalAssignment-AliA / search_tools.py
AliA1997
Completed Final Assignment for Huggingface Agents Course
a6dbfdf
import os
# import chromadb
from dotenv import load_dotenv
from langchain_core.tools import tool
from langchain_tavily import TavilySearch
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.document_loaders import ArxivLoader
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from supabase.client import Client, create_client
from langchain_core.tools import create_retriever_tool
load_dotenv()
@tool
def wiki_search(input: str) -> str:
"""Search Wikipedia for a query and return maximum 2 results.
Args:
query: The search query."""
search_docs = WikipediaLoader(query=input, load_max_docs=2).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
])
return formatted_search_docs
@tool
def web_search(input: str) -> str:
"""Search Tavily for a query and return maximum 3 results."""
results = TavilySearch(max_results=3).invoke(input)
formatted_items = []
for item in results:
# Case 1: item is a dict (new Tavily format)
if isinstance(item, dict):
url = item.get("url", "")
content = item.get("content", "")
formatted_items.append(
f'<Document source="{url}"/>\n{content}\n</Document>'
)
# Case 2: item is a string (fallback format)
else:
formatted_items.append(
f'<Document source=""/>\n{str(item)}\n</Document>'
)
return "\n\n---\n\n".join(formatted_items)
@tool
def arvix_search(input: str) -> str:
"""Search Arxiv for a query and return maximum 3 result.
Args:
query: The search query."""
search_docs = ArxivLoader(query=input, load_max_docs=3).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
for doc in search_docs
])
return formatted_search_docs
# Build embeddings
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-mpnet-base-v2"
)
# Connect to Supabase
supabase_url = os.environ["SUPABASE_URL"]
supabase_service_key = os.environ["SUPABASE_SERVICE_KEY"]
supabase = create_client(supabase_url, supabase_service_key)
# Create Supabase vector store
vector_store = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
table_name="documents", # your table
query_name="match_documents_langchain" # your RPC function
)
# Convert to retriever
retriever = vector_store.as_retriever()
@tool
def question_search(input: str):
"""Retrieve similar questions from Supabase vector store."""
docs = retriever.invoke(input)
return "\n\n".join([d.page_content for d in docs])