NavyDevilDoc commited on
Commit
38ce8e2
·
verified ·
1 Parent(s): cdc0512

Update src/rag_engine.py

Browse files
Files changed (1) hide show
  1. src/rag_engine.py +186 -88
src/rag_engine.py CHANGED
@@ -1,53 +1,80 @@
1
  import os
 
2
  import logging
3
- from typing import List, Literal
4
 
5
- # LangChain imports for the Markdown logic
 
 
6
  from langchain_core.documents import Document
7
- from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
 
8
 
9
- # Custom Core Imports
10
  from core.ParagraphChunker import ParagraphChunker
11
  from core.TokenChunker import TokenChunker
12
 
 
 
 
 
 
 
13
  # Configure Logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def _process_markdown(file_path: str, chunk_size: int = 1000, chunk_overlap: int = 100) -> List[Document]:
18
- """
19
- Internal helper to process Markdown files using Header Semantic Splitting.
20
- """
21
  try:
22
  with open(file_path, 'r', encoding='utf-8') as f:
23
  markdown_text = f.read()
24
 
25
- # Define headers to split on (Logic: Keep context attached to the section)
26
  headers_to_split_on = [
27
  ("#", "Header 1"),
28
  ("##", "Header 2"),
29
  ("###", "Header 3"),
30
  ]
31
 
32
- # Stage 1: Split by Structure (Headers)
33
  markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
34
  md_header_splits = markdown_splitter.split_text(markdown_text)
35
 
36
- # Stage 2: Split by Size (Recursively split long sections)
37
  text_splitter = RecursiveCharacterTextSplitter(
38
  chunk_size=chunk_size,
39
  chunk_overlap=chunk_overlap
40
  )
41
  final_docs = text_splitter.split_documents(md_header_splits)
42
 
43
- # Add source metadata
44
  for doc in final_docs:
45
- doc.metadata['source'] = file_path
46
  doc.metadata['file_type'] = 'md'
 
47
 
48
- logger.info(f"Markdown processing complete: {len(final_docs)} chunks created.")
49
  return final_docs
50
-
51
  except Exception as e:
52
  logger.error(f"Error processing Markdown file {file_path}: {e}")
53
  return []
@@ -57,32 +84,25 @@ def process_file(
57
  chunking_strategy: Literal["paragraph", "token"] = "paragraph",
58
  chunk_size: int = 512,
59
  chunk_overlap: int = 50,
60
- model_name: str = "gpt-4o" # Used for token counting in your custom classes
61
  ) -> List[Document]:
62
  """
63
- Main entry point for processing a single file.
64
- Routes to the correct custom chunker or markdown handler based on extension.
65
  """
66
-
67
  if not os.path.exists(file_path):
68
  logger.error(f"File not found: {file_path}")
69
  return []
70
 
71
  file_extension = os.path.splitext(file_path)[1].lower()
72
- logger.info(f"Processing {file_path} using strategy: {chunking_strategy}")
 
73
 
74
- # ---------------------------------------------------------
75
- # 1. Handle Markdown (Specialized Logic)
76
- # ---------------------------------------------------------
77
  if file_extension == ".md":
78
  return _process_markdown(file_path, chunk_size, chunk_overlap)
79
 
80
- # ---------------------------------------------------------
81
- # 2. Handle PDF and TXT (Custom Core Logic)
82
- # ---------------------------------------------------------
83
  elif file_extension in [".pdf", ".txt"]:
84
-
85
- # Initialize the appropriate Custom Chunker
86
  if chunking_strategy == "token":
