|
|
import os
|
|
|
import pickle
|
|
|
import shutil
|
|
|
from langchain_community.vectorstores import FAISS
|
|
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter, MarkdownHeaderTextSplitter
|
|
|
from langchain_community.retrievers import BM25Retriever
|
|
|
from typing import List
|
|
|
|
|
|
class TextRAG:
|
|
|
def __init__(self, embed_model, vectorstore_dir: str = None):
|
|
|
self.embed_model = embed_model
|
|
|
self.vector_store = None
|
|
|
self.docs = None
|
|
|
self.vectorstore_dir = vectorstore_dir
|
|
|
|
|
|
if os.path.isdir(self.vectorstore_dir) and os.listdir(self.vectorstore_dir):
|
|
|
print(f"Found existing vector store at '{self.vectorstore_dir}', loading...")
|
|
|
self.load_local(self.vectorstore_dir)
|
|
|
else:
|
|
|
print(f"Creating new vector store directory at '{self.vectorstore_dir}'")
|
|
|
os.makedirs(self.vectorstore_dir, exist_ok=True)
|
|
|
|
|
|
def _clear(self):
|
|
|
self.vector_store = None
|
|
|
self.docs = None
|
|
|
print("Cleared the vector store and documents from memory.")
|
|
|
if os.path.isdir(self.vectorstore_dir):
|
|
|
shutil.rmtree(self.vectorstore_dir)
|
|
|
print(f"Removed local vector store directory: {self.vectorstore_dir}")
|
|
|
os.makedirs(self.vectorstore_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
def consume(self, source_dir: str, file_type: List[str] = None, chunk_size: int = 1000, chunk_overlap: int = 100, chunk_method: str = "recursive"):
|
|
|
if file_type is None:
|
|
|
file_type = ["**/*.txt"]
|
|
|
|
|
|
all_documents = []
|
|
|
for pattern in file_type:
|
|
|
loader = DirectoryLoader(source_dir, glob=pattern, loader_cls=TextLoader, show_progress=True, use_multithreading=True)
|
|
|
all_documents.extend(loader.load())
|
|
|
|
|
|
if not all_documents:
|
|
|
print("No documents found for the specified file types.")
|
|
|
return
|
|
|
|
|
|
if chunk_method == "recursive":
|
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
|
chunk_size=chunk_size,
|
|
|
chunk_overlap=chunk_overlap
|
|
|
)
|
|
|
elif chunk_method == "character":
|
|
|
text_splitter = CharacterTextSplitter(
|
|
|
chunk_size=chunk_size,
|
|
|
chunk_overlap=chunk_overlap,
|
|
|
separator="\n"
|
|
|
)
|
|
|
elif chunk_method == "markdown":
|
|
|
text_splitter = MarkdownHeaderTextSplitter(
|
|
|
headers_to_split_on=[("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3")],
|
|
|
return_each_line=False,
|
|
|
strip_headers=False
|
|
|
)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown chunk_method: {chunk_method}")
|
|
|
|
|
|
if chunk_method == "markdown":
|
|
|
new_docs = text_splitter.split_text( "\n\n".join([doc.page_content for doc in all_documents]) )
|
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
|
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
|
|
)
|
|
|
new_docs = text_splitter.split_documents(new_docs)
|
|
|
else:
|
|
|
new_docs = text_splitter.split_documents(all_documents)
|
|
|
|
|
|
if self.vector_store is None:
|
|
|
self.docs = new_docs
|
|
|
self.vector_store = FAISS.from_documents(self.docs, self.embed_model)
|
|
|
print(f"Successfully consumed {len(all_documents)} documents, creating {len(self.docs)} chunks and a new vector store.")
|
|
|
else:
|
|
|
if self.docs is None:
|
|
|
self.docs = []
|
|
|
self.docs.extend(new_docs)
|
|
|
self.vector_store.add_documents(new_docs)
|
|
|
print(f"Successfully added {len(all_documents)} documents ({len(new_docs)} chunks) to the existing vector store.")
|
|
|
|
|
|
self.save_local()
|
|
|
|
|
|
def search(self, query: str, metric: str = "similarity", k: int = 5, threshold: float = None):
|
|
|
"""
|
|
|
metric = ['similarity', 'mmr', 'bm25']
|
|
|
"""
|
|
|
if self.vector_store is None:
|
|
|
raise ValueError("Vector store not initialized. Please run the 'consume' or 'load_local' method first.")
|
|
|
|
|
|
if metric == "similarity":
|
|
|
if threshold is not None:
|
|
|
results_with_scores = self.vector_store.similarity_search_with_score(query, k=k)
|
|
|
return [doc for doc, score in results_with_scores if score >= threshold]
|
|
|
else:
|
|
|
return self.vector_store.similarity_search(query, k=k)
|
|
|
elif metric == "mmr":
|
|
|
return self.vector_store.max_marginal_relevance_search(query, k=k)
|
|
|
elif metric == "bm25":
|
|
|
if self.docs is None:
|
|
|
raise ValueError("Documents not available. BM25 requires consumed or loaded documents.")
|
|
|
bm25_retriever = BM25Retriever.from_documents(self.docs)
|
|
|
return bm25_retriever.get_relevant_documents(query, k=k)
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported metric: '{metric}'. Supported metrics are 'similarity', 'mmr', and 'bm25'.")
|
|
|
|
|
|
def save_local(self, folder_path: str = None):
|
|
|
if folder_path is None:
|
|
|
folder_path = self.vectorstore_dir
|
|
|
|
|
|
if self.vector_store is None or self.docs is None:
|
|
|
raise ValueError("Nothing to save. Please run 'consume' first.")
|
|
|
|
|
|
os.makedirs(folder_path, exist_ok=True)
|
|
|
self.vector_store.save_local(folder_path)
|
|
|
|
|
|
with open(os.path.join(folder_path, "docs.pkl"), "wb") as f:
|
|
|
pickle.dump(self.docs, f)
|
|
|
|
|
|
print(f"Successfully saved RAG state to {folder_path}")
|
|
|
|
|
|
def load_local(self, folder_path: str):
|
|
|
if not os.path.isdir(folder_path):
|
|
|
raise FileNotFoundError(f"Folder not found: {folder_path}")
|
|
|
|
|
|
try:
|
|
|
self.vector_store = FAISS.load_local(folder_path, self.embed_model, allow_dangerous_deserialization=True)
|
|
|
|
|
|
docs_path = os.path.join(folder_path, "docs.pkl")
|
|
|
if os.path.exists(docs_path):
|
|
|
with open(docs_path, "rb") as f:
|
|
|
self.docs = pickle.load(f)
|
|
|
else:
|
|
|
self.docs = None
|
|
|
print("Warning: docs.pkl not found. BM25 search will not be available.")
|
|
|
|
|
|
print(f"Successfully loaded RAG state from {folder_path}")
|
|
|
except Exception as e:
|
|
|
print(f"Could not load from {folder_path}. It might be empty or corrupted. Error: {e}")
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
|
|
|
|
|
embedding_model = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
|
|
vector_store_path = "rag_index"
|
|
|
|
|
|
|
|
|
if os.path.exists(vector_store_path):
|
|
|
shutil.rmtree(vector_store_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|