import os import pickle import shutil from langchain_community.vectorstores import FAISS from langchain_community.document_loaders import DirectoryLoader, TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter, MarkdownHeaderTextSplitter from langchain_community.retrievers import BM25Retriever from typing import List class TextRAG: def __init__(self, embed_model, vectorstore_dir: str = None): self.embed_model = embed_model self.vector_store = None self.docs = None self.vectorstore_dir = vectorstore_dir if os.path.isdir(self.vectorstore_dir) and os.listdir(self.vectorstore_dir): print(f"Found existing vector store at '{self.vectorstore_dir}', loading...") self.load_local(self.vectorstore_dir) else: print(f"Creating new vector store directory at '{self.vectorstore_dir}'") os.makedirs(self.vectorstore_dir, exist_ok=True) def _clear(self): self.vector_store = None self.docs = None print("Cleared the vector store and documents from memory.") if os.path.isdir(self.vectorstore_dir): shutil.rmtree(self.vectorstore_dir) print(f"Removed local vector store directory: {self.vectorstore_dir}") os.makedirs(self.vectorstore_dir, exist_ok=True) def consume(self, source_dir: str, file_type: List[str] = None, chunk_size: int = 1000, chunk_overlap: int = 100, chunk_method: str = "recursive"): if file_type is None: file_type = ["**/*.txt"] all_documents = [] for pattern in file_type: loader = DirectoryLoader(source_dir, glob=pattern, loader_cls=TextLoader, show_progress=True, use_multithreading=True) all_documents.extend(loader.load()) if not all_documents: print("No documents found for the specified file types.") return if chunk_method == "recursive": text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap ) elif chunk_method == "character": text_splitter = CharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator="\n" ) elif chunk_method == "markdown": text_splitter = MarkdownHeaderTextSplitter( headers_to_split_on=[("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3")], return_each_line=False, strip_headers=False ) else: raise ValueError(f"Unknown chunk_method: {chunk_method}") if chunk_method == "markdown": new_docs = text_splitter.split_text( "\n\n".join([doc.page_content for doc in all_documents]) ) text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap ) new_docs = text_splitter.split_documents(new_docs) else: new_docs = text_splitter.split_documents(all_documents) if self.vector_store is None: self.docs = new_docs self.vector_store = FAISS.from_documents(self.docs, self.embed_model) print(f"Successfully consumed {len(all_documents)} documents, creating {len(self.docs)} chunks and a new vector store.") else: if self.docs is None: self.docs = [] self.docs.extend(new_docs) self.vector_store.add_documents(new_docs) print(f"Successfully added {len(all_documents)} documents ({len(new_docs)} chunks) to the existing vector store.") self.save_local() def search(self, query: str, metric: str = "similarity", k: int = 5, threshold: float = None): """ metric = ['similarity', 'mmr', 'bm25'] """ if self.vector_store is None: raise ValueError("Vector store not initialized. Please run the 'consume' or 'load_local' method first.") if metric == "similarity": if threshold is not None: results_with_scores = self.vector_store.similarity_search_with_score(query, k=k) return [doc for doc, score in results_with_scores if score >= threshold] else: return self.vector_store.similarity_search(query, k=k) elif metric == "mmr": return self.vector_store.max_marginal_relevance_search(query, k=k) elif metric == "bm25": if self.docs is None: raise ValueError("Documents not available. BM25 requires consumed or loaded documents.") bm25_retriever = BM25Retriever.from_documents(self.docs) return bm25_retriever.get_relevant_documents(query, k=k) else: raise ValueError(f"Unsupported metric: '{metric}'. Supported metrics are 'similarity', 'mmr', and 'bm25'.") def save_local(self, folder_path: str = None): if folder_path is None: folder_path = self.vectorstore_dir if self.vector_store is None or self.docs is None: raise ValueError("Nothing to save. Please run 'consume' first.") os.makedirs(folder_path, exist_ok=True) self.vector_store.save_local(folder_path) with open(os.path.join(folder_path, "docs.pkl"), "wb") as f: pickle.dump(self.docs, f) print(f"Successfully saved RAG state to {folder_path}") def load_local(self, folder_path: str): if not os.path.isdir(folder_path): raise FileNotFoundError(f"Folder not found: {folder_path}") try: self.vector_store = FAISS.load_local(folder_path, self.embed_model, allow_dangerous_deserialization=True) docs_path = os.path.join(folder_path, "docs.pkl") if os.path.exists(docs_path): with open(docs_path, "rb") as f: self.docs = pickle.load(f) else: self.docs = None print("Warning: docs.pkl not found. BM25 search will not be available.") print(f"Successfully loaded RAG state from {folder_path}") except Exception as e: print(f"Could not load from {folder_path}. It might be empty or corrupted. Error: {e}") if __name__ == '__main__': from langchain_community.embeddings import SentenceTransformerEmbeddings embedding_model = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") vector_store_path = "rag_index" # Clean up previous runs for a clean demonstration if os.path.exists(vector_store_path): shutil.rmtree(vector_store_path) # --- First run: Create and save the index --- # print("--- Initializing first RAG instance ---") # rag_system_1 = TextRAG(embed_model=embedding_model, vectorstore_dir=vector_store_path) # rag_system_1.consume(source_dir="sample_data", file_type=["**/*.txt", "**/*.md"]) # rag_system_1.save_local()