NavyDevilDoc commited on
Commit
e0f2368
·
verified ·
1 Parent(s): ef97ac2

Update src/rag_engine.py

Browse files

refactored to make use of the document loading program

Files changed (1) hide show
  1. src/rag_engine.py +169 -172
src/rag_engine.py CHANGED
@@ -1,210 +1,207 @@
1
  import os
 
 
 
2
  from langchain_chroma import Chroma
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
- from sentence_transformers import CrossEncoder
5
- from core.ChunkingManager import ChunkingManager, ChunkingStrategy
6
- import tracker
7
 
8
  # --- CONFIGURATION ---
9
- UPLOAD_DIR = "/tmp/rag_uploads"
10
- DB_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chroma_db")
11
- EMBEDDING_MODEL_NAME = "all-MiniLM-L12-v2"
12
- RERANKER_MODEL_NAME = "BAAI/bge-reranker-base"
13
-
14
- # --- LAZY LOADING SINGLETONS ---
15
- _embedding_fn = None
16
- _reranker = None
17
- _chunk_manager = None
18
-
19
- def get_embedding_function():
20
- global _embedding_fn
21
- if _embedding_fn is None:
22
- print("⚙️ Loading Embedding Model...")
23
- _embedding_fn = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
24
- return _embedding_fn
25
-
26
- def get_reranker_model():
27
- global _reranker
28
- if _reranker is None:
29
- print("⚙️ Loading Reranker Model...")
30
- _reranker = CrossEncoder(RERANKER_MODEL_NAME)
31
- return _reranker
32
-
33
- def get_chunk_manager():
34
- global _chunk_manager
35
- if _chunk_manager is None:
36
- print("⚙️ Loading Chunk Manager...")
37
- _chunk_manager = ChunkingManager(embedding_model_name=EMBEDDING_MODEL_NAME)
38
- return _chunk_manager
39
-
40
- # --- DATABASE OPERATIONS ---
41
- def get_vectorstore(username):
42
- safe_username = os.path.basename(username)
43
- user_db_path = os.path.join(DB_ROOT, safe_username)
44
-
45
- if not os.path.exists(user_db_path):
46
- os.makedirs(user_db_path, exist_ok=True)
47
-
48
- return Chroma(
49
- persist_directory=user_db_path,
50
- embedding_function=get_embedding_function(),
51
- collection_name=f"docs_{safe_username}"
52
- )
53
-
54
  def save_uploaded_file(uploaded_file):
55
- if not os.path.exists(UPLOAD_DIR):
56
- os.makedirs(UPLOAD_DIR)
57
- safe_filename = os.path.basename(uploaded_file.name)
58
- file_path = os.path.join(UPLOAD_DIR, safe_filename)
59
  with open(file_path, "wb") as f:
60
  f.write(uploaded_file.getbuffer())
 
61
  return file_path
62
 
63
- def process_and_add_document(file_path, username, strategy="paragraph"):
 
 
 
 
 
 
64
  try:
65
- print(f"🧠 Chunking {os.path.basename(file_path)}...")
66
-
67
- strat_map = {
68
- "token": ChunkingStrategy.TOKEN,
69
- "paragraph": ChunkingStrategy.PARAGRAPH,
70
- "page": ChunkingStrategy.PAGE
71
- }
72
- selected_strategy = strat_map.get(strategy, ChunkingStrategy.PARAGRAPH)
73
-
74
- manager = get_chunk_manager()
75
- chunks = manager.process_document(
76
- file_path=file_path,
77
- strategy=selected_strategy,
78
- preprocess=True
79
- )
80
-
81
- if not chunks:
82
- return False, "No text extracted. Is the file empty/scanned?"
83
-
84
- # FIX #1: Tag every chunk with the strategy used
85
- for chunk in chunks:
86
- chunk.metadata["strategy"] = strategy
87
-
88
- print(f"💾 Indexing {len(chunks)} chunks into Vector DB...")
89
- db = get_vectorstore(username)
90
- db.add_documents(chunks)
91
-
92
- tracker.upload_user_db(username)
93
-
94
- if os.path.exists(file_path):
95
- os.remove(file_path)
96
 