87
  chunker = TokenChunker(
88
  model_name=model_name,
@@ -90,88 +110,166 @@ def process_file(
90
  chunk_overlap=chunk_overlap
91
  )
92
  else:
93
- # Paragraph chunker relies on semantic boundaries, not strict sizes
94
  chunker = ParagraphChunker(model_name=model_name)
95
 
96
- # Process based on file type
97
  try:
98
  if file_extension == ".pdf":
99
- # Uses OCREnhancedPDFLoader internally via BaseChunker
100
- return chunker.process_document(file_path)
101
-
102
  elif file_extension == ".txt":
103
- # Uses direct text reading with paragraph preservation
104
- return chunker.process_text_file(file_path)
105
-
 
 
 
 
 
 
106
  except Exception as e:
107
- logger.error(f"Error using {chunking_strategy} chunker on {file_path}: {e}")
108
  return []
109
-
110
  else:
111
  logger.warning(f"Unsupported file extension: {file_extension}")
112
  return []
113
 
114
- def load_documents_from_directory(
115
- directory_path: str,
116
- chunking_strategy: Literal["paragraph", "token"] = "paragraph"
117
- ) -> List[Document]:
118
- """
119
- Batch helper to process a directory of files.
120
- """
121
- all_docs = []
122
- for root, _, files in os.walk(directory_path):
123
- for file in files:
124
- file_path = os.path.join(root, file)
125
- # Only process supported extensions
126
- if file.lower().endswith(('.pdf', '.txt', '.md')):
127
- docs = process_file(file_path, chunking_strategy=chunking_strategy)
128
- all_docs.extend(docs)
129
-
130
- return all_docs
131
-
132
- def list_documents(username: str = "default") -> List[str]:
133
  """
134
- Lists all supported documents for a specific user.
135
- Adjust 'source_documents' if your folder is named differently.
136
  """
137
- # Define your source directory (Update this path if you use a different one!)
138
- base_dir = "source_documents"
139
- user_dir = os.path.join(base_dir, username)
140
-
141
- if not os.path.exists(user_dir):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  return []
143
 
144
- files = []
145
- for f in os.listdir(user_dir):
146
- if f.lower().endswith(('.pdf', '.txt', '.md')):
147
- files.append(f)
148
-
149
- return files
 
 
150
 
151
- def save_uploaded_file(uploaded_file, username: str = "default") -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  """
153
- Saves a StreamlitUploadedFile to a temporary location on disk.
154
- Returns the absolute path to the saved file.
155
  """
156
- try:
157
- # Define the directory where files will be stored
158
- # You can customize "source_documents" to match your preferred structure
159
- base_dir = "source_documents"
160
- user_dir = os.path.join(base_dir, username)
161
 
162
- # Create the directory if it doesn't exist
163
- os.makedirs(user_dir, exist_ok=True)
 
164
 
165
- # Create the full file path
166
- file_path = os.path.join(user_dir, uploaded_file.name)
 
167
 
168
- # Write the file content
169
- with open(file_path, "wb") as f:
170
- f.write(uploaded_file.getbuffer())
 
 
 
 
 
 
 
171
 
172
- logger.info(f"File saved successfully at: {file_path}")
173
- return file_path
 
 
 
 
 
 
 
 
 
 
 
 
174
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  except Exception as e:
176
- logger.error(f"Error saving uploaded file: {e}")
177
- return None
 
 
 
 
 
 
 
 
1
  import os
2
+ import shutil
3
  import logging
4
+ from typing import List, Literal, Tuple
5
 
6
+ # --- LANGCHAIN & DB IMPORTS ---
7
+ from langchain_chroma import Chroma
8
+ from langchain_huggingface import HuggingFaceEmbeddings
9
  from langchain_core.documents import Document
10
+ from langchain.text_splitter import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
11
+ from sentence_transformers import CrossEncoder
12
 
13
+ # --- CUSTOM CORE IMPORTS ---
14
  from core.ParagraphChunker import ParagraphChunker
15
  from core.TokenChunker import TokenChunker
16
 
17
+ # --- CONFIGURATION ---
18
+ CHROMA_PATH = "chroma_db"
19
+ UPLOAD_DIR = "source_documents"
20
+ EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
21
+ RERANK_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
22
+
23
  # Configure Logging
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
 
27
+ # --- LAZY LOADING GLOBALS ---
28
+ _embedding_func = None
29
+ _rerank_model = None
30
+
31
+ def get_embedding_func():
32
+ """Lazy loads the embedding model to save startup resources."""
33
+ global _embedding_func
34
+ if _embedding_func is None:
35
+ logger.info(f"⏳ Loading Embedding Model: {EMBED_MODEL_NAME}...")
36
+ _embedding_func = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
37
+ logger.info("✅ Embedding Model Loaded.")
38
+ return _embedding_func
39
+
40
+ def get_rerank_model():
41
+ """Lazy loads the Cross-Encoder model."""
42
+ global _rerank_model
43
+ if _rerank_model is None:
44
+ logger.info(f"⏳ Loading Reranker: {RERANK_MODEL_NAME}...")
45
+ _rerank_model = CrossEncoder(RERANK_MODEL_NAME)
46
+ logger.info("✅ Reranker Loaded.")
47
+ return _rerank_model
48
+
49
+ # --- PART 1: CHUNKING LOGIC (The New System) ---
50
+
51
  def _process_markdown(file_path: str, chunk_size: int = 1000, chunk_overlap: int = 100) -> List[Document]:
52
+ """Internal helper to process Markdown files using Header Semantic Splitting."""
 
 
53
  try:
54
  with open(file_path, 'r', encoding='utf-8') as f:
55
  markdown_text = f.read()
56
 
 
57
  headers_to_split_on = [
58
  ("#", "Header 1"),
59
  ("##", "Header 2"),
60
  ("###", "Header 3"),
61
  ]
62
 
 
63
  markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
64
  md_header_splits = markdown_splitter.split_text(markdown_text)
65
 
 
66
  text_splitter = RecursiveCharacterTextSplitter(
67
  chunk_size=chunk_size,
68
  chunk_overlap=chunk_overlap
69
  )
70
  final_docs = text_splitter.split_documents(md_header_splits)
71
 
 
72
  for doc in final_docs:
73
+ doc.metadata['source'] = os.path.basename(file_path)
74
  doc.metadata['file_type'] = 'md'
75
+ doc.metadata['strategy'] = 'markdown_header'
76
 
 
77
  return final_docs
 
78
  except Exception as e:
79
  logger.error(f"Error processing Markdown file {file_path}: {e}")
80
  return []
 
84
  chunking_strategy: Literal["paragraph", "token"] = "paragraph",
85
  chunk_size: int = 512,
86
  chunk_overlap: int = 50,
87
+ model_name: str = "gpt-4"
88
  ) -> List[Document]:
89
  """
90
+ Main chunking engine. Routes file to specific chunkers based on type/strategy.
 
91
  """
 
92
  if not os.path.exists(file_path):
93
  logger.error(f"File not found: {file_path}")
94
  return []
95
 
96
  file_extension = os.path.splitext(file_path)[1].lower()
97
+ file_name = os.path.basename(file_path)
98
+ logger.info(f"Processing {file_name} using strategy: {chunking_strategy}")
99
 
100
+ # 1. Handle Markdown
 
 
101
  if file_extension == ".md":
102
  return _process_markdown(file_path, chunk_size, chunk_overlap)
103
 
104
+ # 2. Handle PDF and TXT
 
 
105
  elif file_extension in [".pdf", ".txt"]:
 
 
106
  if chunking_strategy == "token":
107
  chunker = TokenChunker(
108
  model_name=model_name,
 
110
  chunk_overlap=chunk_overlap
111
  )
112
  else:
 
113
  chunker = ParagraphChunker(model_name=model_name)
114
 
 
115
  try:
116
  if file_extension == ".pdf":
117
+ docs = chunker.process_document(file_path)
 
 
118
  elif file_extension == ".txt":
119
+ docs = chunker.process_text_file(file_path)
120
+
121
+ # Ensure metadata consistency
122
+ for doc in docs:
123
+ doc.metadata["source"] = file_name
124
+ doc.metadata["strategy"] = chunking_strategy
125
+
126
+ return docs
127
+
128
  except Exception as e:
129
+ logger.error(f"Error using {chunking_strategy} chunker on {file_name}: {e}")
130
  return []
 
131
  else:
132
  logger.warning(f"Unsupported file extension: {file_extension}")
133
  return []
134
 
135
+ # --- PART 2: DATABASE & FILE MANAGEMENT (The Old Stable System) ---
136
+
137
+ def save_uploaded_file(uploaded_file, username: str = "default") -> str:
138
+ """Saves a StreamlitUploadedFile to disk so the loaders can read it."""
139
+ try:
140
+ user_dir = os.path.join(UPLOAD_DIR, username)
141
+ os.makedirs(user_dir, exist_ok=True)
142
+ file_path = os.path.join(user_dir, uploaded_file.name)
143
+
144
+ with open(file_path, "wb") as f:
145
+ f.write(uploaded_file.getbuffer())
146
+
147
+ logger.info(f"File saved: {file_path}")
148
+ return file_path
149
+ except Exception as e:
150
+ logger.error(f"Error saving file: {e}")
151
+ return None
152
+
153
+ def ingest_file(file_path: str, username: str, strategy: str = "paragraph") -> Tuple[bool, str]:
154
  """
155
+ The High-Level Bridge: Takes a file path, chunks it, and saves to Vector DB.
156
+ Replaces the old 'process_and_add_document'.
157
  """
