VuvanAn's picture
Upload folder using huggingface_hub
cc37925 verified
import os
import pickle
import shutil
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter, MarkdownHeaderTextSplitter
from langchain_community.retrievers import BM25Retriever
from typing import List
class TextRAG:
def __init__(self, embed_model, vectorstore_dir: str = None):
self.embed_model = embed_model
self.vector_store = None
self.docs = None
self.vectorstore_dir = vectorstore_dir
if os.path.isdir(self.vectorstore_dir) and os.listdir(self.vectorstore_dir):
print(f"Found existing vector store at '{self.vectorstore_dir}', loading...")
self.load_local(self.vectorstore_dir)
else:
print(f"Creating new vector store directory at '{self.vectorstore_dir}'")
os.makedirs(self.vectorstore_dir, exist_ok=True)
def _clear(self):
self.vector_store = None
self.docs = None
print("Cleared the vector store and documents from memory.")
if os.path.isdir(self.vectorstore_dir):
shutil.rmtree(self.vectorstore_dir)
print(f"Removed local vector store directory: {self.vectorstore_dir}")
os.makedirs(self.vectorstore_dir, exist_ok=True)
def consume(self, source_dir: str, file_type: List[str] = None, chunk_size: int = 1000, chunk_overlap: int = 100, chunk_method: str = "recursive"):
if file_type is None:
file_type = ["**/*.txt"]
all_documents = []
for pattern in file_type:
loader = DirectoryLoader(source_dir, glob=pattern, loader_cls=TextLoader, show_progress=True, use_multithreading=True)
all_documents.extend(loader.load())
if not all_documents:
print("No documents found for the specified file types.")
return
if chunk_method == "recursive":
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
elif chunk_method == "character":
text_splitter = CharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator="\n"
)
elif chunk_method == "markdown":
text_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=[("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3")],
return_each_line=False,
strip_headers=False
)
else:
raise ValueError(f"Unknown chunk_method: {chunk_method}")
if chunk_method == "markdown":
new_docs = text_splitter.split_text( "\n\n".join([doc.page_content for doc in all_documents]) )
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
new_docs = text_splitter.split_documents(new_docs)
else:
new_docs = text_splitter.split_documents(all_documents)
if self.vector_store is None:
self.docs = new_docs
self.vector_store = FAISS.from_documents(self.docs, self.embed_model)
print(f"Successfully consumed {len(all_documents)} documents, creating {len(self.docs)} chunks and a new vector store.")
else:
if self.docs is None:
self.docs = []
self.docs.extend(new_docs)
self.vector_store.add_documents(new_docs)
print(f"Successfully added {len(all_documents)} documents ({len(new_docs)} chunks) to the existing vector store.")
self.save_local()
def search(self, query: str, metric: str = "similarity", k: int = 5, threshold: float = None):
"""
metric = ['similarity', 'mmr', 'bm25']
"""
if self.vector_store is None:
raise ValueError("Vector store not initialized. Please run the 'consume' or 'load_local' method first.")
if metric == "similarity":
if threshold is not None:
results_with_scores = self.vector_store.similarity_search_with_score(query, k=k)
return [doc for doc, score in results_with_scores if score >= threshold]
else:
return self.vector_store.similarity_search(query, k=k)
elif metric == "mmr":
return self.vector_store.max_marginal_relevance_search(query, k=k)
elif metric == "bm25":
if self.docs is None:
raise ValueError("Documents not available. BM25 requires consumed or loaded documents.")
bm25_retriever = BM25Retriever.from_documents(self.docs)
return bm25_retriever.get_relevant_documents(query, k=k)
else:
raise ValueError(f"Unsupported metric: '{metric}'. Supported metrics are 'similarity', 'mmr', and 'bm25'.")
def save_local(self, folder_path: str = None):
if folder_path is None:
folder_path = self.vectorstore_dir
if self.vector_store is None or self.docs is None:
raise ValueError("Nothing to save. Please run 'consume' first.")
os.makedirs(folder_path, exist_ok=True)
self.vector_store.save_local(folder_path)
with open(os.path.join(folder_path, "docs.pkl"), "wb") as f:
pickle.dump(self.docs, f)
print(f"Successfully saved RAG state to {folder_path}")
def load_local(self, folder_path: str):
if not os.path.isdir(folder_path):
raise FileNotFoundError(f"Folder not found: {folder_path}")
try:
self.vector_store = FAISS.load_local(folder_path, self.embed_model, allow_dangerous_deserialization=True)
docs_path = os.path.join(folder_path, "docs.pkl")
if os.path.exists(docs_path):
with open(docs_path, "rb") as f:
self.docs = pickle.load(f)
else:
self.docs = None
print("Warning: docs.pkl not found. BM25 search will not be available.")
print(f"Successfully loaded RAG state from {folder_path}")
except Exception as e:
print(f"Could not load from {folder_path}. It might be empty or corrupted. Error: {e}")
if __name__ == '__main__':
from langchain_community.embeddings import SentenceTransformerEmbeddings
embedding_model = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vector_store_path = "rag_index"
# Clean up previous runs for a clean demonstration
if os.path.exists(vector_store_path):
shutil.rmtree(vector_store_path)
# --- First run: Create and save the index ---
# print("--- Initializing first RAG instance ---")
# rag_system_1 = TextRAG(embed_model=embedding_model, vectorstore_dir=vector_store_path)
# rag_system_1.consume(source_dir="sample_data", file_type=["**/*.txt", "**/*.md"])
# rag_system_1.save_local()