97
- return True, f"Successfully added {len(chunks)} chunks to Knowledge Base."
98
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  except Exception as e:
100
- print(f" Processing Error: {e}")
101
- return False, str(e)
102
 
103
- # --- RETRIEVAL ENGINE ---
104
- def search_knowledge_base(query, username, k=6):
105
  """
106
- Two-Stage Retrieval System (RAG):
107
- 1. Retrieval: Get 10 candidates via fast Vector Search.
108
- 2. Reranking: Sort them via Cross-Encoder (Slow/Precise).
109
- 3. Return top k.
110
  """
 
 
 
 
111
  try:
112
- db = get_vectorstore(username)
113
- if db._collection.count() == 0:
 
 
 
 
 
114
  return []
115
 
116
- reranker = get_reranker_model()
117
-
118
- # 1. Vector Search (Broad Net)
119
- vector_results = db.similarity_search(query, k=25)
120
 
121
- # 2. "Poor Man's" Keyword Search (The Safety Net)
122
- # We perform a basic text search for unique terms in the query
123
- # This catches acronyms like "C&D" if we normalize them
124
 
125
- # Normalize query acronyms (e.g., "C&D" -> "C D")
126
- normalized_query = query.replace("&", " ")
127
- keyword_results = []
128
 
129
- # (Optional: In a production DB like Pinecone/Weaviate, this is built-in.
130
- # For Chroma local, we rely on the vector net mostly, but we can
131
- # extend k significantly to catch edge cases).
 
132
 
133
- # STRATEGY: Just widen the net significantly.
134
- # Vector models often hide the match at rank 30 or 40 if the spelling differs.
135
- results = db.similarity_search(query, k=50) # Widen from 25 to 50
136
-
137
- if not results:
138
- return []
139
 
140
- # 2. Reranking
141
- passages = [doc.page_content for doc in results]
142
- ranks = reranker.rank(query, passages)
143
-
144
- top_results = []
145
- sorted_ranks = sorted(ranks, key=lambda x: x['score'], reverse=True)
146
-
147
- # Return the top k results
148
- for rank in sorted_ranks[:k]:
149
- doc_index = rank['corpus_id']
150
- doc = results[doc_index]
151
- doc.metadata["relevance_score"] = round(rank['score'], 4)
152
- top_results.append(doc)
153
-
154
- return top_results
155
-
156
  except Exception as e:
157
- print(f"⚠️ Search Error (likely empty DB): {e}")
158
  return []
159
 
 
160
  def list_documents(username):
 
 
 
 
 
161
  try:
162
- db = get_vectorstore(username)
163
- # Check if empty before fetching to prevent errors
164
- if db._collection.count() == 0:
165
- return []
166
-
167
- data = db.get()
168
  metadatas = data['metadatas']
169
 
170
- file_stats = {}
171
-
172
- for meta in metadatas:
173
- src = meta.get('source', 'unknown')
174
- filename = os.path.basename(src)
175
- # FIX #2: Retrieve the strategy (Default to 'unknown' for old docs)
176
- strat = meta.get('strategy', 'unknown')
177
-
178
- if src not in file_stats:
179
- file_stats[src] = {
180
- 'source': src,
181
- 'filename': filename,
182
- 'chunks': 0,
183
- 'strategy': strat
184
- }
185
- file_stats[src]['chunks'] += 1
186
 
187
- return list(file_stats.values())
188
-
189
- except Exception as e:
190
- print(f"❌ Error listing docs: {e}")
191
  return []
192
 
193
- def delete_document(username, source_path):
 
 
194
  try:
195
- print(f"🗑️ Deleting {source_path} for {username}...")
196
- db = get_vectorstore(username)
197
- db.delete(where={"source": source_path})
198
- tracker.upload_user_db(username)
199
- return True, f"Deleted {os.path.basename(source_path)}"
 
 
 
 
 
 
 
 
 
200
  except Exception as e:
201
- return False, str(e)
202
 
203
  def reset_knowledge_base(username):
204
- try:
205
- db = get_vectorstore(username)
206
- db.delete_collection()
207
- tracker.upload_user_db(username)
208
- return True, "Knowledge Base completely reset."
209
- except Exception as e:
210
- return False, str(e)
 
