File size: 7,315 Bytes
cc37925 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import os
import pickle
import shutil
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter, MarkdownHeaderTextSplitter
from langchain_community.retrievers import BM25Retriever
from typing import List
class TextRAG:
def __init__(self, embed_model, vectorstore_dir: str = None):
self.embed_model = embed_model
self.vector_store = None
self.docs = None
self.vectorstore_dir = vectorstore_dir
if os.path.isdir(self.vectorstore_dir) and os.listdir(self.vectorstore_dir):
print(f"Found existing vector store at '{self.vectorstore_dir}', loading...")
self.load_local(self.vectorstore_dir)
else:
print(f"Creating new vector store directory at '{self.vectorstore_dir}'")
os.makedirs(self.vectorstore_dir, exist_ok=True)
def _clear(self):
self.vector_store = None
self.docs = None
print("Cleared the vector store and documents from memory.")
if os.path.isdir(self.vectorstore_dir):
shutil.rmtree(self.vectorstore_dir)
print(f"Removed local vector store directory: {self.vectorstore_dir}")
os.makedirs(self.vectorstore_dir, exist_ok=True)
def consume(self, source_dir: str, file_type: List[str] = None, chunk_size: int = 1000, chunk_overlap: int = 100, chunk_method: str = "recursive"):
if file_type is None:
file_type = ["**/*.txt"]
all_documents = []
for pattern in file_type:
loader = DirectoryLoader(source_dir, glob=pattern, loader_cls=TextLoader, show_progress=True, use_multithreading=True)
all_documents.extend(loader.load())
if not all_documents:
print("No documents found for the specified file types.")
return
if chunk_method == "recursive":
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
elif chunk_method == "character":
text_splitter = CharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator="\n"
)
elif chunk_method == "markdown":
text_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=[("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3")],
return_each_line=False,
strip_headers=False
)
else:
raise ValueError(f"Unknown chunk_method: {chunk_method}")
if chunk_method == "markdown":
new_docs = text_splitter.split_text( "\n\n".join([doc.page_content for doc in all_documents]) )
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
new_docs = text_splitter.split_documents(new_docs)
else:
new_docs = text_splitter.split_documents(all_documents)
if self.vector_store is None:
self.docs = new_docs
self.vector_store = FAISS.from_documents(self.docs, self.embed_model)
print(f"Successfully consumed {len(all_documents)} documents, creating {len(self.docs)} chunks and a new vector store.")
else:
if self.docs is None:
self.docs = []
self.docs.extend(new_docs)
self.vector_store.add_documents(new_docs)
print(f"Successfully added {len(all_documents)} documents ({len(new_docs)} chunks) to the existing vector store.")
self.save_local()
def search(self, query: str, metric: str = "similarity", k: int = 5, threshold: float = None):
"""
metric = ['similarity', 'mmr', 'bm25']
"""
if self.vector_store is None:
raise ValueError("Vector store not initialized. Please run the 'consume' or 'load_local' method first.")
if metric == "similarity":
if threshold is not None:
results_with_scores = self.vector_store.similarity_search_with_score(query, k=k)
return [doc for doc, score in results_with_scores if score >= threshold]
else:
return self.vector_store.similarity_search(query, k=k)
elif metric == "mmr":
return self.vector_store.max_marginal_relevance_search(query, k=k)
elif metric == "bm25":
if self.docs is None:
raise ValueError("Documents not available. BM25 requires consumed or loaded documents.")
bm25_retriever = BM25Retriever.from_documents(self.docs)
return bm25_retriever.get_relevant_documents(query, k=k)
else:
raise ValueError(f"Unsupported metric: '{metric}'. Supported metrics are 'similarity', 'mmr', and 'bm25'.")
def save_local(self, folder_path: str = None):
if folder_path is None:
folder_path = self.vectorstore_dir
if self.vector_store is None or self.docs is None:
raise ValueError("Nothing to save. Please run 'consume' first.")
os.makedirs(folder_path, exist_ok=True)
self.vector_store.save_local(folder_path)
with open(os.path.join(folder_path, "docs.pkl"), "wb") as f:
pickle.dump(self.docs, f)
print(f"Successfully saved RAG state to {folder_path}")
def load_local(self, folder_path: str):
if not os.path.isdir(folder_path):
raise FileNotFoundError(f"Folder not found: {folder_path}")
try:
self.vector_store = FAISS.load_local(folder_path, self.embed_model, allow_dangerous_deserialization=True)
docs_path = os.path.join(folder_path, "docs.pkl")
if os.path.exists(docs_path):
with open(docs_path, "rb") as f:
self.docs = pickle.load(f)
else:
self.docs = None
print("Warning: docs.pkl not found. BM25 search will not be available.")
print(f"Successfully loaded RAG state from {folder_path}")
except Exception as e:
print(f"Could not load from {folder_path}. It might be empty or corrupted. Error: {e}")
if __name__ == '__main__':
from langchain_community.embeddings import SentenceTransformerEmbeddings
embedding_model = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vector_store_path = "rag_index"
# Clean up previous runs for a clean demonstration
if os.path.exists(vector_store_path):
shutil.rmtree(vector_store_path)
# --- First run: Create and save the index ---
# print("--- Initializing first RAG instance ---")
# rag_system_1 = TextRAG(embed_model=embedding_model, vectorstore_dir=vector_store_path)
# rag_system_1.consume(source_dir="sample_data", file_type=["**/*.txt", "**/*.md"])
# rag_system_1.save_local()
|