File size: 7,315 Bytes
cc37925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import pickle
import shutil
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter, MarkdownHeaderTextSplitter
from langchain_community.retrievers import BM25Retriever
from typing import List

class TextRAG:
    def __init__(self, embed_model, vectorstore_dir: str = None):
        self.embed_model = embed_model
        self.vector_store = None
        self.docs = None
        self.vectorstore_dir = vectorstore_dir
        
        if os.path.isdir(self.vectorstore_dir) and os.listdir(self.vectorstore_dir):
            print(f"Found existing vector store at '{self.vectorstore_dir}', loading...")
            self.load_local(self.vectorstore_dir)
        else:
            print(f"Creating new vector store directory at '{self.vectorstore_dir}'")
            os.makedirs(self.vectorstore_dir, exist_ok=True)

    def _clear(self):
        self.vector_store = None
        self.docs = None
        print("Cleared the vector store and documents from memory.")
        if os.path.isdir(self.vectorstore_dir):
            shutil.rmtree(self.vectorstore_dir)
            print(f"Removed local vector store directory: {self.vectorstore_dir}")
            os.makedirs(self.vectorstore_dir, exist_ok=True)


    def consume(self, source_dir: str, file_type: List[str] = None, chunk_size: int = 1000, chunk_overlap: int = 100, chunk_method: str = "recursive"):
        if file_type is None:
            file_type = ["**/*.txt"]

        all_documents = []
        for pattern in file_type:
            loader = DirectoryLoader(source_dir, glob=pattern, loader_cls=TextLoader, show_progress=True, use_multithreading=True)
            all_documents.extend(loader.load())

        if not all_documents:
            print("No documents found for the specified file types.")
            return

        if chunk_method == "recursive":
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap
            )
        elif chunk_method == "character":
            text_splitter = CharacterTextSplitter(
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                separator="\n"
            )
        elif chunk_method == "markdown":
            text_splitter = MarkdownHeaderTextSplitter(
                headers_to_split_on=[("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3")],
                return_each_line=False,
                strip_headers=False
            )
        else:
            raise ValueError(f"Unknown chunk_method: {chunk_method}")

        if chunk_method == "markdown":
            new_docs = text_splitter.split_text( "\n\n".join([doc.page_content for doc in all_documents]) )
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=chunk_size, chunk_overlap=chunk_overlap
            )
            new_docs = text_splitter.split_documents(new_docs)
        else:
            new_docs = text_splitter.split_documents(all_documents)
        
        if self.vector_store is None:
            self.docs = new_docs
            self.vector_store = FAISS.from_documents(self.docs, self.embed_model)
            print(f"Successfully consumed {len(all_documents)} documents, creating {len(self.docs)} chunks and a new vector store.")
        else:
            if self.docs is None:
                self.docs = []
            self.docs.extend(new_docs)
            self.vector_store.add_documents(new_docs)
            print(f"Successfully added {len(all_documents)} documents ({len(new_docs)} chunks) to the existing vector store.")

        self.save_local()

    def search(self, query: str, metric: str = "similarity", k: int = 5, threshold: float = None):
        """

        metric = ['similarity', 'mmr', 'bm25']

        """
        if self.vector_store is None:
            raise ValueError("Vector store not initialized. Please run the 'consume' or 'load_local' method first.")

        if metric == "similarity":
            if threshold is not None:
                results_with_scores = self.vector_store.similarity_search_with_score(query, k=k)
                return [doc for doc, score in results_with_scores if score >= threshold]
            else:
                return self.vector_store.similarity_search(query, k=k)
        elif metric == "mmr":
            return self.vector_store.max_marginal_relevance_search(query, k=k)
        elif metric == "bm25":
            if self.docs is None:
                raise ValueError("Documents not available. BM25 requires consumed or loaded documents.")
            bm25_retriever = BM25Retriever.from_documents(self.docs)
            return bm25_retriever.get_relevant_documents(query, k=k)
        else:
            raise ValueError(f"Unsupported metric: '{metric}'. Supported metrics are 'similarity', 'mmr', and 'bm25'.")

    def save_local(self, folder_path: str = None):
        if folder_path is None:
            folder_path = self.vectorstore_dir

        if self.vector_store is None or self.docs is None:
            raise ValueError("Nothing to save. Please run 'consume' first.")
        
        os.makedirs(folder_path, exist_ok=True)
        self.vector_store.save_local(folder_path)
        
        with open(os.path.join(folder_path, "docs.pkl"), "wb") as f:
            pickle.dump(self.docs, f)
            
        print(f"Successfully saved RAG state to {folder_path}")

    def load_local(self, folder_path: str):
        if not os.path.isdir(folder_path):
            raise FileNotFoundError(f"Folder not found: {folder_path}")
            
        try:
            self.vector_store = FAISS.load_local(folder_path, self.embed_model, allow_dangerous_deserialization=True)
            
            docs_path = os.path.join(folder_path, "docs.pkl")
            if os.path.exists(docs_path):
                with open(docs_path, "rb") as f:
                    self.docs = pickle.load(f)
            else:
                self.docs = None 
                print("Warning: docs.pkl not found. BM25 search will not be available.")

            print(f"Successfully loaded RAG state from {folder_path}")
        except Exception as e:
            print(f"Could not load from {folder_path}. It might be empty or corrupted. Error: {e}")


if __name__ == '__main__':
    from langchain_community.embeddings import SentenceTransformerEmbeddings

    embedding_model = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
    vector_store_path = "rag_index"

    # Clean up previous runs for a clean demonstration
    if os.path.exists(vector_store_path):
        shutil.rmtree(vector_store_path)

    # --- First run: Create and save the index ---
    # print("--- Initializing first RAG instance ---")
    # rag_system_1 = TextRAG(embed_model=embedding_model, vectorstore_dir=vector_store_path)

    # rag_system_1.consume(source_dir="sample_data", file_type=["**/*.txt", "**/*.md"])
    # rag_system_1.save_local()