1
  import os
2
+ import shutil
3
+ import time
4
+ from langchain_text_splitters import RecursiveCharacterTextSplitter, TokenTextSplitter
5
  from langchain_chroma import Chroma
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
+ from langchain_community.docstore.document import Document
8
+ from sentence_transformers import CrossEncoder # Re-added for Reranking
9
+ import doc_loader
10
 
11
  # --- CONFIGURATION ---
12
+ CHROMA_PATH = "chroma_db"
13
+ UPLOAD_DIR = "temp_ingest" # Re-added directory constant
14
+ EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
15
+ RERANK_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" # Re-added model name
16
+
17
+ # --- LAZY LOADING GLOBALS ---
18
+ # We use a global variable pattern to avoid loading heavy models
19
+ # until the moment they are actually needed (saves startup RAM).
20
+ _embedding_func = None
21
+ _rerank_model = None
22
+
23
+ def get_embedding_func():
24
+ """Lazy loads the embedding model."""
25
+ global _embedding_func
26
+ if _embedding_func is None:
27
+ print(f"⏳ Loading Embedding Model: {EMBED_MODEL_NAME}...")
28
+ _embedding_func = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
29
+ print("✅ Embedding Model Loaded.")
30
+ return _embedding_func
31
+
32
+ def get_rerank_model():
33
+ """Lazy loads the Cross-Encoder model."""
34
+ global _rerank_model
35
+ if _rerank_model is None:
36
+ print(f"⏳ Loading Reranker: {RERANK_MODEL_NAME}...")
37
+ _rerank_model = CrossEncoder(RERANK_MODEL_NAME)
38
+ print("✅ Reranker Loaded.")
39
+ return _rerank_model
40
+
41
+ # --- FILE OPERATIONS ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def save_uploaded_file(uploaded_file):
43
+ """Saves uploaded file to the temp directory."""
44
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
45
+ file_path = os.path.join(UPLOAD_DIR, uploaded_file.name)
46
+
47
  with open(file_path, "wb") as f:
48
  f.write(uploaded_file.getbuffer())
49
+
50
  return file_path
51
 
52
+ # --- INGESTION PIPELINE ---
53
+ def process_and_add_document(file_path, username, strategy, use_vision=False, api_key=None):
54
+ """
55
+ Ingests a document using the Universal Loader and adds it to the user's vector DB.
56
+ """
57
+ user_db_path = os.path.join(CHROMA_PATH, username)
58
+
59
  try:
60
+ # 1. EXTRACT TEXT (Using doc_loader)
61
+ # We need a pseudo-object because doc_loader expects a Streamlit object,
62
+ # but we are reading from disk.
63
+ with open(file_path, "rb") as f:
64
+ class FileObj:
65
+ def __init__(self, f, name):
66
+ self.f = f
67
+ self.name = name
68
+ def read(self): return self.f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ file_obj = FileObj(f, os.path.basename(file_path))
71
+ raw_text = doc_loader.extract_text_from_file(file_obj, use_vision=use_vision, api_key=api_key)
72
+
73
+ if not raw_text or not raw_text.strip():
74
+ return False, "Document appears empty or could not be read."
75
+
76
+ # 2. CHUNK TEXT
77
+ chunks = []
78
+ if strategy == "paragraph":
79
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
80
+ chunks = splitter.split_text(raw_text)
81
+ elif strategy == "token":
82
+ splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=50)
83
+ chunks = splitter.split_text(raw_text)
84
+ elif strategy == "page":
85
+ splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
86
+ chunks = splitter.split_text(raw_text)
87
+
88
+ # 3. CREATE DOCUMENTS
89
+ docs = [
90
+ Document(
91
+ page_content=chunk,
92
+ metadata={"source": os.path.basename(file_path), "strategy": strategy}
93
+ )
94
+ for chunk in chunks
95
+ ]
96
+
97
+ # 4. INDEX TO CHROMA
98
+ if docs:
99
+ # Use the getter function (Lazy Load)
100
+ emb_fn = get_embedding_func()
101
+ db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
102
+ db.add_documents(docs)
103
+ return True, f"Successfully indexed {len(docs)} chunks from {os.path.basename(file_path)}."
104
+ else:
105
+ return False, "No chunks created."
106
+
107
  except Exception as e:
