Spaces:
Sleeping
Sleeping
| from langchain_chroma import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_core.tools import tool | |
| from agent_graph.load_tools_config import LoadToolsConfig | |
| TOOLS_CFG = LoadToolsConfig() | |
| class StoriesRAGTool: | |
| """ | |
| A tool for retrieving relevant stories using a Retrieval-Augmented Generation (RAG) approach with vector embeddings. | |
| This tool leverages a pre-trained Hugging Face embedding model to transform user queries into vector embeddings. | |
| It then uses these embeddings to query a Chroma-based vector database to retrieve the top-k most relevant | |
| stories from a specific collection stored in the database. | |
| Attributes: | |
| embedding_model (str): The name of the Hugging Face embedding model used for generating vector representations of queries. | |
| vectordb_dir (str): The directory where the Chroma vector database is persisted on disk. | |
| k (int): The number of top-k nearest neighbor stories to retrieve from the vector database. | |
| vectordb (Chroma): The Chroma vector database instance connected to the specified collection and embedding model. | |
| Methods: | |
| __init__: Initializes the tool with the specified embedding model, vector database, and retrieval parameters. | |
| """ | |
| def __init__(self, embedding_model: str, vectordb_dir: str, k: int, collection_name: str) -> None: | |
| """ | |
| Initializes the StoriesRAGTool with the necessary configurations. | |
| Args: | |
| embedding_model (str): The name of the embedding model (e.g., "all-MiniLM-L6-v2") | |
| used to convert queries into vector representations. | |
| vectordb_dir (str): The directory path where the Chroma vector database is stored and persisted on disk. | |
| k (int): The number of nearest neighbor stories to retrieve based on query similarity. | |
| collection_name (str): The name of the collection inside the vector database that holds the relevant stories. | |
| """ | |
| self.embedding_model = embedding_model | |
| self.vectordb_dir = vectordb_dir | |
| self.k = k | |
| self.vectordb = Chroma( | |
| collection_name=collection_name, | |
| persist_directory=self.vectordb_dir, | |
| embedding_function=HuggingFaceEmbeddings(model_name=self.embedding_model) | |
| ) | |
| print("Number of vectors in vectordb:", | |
| self.vectordb._collection.count(), "\n\n") | |
| def lookup_stories(query: str) -> str: | |
| """Search among the fictional stories and find the answer to the query. Input should be the query.""" | |
| rag_tool = StoriesRAGTool( | |
| embedding_model=TOOLS_CFG.stories_rag_embedding_model, | |
| vectordb_dir=TOOLS_CFG.stories_rag_vectordb_directory, | |
| k=TOOLS_CFG.stories_rag_k, | |
| collection_name=TOOLS_CFG.stories_rag_collection_name) | |
| docs = rag_tool.vectordb.similarity_search(query, k=rag_tool.k) | |
| return "\n\n".join([doc.page_content for doc in docs]) | |