NavyDevilDoc commited on
Commit
6695d4a
·
verified ·
1 Parent(s): 4c360e5

Update src/rag_engine.py

Browse files

refactored to use custom text splitting code

Files changed (1) hide show
  1. src/rag_engine.py +114 -230
src/rag_engine.py CHANGED
@@ -1,254 +1,138 @@
1
  import os
2
- import shutil
3
- import time
4
- from langchain_text_splitters import RecursiveCharacterTextSplitter, TokenTextSplitter
5
- from langchain_chroma import Chroma
6
- from langchain_huggingface import HuggingFaceEmbeddings
7
- from langchain_community.docstore.document import Document
8
- from sentence_transformers import CrossEncoder # Re-added for Reranking
9
- import doc_loader
10
-
11
- # --- CONFIGURATION ---
12
- CHROMA_PATH = "chroma_db"
13
- UPLOAD_DIR = "temp_ingest" # Re-added directory constant
14
- EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
15
- RERANK_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" # Re-added model name
16
-
17
- # --- LAZY LOADING GLOBALS ---
18
- # We use a global variable pattern to avoid loading heavy models
19
- # until the moment they are actually needed (saves startup RAM).
20
- _embedding_func = None
21
- _rerank_model = None
22
-
23
- def get_embedding_func():
24
- """Lazy loads the embedding model."""
25
- global _embedding_func
26
- if _embedding_func is None:
27
- print(f"⏳ Loading Embedding Model: {EMBED_MODEL_NAME}...")
28
- _embedding_func = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
29
- print("✅ Embedding Model Loaded.")
30
- return _embedding_func
31
-
32
- def get_rerank_model():
33
- """Lazy loads the Cross-Encoder model."""
34
- global _rerank_model
35
- if _rerank_model is None:
36
- print(f"⏳ Loading Reranker: {RERANK_MODEL_NAME}...")
37
- _rerank_model = CrossEncoder(RERANK_MODEL_NAME)
38
- print("✅ Reranker Loaded.")
39
- return _rerank_model
40
-
41
- # --- FILE OPERATIONS ---
42
- def save_uploaded_file(uploaded_file):
43
- """Saves uploaded file to the temp directory."""
44
- os.makedirs(UPLOAD_DIR, exist_ok=True)
45
- file_path = os.path.join(UPLOAD_DIR, uploaded_file.name)
46
-
47
- with open(file_path, "wb") as f:
48
- f.write(uploaded_file.getbuffer())
49
-
50
- return file_path
51
 
52
- # --- INGESTION PIPELINE ---
53
- def process_and_add_document(file_path, username, strategy, use_vision=False, api_key=None):
 
 
 
 
 
 
 
54
  """
55
- Ingests a document using the Universal Loader and adds it to the user's vector DB.
56
  """
57
- user_db_path = os.path.join(CHROMA_PATH, username)
58
-
59
  try:
60
- # 1. EXTRACT TEXT (Using doc_loader)
61
- # We need a pseudo-object because doc_loader expects a Streamlit object,
62
- # but we are reading from disk.
63
- with open(file_path, "rb") as f:
64
- class FileObj:
65
- def __init__(self, f, name):
66
- self.f = f
67
- self.name = name
68
- def read(self): return self.f.read()
69
-
70
- file_obj = FileObj(f, os.path.basename(file_path))
71
- raw_text = doc_loader.extract_text_from_file(file_obj, use_vision=use_vision, api_key=api_key)
72
-
73
- if not raw_text or not raw_text.strip():
74
- return False, "Document appears empty or could not be read."
75
-
76
- # 2. CHUNK TEXT
77
- chunks = []
78
- if strategy == "paragraph":
79
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
80
- chunks = splitter.split_text(raw_text)
81
- elif strategy == "token":
82
- splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=50)
83
- chunks = splitter.split_text(raw_text)
84
- elif strategy == "page":
85
- splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
86
- chunks = splitter.split_text(raw_text)
87
-
88
- # 3. CREATE DOCUMENTS
89
- docs = [
90
- Document(
91
- page_content=chunk,
92
- metadata={"source": os.path.basename(file_path), "strategy": strategy}
93
- )
94
- for chunk in chunks
95
  ]