108
+ return False, f"Error processing document: {e}"
 
109
 
110
+ # --- SEARCH PIPELINE (Now with Reranking!) ---
111
+ def search_knowledge_base(query, username, k=10, final_k=4):
112
  """
113
+ Retrieves top K chunks, then uses a Cross-Encoder to re-rank them
114
+ and returns the top final_k most relevant chunks.
 
 
115
  """
116
+ user_db_path = os.path.join(CHROMA_PATH, username)
117
+ if not os.path.exists(user_db_path):
118
+ return []
119
+
120
  try:
121
+ # 1. INITIAL RETRIEVAL (Vector Similarity)
122
+ emb_fn = get_embedding_func()
123
+ db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
124
+ # Fetch more candidates (k=10) to give the reranker options
125
+ results = db.similarity_search_with_relevance_scores(query, k=k)
126
+
127
+ if not results:
128
  return []
129
 
130
+ # 2. RERANKING
131
+ # Extract just the text for the cross-encoder
132
+ candidate_docs = [doc for doc, _ in results]
133
+ candidate_texts = [doc.page_content for doc in candidate_docs]
134
 
135
+ # Form pairs: (Query, Document Text)
136
+ pairs = [[query, text] for text in candidate_texts]
 
137
 
138
+ # Score pairs
139
+ reranker = get_rerank_model()
140
+ scores = reranker.predict(pairs)
141
 
142
+ # Attach scores to documents and sort
143
+ scored_docs = list(zip(candidate_docs, scores))
144
+ # Sort by score descending (High score = Better match)
145
+ scored_docs.sort(key=lambda x: x[1], reverse=True)
146
 
147
+ # 3. RETURN TOP N
148
+ # Return only the document objects of the top final_k
149
+ final_docs = [doc for doc, score in scored_docs[:final_k]]
150
+ return final_docs
 
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  except Exception as e:
153
+ print(f"RAG Error: {e}")
154
  return []
155
 
156
+ # --- MANAGEMENT UTILS ---
157
  def list_documents(username):
158
+ """Returns a list of unique sources in the user's DB."""
159
+ user_db_path = os.path.join(CHROMA_PATH, username)
160
+ if not os.path.exists(user_db_path):
161
+ return []
162
+
163
  try:
164
+ emb_fn = get_embedding_func()
165
+ db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
166
+ data = db.get()
 
 
 
167
  metadatas = data['metadatas']
168
 
169
+ inventory = {}
170
+ for m in metadatas:
171
+ src = m.get('source', 'Unknown')
172
+ if src not in inventory:
173
+ inventory[src] = {"chunks": 0, "strategy": m.get('strategy', 'Unknown')}
174
+ inventory[src]["chunks"] += 1
 
 
 
 
 
 
 
 
 
 
175
 
176
+ return [{"filename": k, "chunks": v["chunks"], "strategy": v["strategy"], "source": k} for k, v in inventory.items()]
177
+ except:
 
 
178
  return []
179
 
180
+ def delete_document(username, source_name):
181
+ """Removes all chunks associated with a specific source file."""
182
+ user_db_path = os.path.join(CHROMA_PATH, username)
183
  try:
184
+ emb_fn = get_embedding_func()
185
+ db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
186
+
187
+ data = db.get()
188
+ ids_to_delete = []
189
+ for i, meta in enumerate(data['metadatas']):
190
+ if meta.get('source') == source_name:
191
+ ids_to_delete.append(data['ids'][i])
192
+
193
+ if ids_to_delete:
194
+ db.delete(ids=ids_to_delete)
195
+ return True, f"Deleted {source_name}."
196
+ else:
197
+ return False, "File not found in index."
198
  except Exception as e:
199
+ return False, f"Delete failed: {e}"
200
 
201
  def reset_knowledge_base(username):
202
+ """Wipes the entire user database."""
203
+ user_db_path = os.path.join(CHROMA_PATH, username)
204
+ if os.path.exists(user_db_path):
205
+ shutil.rmtree(user_db_path)
206
+ return True, "Database Reset."
207
+ return False, "Database already empty."