158
+ try:
159
+ # 1. Chunk the file using the new engine
160
+ docs = process_file(file_path, chunking_strategy=strategy)
161
+
162
+ if not docs:
163
+ return False, "No valid chunks generated from file."
164
+
165
+ # 2. Add to Chroma DB
166
+ user_db_path = os.path.join(CHROMA_PATH, username)
167
+ emb_fn = get_embedding_func()
168
+
169
+ db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
170
+ db.add_documents(docs)
171
+
172
+ return True, f"Successfully indexed {len(docs)} chunks."
173
+
174
+ except Exception as e:
175
+ logger.error(f"Ingestion failed: {e}")
176
+ return False, f"System Error: {str(e)}"
177
+
178
+ def search_knowledge_base(query: str, username: str, k: int = 10, final_k: int = 4) -> List[Document]:
179
+ """Retrieves top K chunks, then uses Cross-Encoder to re-rank them."""
180
+ user_db_path = os.path.join(CHROMA_PATH, username)
181
+ if not os.path.exists(user_db_path):
182
  return []
183
 
184
+ try:
185
+ # 1. Vector Retrieval
186
+ emb_fn = get_embedding_func()
187
+ db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
188
+ results = db.similarity_search_with_relevance_scores(query, k=k)
189
+
190
+ if not results:
191
+ return []
192
 
193
+ # 2. Reranking
194
+ candidate_docs = [doc for doc, _ in results]
195
+ candidate_texts = [doc.page_content for doc in candidate_docs]
196
+ pairs = [[query, text] for text in candidate_texts]
197
+
198
+ reranker = get_rerank_model()
199
+ scores = reranker.predict(pairs)
200
+
201
+ # Sort by new score
202
+ scored_docs = list(zip(candidate_docs, scores))
203
+ scored_docs.sort(key=lambda x: x[1], reverse=True)
204
+
205
+ return [doc for doc, score in scored_docs[:final_k]]
206
+
207
+ except Exception as e:
208
+ logger.error(f"Search Error: {e}")
209
+ return []
210
+
211
+ def list_documents(username: str) -> List[dict]:
212
  """
213
+ Returns a list of unique files currently in the vector database.
214
+ (Used for the sidebar list)
215
  """
216
+ user_db_path = os.path.join(CHROMA_PATH, username)
217
+ if not os.path.exists(user_db_path):
218
+ return []
 
 
219
 
220
+ try:
221
+ emb_fn = get_embedding_func()
222
+ db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
223
 
224
+ # Chroma's .get() returns all metadata
225
+ data = db.get()
226
+ metadatas = data['metadatas']
227
 
228
+ inventory = {}
229
+ for m in metadatas:
230
+ # Metadata keys might differ slightly, handle gracefully
231
+ src = m.get('source', 'Unknown')
232
+ if src not in inventory:
233
+ inventory[src] = {
234
+ "chunks": 0,
235
+ "strategy": m.get('strategy', 'unknown')
236
+ }
237
+ inventory[src]["chunks"] += 1
238
 
239
+ return [
240
+ {"filename": k, "chunks": v["chunks"], "strategy": v["strategy"]}
241
+ for k, v in inventory.items()
242
+ ]
243
+ except Exception as e:
244
+ logger.error(f"Error listing docs: {e}")
245
+ return []
246
+
247
+ def delete_document(username: str, filename: str) -> Tuple[bool, str]:
248
+ """Removes a document from the vector database."""
249
+ user_db_path = os.path.join(CHROMA_PATH, username)
250
+ try:
251
+ emb_fn = get_embedding_func()
252
+ db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
253
 
254
+ data = db.get()
255
+ ids_to_delete = []
256
+ for i, meta in enumerate(data['metadatas']):
257
+ if meta.get('source') == filename:
258
+ ids_to_delete.append(data['ids'][i])
259
+
260
+ if ids_to_delete:
261
+ db.delete(ids=ids_to_delete)
262
+ return True, f"Deleted {filename}."
263
+ else:
264
+ return False, "File not found in index."
265
+
266
  except Exception as e:
267
+ return False, f"Delete failed: {e}"
268
+
269
+ def reset_knowledge_base(username: str) -> Tuple[bool, str]:
270
+ """Nukes the user's database folder."""
271
+ user_db_path = os.path.join(CHROMA_PATH, username)
272
+ if os.path.exists(user_db_path):
273
+ shutil.rmtree(user_db_path)
274
+ return True, "Database Reset."
275
+ return False, "Database already empty."