96
-
97
- # 4. INDEX TO CHROMA
98
- if docs:
99
- # Use the getter function (Lazy Load)
100
- emb_fn = get_embedding_func()
101
- db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
102
- db.add_documents(docs)
103
- return True, f"Successfully indexed {len(docs)} chunks from {os.path.basename(file_path)}."
104
- else:
105
- return False, "No chunks created."
106
 
107
- except Exception as e:
108
- return False, f"Error processing document: {e}"
 
109
 
110
- # --- SEARCH PIPELINE (Now with Reranking!) ---
111
- def search_knowledge_base(query, username, k=10, final_k=4):
112
- """
113
- Retrieves top K chunks, then uses a Cross-Encoder to re-rank them
114
- and returns the top final_k most relevant chunks.
115
- """
116
- user_db_path = os.path.join(CHROMA_PATH, username)
117
- if not os.path.exists(user_db_path):
118
- return []
119
-
120
- try:
121
- # 1. INITIAL RETRIEVAL (Vector Similarity)
122
- emb_fn = get_embedding_func()
123
- db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
124
- # Fetch more candidates (k=10) to give the reranker options
125
- results = db.similarity_search_with_relevance_scores(query, k=k)
126
 
127
- if not results:
128
- return []
 
 
129
 
130
- # 2. RERANKING
131
- # Extract just the text for the cross-encoder
132
- candidate_docs = [doc for doc, _ in results]
133
- candidate_texts = [doc.page_content for doc in candidate_docs]
134
- if not candidate_texts:
135
- return []
136
-
137
- # Form pairs: (Query, Document Text)
138
- pairs = [[query, text] for text in candidate_texts]
139
-
140
- # Score pairs
141
- reranker = get_rerank_model()
142
- scores = reranker.predict(pairs)
143
-
144
- # Attach scores to documents and sort
145
- scored_docs = list(zip(candidate_docs, scores))
146
- # Sort by score descending (High score = Better match)
147
- scored_docs.sort(key=lambda x: x[1], reverse=True)
148
-
149
- # 3. RETURN TOP N
150
- # Return only the document objects of the top final_k
151
- final_docs = [doc for doc, score in scored_docs[:final_k]]
152
  return final_docs
153
 
154
  except Exception as e:
155
- print(f"RAG Error: {e}")
156
  return []
157
 
158
- # --- MANAGEMENT UTILS ---
159
- def list_documents(username):
160
- """Returns a list of unique sources in the user's DB."""
161
- user_db_path = os.path.join(CHROMA_PATH, username)
162
- if not os.path.exists(user_db_path):
163
- return []
164
-
165
- try:
166
- emb_fn = get_embedding_func()
167
- db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
168
- data = db.get()
169
- metadatas = data['metadatas']
170
-
171
- inventory = {}
172
- for m in metadatas:
173
- src = m.get('source', 'Unknown')
174
- if src not in inventory:
175
- inventory[src] = {"chunks": 0, "strategy": m.get('strategy', 'Unknown')}
176
- inventory[src]["chunks"] += 1
177
-
178
- return [{"filename": k, "chunks": v["chunks"], "strategy": v["strategy"], "source": k} for k, v in inventory.items()]
179
- except:
180
  return []
181
 
182
- def delete_document(username, source_name):
183
- """Removes all chunks associated with a specific source file."""
184
- user_db_path = os.path.join(CHROMA_PATH, username)
185
- try:
186
- emb_fn = get_embedding_func()
187
- db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
 
 
 
 
 
 
 
188
 
