QueryMind / src /agent_graph /tool_stories_rag.py
7beshoyarnest's picture
Update src/agent_graph/tool_stories_rag.py
58f8537 verified
Raw
History Blame Contribute Delete
2.96 kB
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.tools import tool
from agent_graph.load_tools_config import LoadToolsConfig
TOOLS_CFG = LoadToolsConfig()
class StoriesRAGTool:
"""
A tool for retrieving relevant stories using a Retrieval-Augmented Generation (RAG) approach with vector embeddings.
This tool leverages a pre-trained Hugging Face embedding model to transform user queries into vector embeddings.
It then uses these embeddings to query a Chroma-based vector database to retrieve the top-k most relevant
stories from a specific collection stored in the database.
Attributes:
embedding_model (str): The name of the Hugging Face embedding model used for generating vector representations of queries.
vectordb_dir (str): The directory where the Chroma vector database is persisted on disk.
k (int): The number of top-k nearest neighbor stories to retrieve from the vector database.
vectordb (Chroma): The Chroma vector database instance connected to the specified collection and embedding model.
Methods:
__init__: Initializes the tool with the specified embedding model, vector database, and retrieval parameters.
"""
def __init__(self, embedding_model: str, vectordb_dir: str, k: int, collection_name: str) -> None:
"""
Initializes the StoriesRAGTool with the necessary configurations.
Args:
embedding_model (str): The name of the embedding model (e.g., "all-MiniLM-L6-v2")
used to convert queries into vector representations.
vectordb_dir (str): The directory path where the Chroma vector database is stored and persisted on disk.
k (int): The number of nearest neighbor stories to retrieve based on query similarity.
collection_name (str): The name of the collection inside the vector database that holds the relevant stories.
"""
self.embedding_model = embedding_model
self.vectordb_dir = vectordb_dir
self.k = k
self.vectordb = Chroma(
collection_name=collection_name,
persist_directory=self.vectordb_dir,
embedding_function=HuggingFaceEmbeddings(model_name=self.embedding_model)
)
print("Number of vectors in vectordb:",
self.vectordb._collection.count(), "\n\n")
@tool
def lookup_stories(query: str) -> str:
"""Search among the fictional stories and find the answer to the query. Input should be the query."""
rag_tool = StoriesRAGTool(
embedding_model=TOOLS_CFG.stories_rag_embedding_model,
vectordb_dir=TOOLS_CFG.stories_rag_vectordb_directory,
k=TOOLS_CFG.stories_rag_k,
collection_name=TOOLS_CFG.stories_rag_collection_name)
docs = rag_tool.vectordb.similarity_search(query, k=rag_tool.k)
return "\n\n".join([doc.page_content for doc in docs])