rag_agent / agent /tools.py
Cheh Kit Hong
patched tavily to latest
c92c6b2
import json
from typing import List
from langchain_core.tools import tool
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_tavily import TavilySearch
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from dotenv import load_dotenv
from config import configs
load_dotenv()
def intialize_chroma_vectorstore():
"""Initialize and return the Chroma vector store."""
dense_embeddings = HuggingFaceEmbeddings(
model_name=configs["EMBEDDING_MODEL_NAME"]
)
vectorstore = Chroma(
persist_directory=configs["PERSIST_PATH"],
embedding_function=dense_embeddings,
collection_name=configs["COLLECTION_NAME"]
)
return vectorstore
@tool
def web_search_tavily(query: str) -> dict:
"""Search Tavily for a query and return up to 3 results.
Args:
query: The search query.
Returns:
dict with key 'web_results', containing a list of search results with
'source', 'page', and 'content'.
"""
try:
search_docs = TavilySearch(
max_results=3,
topic="general",
).invoke({"query": query})["results"]
results = [
{
"title": doc.get("title", ""),
"url": doc.get("url", ""),
"content": doc.get("content", ""),
}
for doc in search_docs
]
return {"web_results": results}
except Exception as e:
return {"web_results": f"Error retrieving results: {str(e)}"}
@tool
def wikipedia_search(query: str) -> dict:
"""Search Wikipedia for a query and return up to 3 results.
Args:
query: The search query.
Returns:
dict with key 'wiki_results', containing a list of search results with
'title', 'url', and 'snippet'.
"""
try:
search_docs = WikipediaLoader(query=query, load_max_docs=3).load()
results = [
{
"title": doc.metadata.get("title", ""),
"url": doc.metadata.get("url", ""),
"snippet": doc.page_content,
}
for doc in search_docs
]
return {"wiki_results": results}
except Exception as e:
return {"wiki_results": f"Error retrieving results: {str(e)}"}
@tool
def arxiv_search(query: str) -> dict:
"""Search Arxiv for a query and return up to 3 results.
Args:
query: The search query.
Returns:
dict with key 'arxiv_results', containing a list of search results with
'title', 'url', and 'snippet'.
"""
try:
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
results = [
{
"title": doc.metadata.get("title", ""),
"url": doc.metadata.get("url", ""),
"snippet": doc.page_content,
}
for doc in search_docs
]
return {"arxiv_results": results}
except Exception as e:
return {"arxiv_results": f"Error retrieving results: {str(e)}"}
if __name__ == "__main__":
pass