189
- data = db.get()
190
- ids_to_delete = []
191
- for i, meta in enumerate(data['metadatas']):
192
- if meta.get('source') == source_name:
193
- ids_to_delete.append(data['ids'][i])
194
-
195
- if ids_to_delete:
196
- db.delete(ids=ids_to_delete)
197
- return True, f"Deleted {source_name}."
198
  else:
199
- return False, "File not found in index."
200
- except Exception as e:
201
- return False, f"Delete failed: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- def reset_knowledge_base(username):
204
- """Wipes the entire user database."""
205
- user_db_path = os.path.join(CHROMA_PATH, username)
206
- if os.path.exists(user_db_path):
207
- shutil.rmtree(user_db_path)
208
- return True, "Database Reset."
209
- return False, "Database already empty."
210
 
211
- def process_and_add_text(raw_text, source_name, username, strategy="paragraph"):
 
 
 
212
  """
213
- Directly indexes a raw text string into the user's vector DB.
214
- Useful for indexing content generated by the LLM (like flattened notes).
215
  """
216
- user_db_path = os.path.join(CHROMA_PATH, username)
 
 
 
 
 
 
 
217
 
218
- try:
219
- if not raw_text or not raw_text.strip():
220
- return False, "Content appears empty."
221
-
222
- # 1. CHUNK TEXT (Reusing the standard logic)
223
- chunks = []
224
- if strategy == "paragraph":
225
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
226
- chunks = splitter.split_text(raw_text)
227
- elif strategy == "token":
228
- splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=50)
229
- chunks = splitter.split_text(raw_text)
230
- elif strategy == "page":
231
- splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
232
- chunks = splitter.split_text(raw_text)
233
-
234
- # 2. CREATE DOCUMENTS
235
- # We append "_flattened" to the source name so you can distinguish it from the original
236
- docs = [
237
- Document(
238
- page_content=chunk,
239
- metadata={"source": source_name, "strategy": f"{strategy}-flattened"}
240
- )
241
- for chunk in chunks
242
- ]
243
-
244
- # 3. INDEX TO CHROMA
245
- if docs:
246
- emb_fn = get_embedding_func()
247
- db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
248
- db.add_documents(docs)
249
- return True, f"Successfully indexed {len(docs)} flattened chunks."
250
- else:
251
- return False, "No chunks created."
252
-
253
- except Exception as e:
254
- return False, f"Error processing text: {e}"
 
1
  import os
2
+ import logging
3
+ from typing import List, Literal
4
+
5
+ # LangChain imports for the Markdown logic
6
+ from langchain_core.documents import Document
7
+ from langchain.text_splitter import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Custom Core Imports
10
+ from core.ParagraphChunker import ParagraphChunker
11
+ from core.TokenChunker import TokenChunker
12
+
13
+ # Configure Logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ def _process_markdown(file_path: str, chunk_size: int = 1000, chunk_overlap: int = 100) -> List[Document]:
18
  """
19
+ Internal helper to process Markdown files using Header Semantic Splitting.
20
  """
 
 
21
  try:
22
+ with open(file_path, 'r', encoding='utf-8') as f:
23
+ markdown_text = f.read()
24
+
25
+ # Define headers to split on (Logic: Keep context attached to the section)
26
+ headers_to_split_on = [
27
+ ("#", "Header 1"),
28
+ ("##", "Header 2"),
29
+ ("###", "Header 3"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  ]
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # Stage 1: Split by Structure (Headers)
33
+ markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
34
+ md_header_splits = markdown_splitter.split_text(markdown_text)
35
 
36
+ # Stage 2: Split by Size (Recursively split long sections)
37
+ text_splitter = RecursiveCharacterTextSplitter(
38
+ chunk_size=chunk_size,
39
+ chunk_overlap=chunk_overlap
40
+ )
41
+ final_docs = text_splitter.split_documents(md_header_splits)
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # Add source metadata
44
+ for doc in final_docs:
45
+ doc.metadata['source'] = file_path
46
+ doc.metadata['file_type'] = 'md'
47
 
48
+ logger.info(f"Markdown processing complete: {len(final_docs)} chunks created.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  return final_docs
50
 
51
  except Exception as e:
52
+ logger.error(f"Error processing Markdown file {file_path}: {e}")
53
  return []
54
 
55
+ def process_file(
56
+ file_path: str,
57
+ chunking_strategy: Literal["paragraph", "token"] = "paragraph",
58
+ chunk_size: int = 512,
59
+ chunk_overlap: int = 50,
60
+ model_name: str = "gpt-4o" # Used for token counting in your custom classes
61
+ ) -> List[Document]:
62
+ """
63
+ Main entry point for processing a single file.
64
+ Routes to the correct custom chunker or markdown handler based on extension.
65
+ """
66
+
67
+ if not os.path.exists(file_path):
68
+ logger.error(f"File not found: {file_path}")
 
 
 
 
 
 
 
 
69
  return []
70
 
71
+ file_extension = os.path.splitext(file_path)[1].lower()
72
+ logger.info(f"Processing {file_path} using strategy: {chunking_strategy}")
73
+
74
+ # ---------------------------------------------------------
75
+ # 1. Handle Markdown (Specialized Logic)
76
+ # ---------------------------------------------------------
77
+ if file_extension == ".md":
78
+ return _process_markdown(file_path, chunk_size, chunk_overlap)
79
+
80
+ # ---------------------------------------------------------
81
+ # 2. Handle PDF and TXT (Custom Core Logic)
82
+ # ---------------------------------------------------------
83
+ elif file_extension in [".pdf", ".txt"]:
84
 
85
+ # Initialize the appropriate Custom Chunker
86
+ if chunking_strategy == "token":
87
+ chunker = TokenChunker(
88
+ model_name=model_name,
89
+ chunk_size=chunk_size,
90
+ chunk_overlap=chunk_overlap
91
+ )
 
 
92
  else:
93
+ # Paragraph chunker relies on semantic boundaries, not strict sizes
94
+ chunker = ParagraphChunker(model_name=model_name)
95
+
96
+ # Process based on file type
97
+ try:
98
+ if file_extension == ".pdf":
99
+ # Uses OCREnhancedPDFLoader internally via BaseChunker
100
+ return chunker.process_document(file_path)
101
+
102
+ elif file_extension == ".txt":
103
+ # Uses direct text reading with paragraph preservation
104
+ return chunker.process_text_file(file_path)
105
+
106
+ except Exception as e:
107
+ logger.error(f"Error using {chunking_strategy} chunker on {file_path}: {e}")
108
+ return []
109
 
110
+ else:
111
+ logger.warning(f"Unsupported file extension: {file_extension}")
112
+ return []
 
 
 
 
113
 
114
+ def load_documents_from_directory(
115
+ directory_path: str,
116
+ chunking_strategy: Literal["paragraph", "token"] = "paragraph"
117
+ ) -> List[Document]:
118
  """
119
+ Batch helper to process a directory of files.
 
120
  """
121
+ all_docs = []
122
+ for root, _, files in os.walk(directory_path):
123
+ for file in files:
124
+ file_path = os.path.join(root, file)
125
+ # Only process supported extensions
126
+ if file.lower().endswith(('.pdf', '.txt', '.md')):
127
+ docs = process_file(file_path, chunking_strategy=chunking_strategy)
128
+ all_docs.extend(docs)
129
 
130
+ return all_docs
131
+
132
+ # Quick test block
133
+ if __name__ == "__main__":
134
+ # Example usage
135
+ print("--- Testing Rag Engine ---")
136
+ # You can point this to a dummy file to test
137
+ # docs = process_file("test_data/navy_manual.pdf", chunking_strategy="paragraph")
138
+ # print(f"Loaded {len(docs)} chunks.")