NavyDevilDoc commited on
Commit
ff3310f
·
verified ·
1 Parent(s): 4de416e

Update src/rag_engine.py

Browse files

removed chroma support and added pinecone

Files changed (1) hide show
  1. src/rag_engine.py +89 -116
src/rag_engine.py CHANGED
@@ -4,22 +4,23 @@ import logging
4
  from typing import List, Literal, Tuple
5
 
6
  # --- LANGCHAIN & DB IMPORTS ---
7
- from langchain_chroma import Chroma
8
  from langchain_huggingface import HuggingFaceEmbeddings
9
  from langchain_core.documents import Document
10
  from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
11
  from sentence_transformers import CrossEncoder
12
 
 
13
  # --- CUSTOM CORE IMPORTS ---
 
14
  from core.ParagraphChunker import ParagraphChunker
15
  from core.TokenChunker import TokenChunker
16
  from core.AcronymManager import AcronymManager
17
 
18
  # --- CONFIGURATION ---
19
- CHROMA_PATH = "chroma_db"
20
  UPLOAD_DIR = "source_documents"
21
  EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
22
  RERANK_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
 
23
 
24
  # Configure Logging
25
  logging.basicConfig(level=logging.INFO)
@@ -133,7 +134,7 @@ def process_file(
133
  logger.warning(f"Unsupported file extension: {file_extension}")
134
  return []
135
 
136
- # --- PART 2: DATABASE & FILE MANAGEMENT (The Old Stable System) ---
137
 
138
  def save_uploaded_file(uploaded_file, username: str = "default") -> str:
139
  """Saves a StreamlitUploadedFile to disk so the loaders can read it."""
@@ -144,102 +145,99 @@ def save_uploaded_file(uploaded_file, username: str = "default") -> str:
144
 
145
  with open(file_path, "wb") as f:
146
  f.write(uploaded_file.getbuffer())
147
-
148
- logger.info(f"File saved: {file_path}")
149
  return file_path
150
  except Exception as e:
151
  logger.error(f"Error saving file: {e}")
152
  return None
153
 
154
- def process_and_add_text(text: str, source_name: str, username: str) -> Tuple[bool, str]:
155
- """
156
- Ingests raw text string (e.g., from the Flattener tool) directly into Chroma.
157
- """
158
  try:
159
- user_db_path = os.path.join(CHROMA_PATH, username)
160
- emb_fn = get_embedding_func()
 
 
161
 
162
- # Initialize DB
163
- db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
164
 
165
- # Create a Document object directly
 
 
 
 
166
  doc = Document(
167
  page_content=text,
168
- metadata={
169
- "source": source_name,
170
- "strategy": "flattened_text",
171
- "file_type": "generated"
172
- }
173
  )
174
-
175
- # Add single document
176
- db.add_documents([doc])
177
- return True, f"Successfully indexed flattened text: {source_name}"
178
 
 
 
 
 
 
179
  except Exception as e:
180
- logger.error(f"Error indexing raw text: {e}")
181
- return False, f"Error: {str(e)}"
 
 
 
 
182
 
183
- def ingest_file(file_path: str, username: str, strategy: str = "paragraph") -> Tuple[bool, str]:
184
  try:
185
- # 1. Chunk the file
186
  docs = process_file(file_path, chunking_strategy=strategy)
187
-
188
- if not docs:
189
- return False, "No valid chunks generated from file."
190
 
191
- # --- ACRONYM SCANNING ---
192
- # We scan the raw text of the chunks to learn new definitions
193
  acronym_mgr = AcronymManager()
194
  for doc in docs:
195
  acronym_mgr.scan_text_for_acronyms(doc.page_content)
196
- # -----------------------------
197
 
198
- # 2. Add to Chroma DB
199
- user_db_path = os.path.join(CHROMA_PATH, username)
 
 
 
 
200
  emb_fn = get_embedding_func()
201
-
202
- db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
203
- db.add_documents(docs)
204
 
205
  return True, f"Successfully indexed {len(docs)} chunks."
206
 
207
  except Exception as e:
208
  logger.error(f"Ingestion failed: {e}")
209
- return False, f"System Error: {str(e)}"
210
 
211
- def search_knowledge_base(query: str, username: str, k: int = 10, final_k: int = 4) -> List[Document]:
212
- user_db_path = os.path.join(CHROMA_PATH, username)
213
- if not os.path.exists(user_db_path):
214
- return []
215
-
216
  try:
217
- # --- NEW: QUERY EXPANSION ---
218
  acronym_mgr = AcronymManager()
219
  expanded_query = acronym_mgr.expand_query(query)
220
- if expanded_query != query:
221
- logger.info(f"Query Expanded: '{query}' -> '{expanded_query}'")
222
- else:
223
- expanded_query = query
224
- # ----------------------------
225
 
226
- # 1. Vector Retrieval (Use expanded_query instead of query)
 
227
  emb_fn = get_embedding_func()
228
- db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
229
- results = db.similarity_search_with_relevance_scores(expanded_query, k=k) # <--- UPDATED VAR
230
 
231
- if not results:
232
- return []
233
 
234
- # 2. Reranking (Pass expanded_query here too)
235
- candidate_docs = [doc for doc, _ in results]
236
  candidate_texts = [doc.page_content for doc in candidate_docs]
237
- pairs = [[expanded_query, text] for text in candidate_texts] # <--- UPDATED VAR
238
 
239
  reranker = get_rerank_model()
240
  scores = reranker.predict(pairs)
241
 
242
- # Sort by new score
243
  scored_docs = list(zip(candidate_docs, scores))
244
  scored_docs.sort(key=lambda x: x[1], reverse=True)
245
 
@@ -251,67 +249,42 @@ def search_knowledge_base(query: str, username: str, k: int = 10, final_k: int =
251
 
252
  def list_documents(username: str) -> List[dict]:
253
  """
254
- Returns a list of unique files currently in the vector database.
255
- (Used for the sidebar list)
 
256
  """
257
- user_db_path = os.path.join(CHROMA_PATH, username)
258
- if not os.path.exists(user_db_path):
259
- return []
260
-
261
- try:
262
- emb_fn = get_embedding_func()
263
- db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
264
-
265
- # Chroma's .get() returns all metadata
266
- data = db.get()
267
- metadatas = data['metadatas']
268
-
269
- inventory = {}
270
- for m in metadatas:
271
- # Metadata keys might differ slightly, handle gracefully
272
- src = m.get('source', 'Unknown')
273
- if src not in inventory:
274
- inventory[src] = {
275
- "chunks": 0,
276
- "strategy": m.get('strategy', 'unknown')
277
- }
278
- inventory[src]["chunks"] += 1
279
-
280
- # FIXED: Added "source": k to the dictionary below
281
- return [
282
- {"filename": k, "chunks": v["chunks"], "strategy": v["strategy"], "source": k}
283
- for k, v in inventory.items()
284
- ]
285
- except Exception as e:
286
- logger.error(f"Error listing docs: {e}")
287
- return []
288
-
289
- def delete_document(username: str, filename: str) -> Tuple[bool, str]:
290
- """Removes a document from the vector database."""
291
- user_db_path = os.path.join(CHROMA_PATH, username)
292
  try:
293
- emb_fn = get_embedding_func()
294
- db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
 
295
 
296
- data = db.get()
297
- ids_to_delete = []
298
- for i, meta in enumerate(data['metadatas']):
299
- if meta.get('source') == filename:
300
- ids_to_delete.append(data['ids'][i])
301
-
302
- if ids_to_delete:
303
- db.delete(ids=ids_to_delete)
304
- return True, f"Deleted {filename}."
305
- else:
306
- return False, "File not found in index."
307
 
 
308
  except Exception as e:
309
- return False, f"Delete failed: {e}"
310
 
311
  def reset_knowledge_base(username: str) -> Tuple[bool, str]:
312
- """Nukes the user's database folder."""
313
- user_db_path = os.path.join(CHROMA_PATH, username)
314
- if os.path.exists(user_db_path):
315
- shutil.rmtree(user_db_path)
316
- return True, "Database Reset."
317
- return False, "Database already empty."
 
 
4
  from typing import List, Literal, Tuple
5
 
6
  # --- LANGCHAIN & DB IMPORTS ---
 
7
  from langchain_huggingface import HuggingFaceEmbeddings
8
  from langchain_core.documents import Document
9
  from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
10
  from sentence_transformers import CrossEncoder
11
 
12
+
13
  # --- CUSTOM CORE IMPORTS ---
14
+ from core.PineconeManager import PineconeManager
15
  from core.ParagraphChunker import ParagraphChunker
16
  from core.TokenChunker import TokenChunker
17
  from core.AcronymManager import AcronymManager
18
 
19
  # --- CONFIGURATION ---
 
20
  UPLOAD_DIR = "source_documents"
21
  EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
22
  RERANK_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
23
+ PINECONE_KEY = os.getenv("PINECONE_API_KEY")
24
 
25
  # Configure Logging
26
  logging.basicConfig(level=logging.INFO)
 
134
  logger.warning(f"Unsupported file extension: {file_extension}")
135
  return []
136
 
137
+ # --- PART 2: DATABASE & FILE MANAGEMENT (Pinecone Version) ---
138
 
139
  def save_uploaded_file(uploaded_file, username: str = "default") -> str:
140
  """Saves a StreamlitUploadedFile to disk so the loaders can read it."""
 
145
 
146
  with open(file_path, "wb") as f:
147
  f.write(uploaded_file.getbuffer())
 
 
148
  return file_path
149
  except Exception as e:
150
  logger.error(f"Error saving file: {e}")
151
  return None
152
 
153
+ def process_and_add_text(text: str, source_name: str, username: str, index_name: str) -> Tuple[bool, str]:
154
+ """Ingests raw text (Flattener) -> Saves Backup to Disk -> Uploads to Pinecone."""
155
+ if not PINECONE_KEY or not index_name: return False, "Pinecone Configuration Missing."
156
+
157
  try:
158
+ # 1. SAVE PHYSICAL BACKUP (For Quiz Engine)
159
+ user_docs_dir = os.path.join(UPLOAD_DIR, username)
160
+ os.makedirs(user_docs_dir, exist_ok=True)
161
+ backup_path = os.path.join(user_docs_dir, source_name)
162
 
163
+ with open(backup_path, "w", encoding='utf-8') as f:
164
+ f.write(text)
165
 
166
+ # 2. UPLOAD TO PINECONE
167
+ pm = PineconeManager(PINECONE_KEY)
168
+ emb_fn = get_embedding_func()
169
+
170
+ # Create Document
171
  doc = Document(
172
  page_content=text,
173
+ metadata={"source": source_name, "strategy": "flattened", "file_type": "generated"}
 
 
 
 
174
  )
 
 
 
 
175
 
176
+ # Add to VectorStore (Namespace = Username)
177
+ vstore = pm.get_vectorstore(index_name, emb_fn, namespace=username)
178
+ vstore.add_documents([doc])
179
+
180
+ return True, f"Indexed and backed up: {source_name}"
181
  except Exception as e:
182
+ logger.error(f"Error indexing text: {e}")
183
+ return False, str(e)
184
+
185
+ def ingest_file(file_path: str, username: str, index_name: str, strategy: str = "paragraph") -> Tuple[bool, str]:
186
+ """Chunks File -> Scans Acronyms -> Uploads to Pinecone."""
187
+ if not PINECONE_KEY or not index_name: return False, "Pinecone Configuration Missing."
188
 
 
189
  try:
190
+ # 1. Chunking
191
  docs = process_file(file_path, chunking_strategy=strategy)
192
+ if not docs: return False, "No valid chunks generated."
 
 
193
 
194
+ # 2. Acronym Learning
 
195
  acronym_mgr = AcronymManager()
196
  for doc in docs:
197
  acronym_mgr.scan_text_for_acronyms(doc.page_content)
 
198
 
199
+ # 3. Pinecone Safety Check
200
+ pm = PineconeManager(PINECONE_KEY)
201
+ if not pm.check_dimension_compatibility(index_name, 384):
202
+ return False, f"Dimension Mismatch! Index {index_name} is not 384d."
203
+
204
+ # 4. Upload
205
  emb_fn = get_embedding_func()
206
+ vstore = pm.get_vectorstore(index_name, emb_fn, namespace=username)
207
+ vstore.add_documents(docs)
 
208
 
209
  return True, f"Successfully indexed {len(docs)} chunks."
210
 
211
  except Exception as e:
212
  logger.error(f"Ingestion failed: {e}")
213
+ return False, str(e)
214
 
215
+ def search_knowledge_base(query: str, username: str, index_name: str, k: int = 10, final_k: int = 4) -> List[Document]:
216
+ """Retrieves from Pinecone -> Reranks."""
217
+ if not PINECONE_KEY or not index_name: return []
218
+
 
219
  try:
220
+ # 1. Expand Query (Acronyms)
221
  acronym_mgr = AcronymManager()
222
  expanded_query = acronym_mgr.expand_query(query)
 
 
 
 
 
223
 
224
+ # 2. Vector Search
225
+ pm = PineconeManager(PINECONE_KEY)
226
  emb_fn = get_embedding_func()
227
+ vstore = pm.get_vectorstore(index_name, emb_fn, namespace=username)
 
228
 
229
+ results = vstore.similarity_search(expanded_query, k=k)
230
+ if not results: return []
231
 
232
+ # 3. Reranking
233
+ candidate_docs = results
234
  candidate_texts = [doc.page_content for doc in candidate_docs]
235
+ pairs = [[expanded_query, text] for text in candidate_texts]
236
 
237
  reranker = get_rerank_model()
238
  scores = reranker.predict(pairs)
239
 
240
+ # Sort
241
  scored_docs = list(zip(candidate_docs, scores))
242
  scored_docs.sort(key=lambda x: x[1], reverse=True)
243
 
 
249
 
250
  def list_documents(username: str) -> List[dict]:
251
  """
252
+ NOTE: Pinecone does not support easy listing of all unique files.
253
+ We return the Local Cache (source_documents) as a proxy for what is
254
+ available for the Quiz Engine.
255
  """
256
+ user_dir = os.path.join(UPLOAD_DIR, username)
257
+ if not os.path.exists(user_dir): return []
258
+
259
+ files = []
260
+ for f in os.listdir(user_dir):
261
+ if f.lower().endswith(('.pdf', '.txt', '.md')):
262
+ files.append({"filename": f, "source": f, "strategy": "local_cache"})
263
+ return files
264
+
265
+ def delete_document(username: str, filename: str, index_name: str) -> Tuple[bool, str]:
266
+ """Deletes from Pinecone AND Local Disk."""
267
+ if not PINECONE_KEY or not index_name: return False, "Config Missing."
268
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  try:
270
+ # 1. Delete from Pinecone
271
+ pm = PineconeManager(PINECONE_KEY)
272
+ pm.delete_file(index_name, filename, namespace=username)
273
 
274
+ # 2. Delete from Disk (Clean up Quiz Cache)
275
+ local_path = os.path.join(UPLOAD_DIR, username, filename)
276
+ if os.path.exists(local_path):
277
+ os.remove(local_path)
 
 
 
 
 
 
 
278
 
279
+ return True, f"Deleted {filename} from Index and Disk."
280
  except Exception as e:
281
+ return False, str(e)
282
 
283
  def reset_knowledge_base(username: str) -> Tuple[bool, str]:
284
+ """
285
+ WARNING: This deletes the USER NAMESPACE in Pinecone, not the whole Index.
286
+ """
287
+ # Pinecone delete_all is index-wide usually.
288
+ # For safety in namespace-based multi-tenancy, we usually skip this
289
+ # or implement a delete_all(delete_all=True, namespace=username)
290
+ return False, "Resetting entire DB via API is disabled for safety. Use Delete."