NavyDevilDoc commited on
Commit
c88e290
·
verified ·
1 Parent(s): 2201a66

Update src/rag_engine.py

Browse files

increased k from 3 retrieved documents to 10

Files changed (1) hide show
  1. src/rag_engine.py +197 -197
src/rag_engine.py CHANGED
@@ -1,198 +1,198 @@
1
- import os
2
- from langchain_chroma import Chroma
3
- from langchain_huggingface import HuggingFaceEmbeddings
4
- from sentence_transformers import CrossEncoder
5
- from core.ChunkingManager import ChunkingManager, ChunkingStrategy
6
- import tracker # To trigger syncs
7
-
8
- # --- CONFIGURATION ---
9
- UPLOAD_DIR = "/tmp/rag_uploads"
10
- DB_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chroma_db")
11
- EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
12
- RERANKER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
13
-
14
- # --- LAZY LOADING SINGLETONS ---
15
- # We use these globals to store the models once loaded, so we don't reload them
16
- # every time a function is called, but we also don't load them on import.
17
- _embedding_fn = None
18
- _reranker = None
19
- _chunk_manager = None
20
-
21
- def get_embedding_function():
22
- """Lazy loads the embedding model only when needed."""
23
- global _embedding_fn
24
- if _embedding_fn is None:
25
- print("⚙️ Loading Embedding Model...")
26
- _embedding_fn = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
27
- return _embedding_fn
28
-
29
- def get_reranker_model():
30
- """Lazy loads the CrossEncoder only when needed."""
31
- global _reranker
32
- if _reranker is None:
33
- print("⚙️ Loading Reranker Model...")
34
- _reranker = CrossEncoder(RERANKER_MODEL_NAME)
35
- return _reranker
36
-
37
- def get_chunk_manager():
38
- """Lazy loads the Chunking Manager."""
39
- global _chunk_manager
40
- if _chunk_manager is None:
41
- print("⚙️ Loading Chunk Manager...")
42
- _chunk_manager = ChunkingManager(embedding_model_name=EMBEDDING_MODEL_NAME)
43
- return _chunk_manager
44
-
45
- # --- DATABASE OPERATIONS ---
46
- def get_vectorstore(username):
47
- """Returns the persistent ChromaDB for a SPECIFIC USER."""
48
- # Safety: Ensure username doesn't contain path traversal characters
49
- safe_username = os.path.basename(username)
50
- user_db_path = os.path.join(DB_ROOT, safe_username)
51
-
52
- if not os.path.exists(user_db_path):
53
- os.makedirs(user_db_path, exist_ok=True)
54
-
55
- return Chroma(
56
- persist_directory=user_db_path,
57
- embedding_function=get_embedding_function(),
58
- collection_name=f"docs_{safe_username}"
59
- )
60
-
61
- def save_uploaded_file(uploaded_file):
62
- """Saves upload to temp, sanitizing the filename."""
63
- if not os.path.exists(UPLOAD_DIR):
64
- os.makedirs(UPLOAD_DIR)
65
-
66
- # SECURITY FIX: Sanitize filename to prevent directory traversal
67
- safe_filename = os.path.basename(uploaded_file.name)
68
- file_path = os.path.join(UPLOAD_DIR, safe_filename)
69
-
70
- with open(file_path, "wb") as f:
71
- f.write(uploaded_file.getbuffer())
72
- return file_path
73
-
74
- def process_and_add_document(file_path, username, strategy="paragraph"):
75
- try:
76
- print(f"🧠 Chunking {os.path.basename(file_path)}...")
77
-
78
- strat_map = {
79
- "token": ChunkingStrategy.TOKEN,
80
- "paragraph": ChunkingStrategy.PARAGRAPH,
81
- "page": ChunkingStrategy.PAGE
82
- }
83
- selected_strategy = strat_map.get(strategy, ChunkingStrategy.PARAGRAPH)
84
-
85
- # Use the lazy-loaded chunk manager
86
- manager = get_chunk_manager()
87
- chunks = manager.process_document(
88
- file_path=file_path,
89
- strategy=selected_strategy,
90
- preprocess=True
91
- )
92
-
93
- if not chunks:
94
- return False, "No text extracted. Is the file empty/scanned?"
95
-
96
- print(f"💾 Indexing {len(chunks)} chunks into Vector DB...")
97
- db = get_vectorstore(username)
98
- db.add_documents(chunks)
99
-
100
- # Sync immediately
101
- tracker.upload_user_db(username)
102
-
103
- if os.path.exists(file_path):
104
- os.remove(file_path)
105
-
106
- return True, f"Successfully added {len(chunks)} chunks to Knowledge Base."
107
-
108
- except Exception as e:
109
- print(f"❌ Processing Error: {e}")
110
- return False, str(e)
111
-
112
- # --- RETRIEVAL ENGINE ---
113
- def search_knowledge_base(query, username, k=3):
114
- """
115
- Two-Stage Retrieval System (RAG):
116
- 1. Retrieval: Get 10 candidates via fast Vector Search.
117
- 2. Reranking: Sort them via Cross-Encoder (Slow/Precise).
118
- 3. Return top k.
119
- """
120
- db = get_vectorstore(username)
121
- reranker = get_reranker_model()
122
-
123
- # 1. Broad Search (Retrieve more than needed to filter later)
124
- results = db.similarity_search(query, k=10)
125
-
126
- if not results:
127
- return []
128
-
129
- # 2. Reranking
130
- # Prepare pairs: [[Query, Text1], [Query, Text2]...]
131
- passages = [doc.page_content for doc in results]
132
- ranks = reranker.rank(query, passages)
133
-
134
- # 3. Sort and Filter
135
- # Reranker returns list of dicts: {'corpus_id': 0, 'score': 0.9}
136
- top_results = []
137
-
138
- # Sort ranks by score descending just to be safe (though .rank() usually sorts)
139
- sorted_ranks = sorted(ranks, key=lambda x: x['score'], reverse=True)
140
-
141
- for rank in sorted_ranks[:k]:
142
- doc_index = rank['corpus_id']
143
- doc = results[doc_index]
144
- # Append score for transparency
145
- doc.metadata["relevance_score"] = round(rank['score'], 4)
146
- top_results.append(doc)
147
-
148
- return top_results
149
-
150
- def list_documents(username):
151
- """
152
- Returns a list of unique files currently in the user's database.
153
- WARNING: This pulls all metadata. Performance degrades >10k chunks.
154
- """
155
- try:
156
- db = get_vectorstore(username)
157
- data = db.get()
158
- metadatas = data['metadatas']
159
-
160
- file_stats = {}
161
-
162
- for meta in metadatas:
163
- src = meta.get('source', 'unknown')
164
- filename = os.path.basename(src)
165
-
166
- if src not in file_stats:
167
- file_stats[src] = {'source': src, 'filename': filename, 'chunks': 0}
168
- file_stats[src]['chunks'] += 1
169
-
170
- return list(file_stats.values())
171
-
172
- except Exception as e:
173
- print(f"❌ Error listing docs: {e}")
174
- return []
175
-
176
- def delete_document(username, source_path):
177
- """Removes all chunks associated with a specific source file."""
178
- try:
179
- print(f"🗑️ Deleting {source_path} for {username}...")
180
- db = get_vectorstore(username)
181
-
182
- db.delete(where={"source": source_path})
183
-
184
- tracker.upload_user_db(username)
185
- return True, f"Deleted {os.path.basename(source_path)}"
186
-
187
- except Exception as e:
188
- return False, str(e)
189
-
190
- def reset_knowledge_base(username):
191
- """Nuke option: Clears the entire database for the user."""
192
- try:
193
- db = get_vectorstore(username)
194
- db.delete_collection()
195
- tracker.upload_user_db(username)
196
- return True, "Knowledge Base completely reset."
197
- except Exception as e:
198
  return False, str(e)
 
1
+ import os
2
+ from langchain_chroma import Chroma
3
+ from langchain_huggingface import HuggingFaceEmbeddings
4
+ from sentence_transformers import CrossEncoder
5
+ from core.ChunkingManager import ChunkingManager, ChunkingStrategy
6
+ import tracker # To trigger syncs
7
+
8
+ # --- CONFIGURATION ---
9
+ UPLOAD_DIR = "/tmp/rag_uploads"
10
+ DB_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chroma_db")
11
+ EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
12
+ RERANKER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
13
+
14
+ # --- LAZY LOADING SINGLETONS ---
15
+ # We use these globals to store the models once loaded, so we don't reload them
16
+ # every time a function is called, but we also don't load them on import.
17
+ _embedding_fn = None
18
+ _reranker = None
19
+ _chunk_manager = None
20
+
21
+ def get_embedding_function():
22
+ """Lazy loads the embedding model only when needed."""
23
+ global _embedding_fn
24
+ if _embedding_fn is None:
25
+ print("⚙️ Loading Embedding Model...")
26
+ _embedding_fn = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
27
+ return _embedding_fn
28
+
29
+ def get_reranker_model():
30
+ """Lazy loads the CrossEncoder only when needed."""
31
+ global _reranker
32
+ if _reranker is None:
33
+ print("⚙️ Loading Reranker Model...")
34
+ _reranker = CrossEncoder(RERANKER_MODEL_NAME)
35
+ return _reranker
36
+
37
+ def get_chunk_manager():
38
+ """Lazy loads the Chunking Manager."""
39
+ global _chunk_manager
40
+ if _chunk_manager is None:
41
+ print("⚙️ Loading Chunk Manager...")
42
+ _chunk_manager = ChunkingManager(embedding_model_name=EMBEDDING_MODEL_NAME)
43
+ return _chunk_manager
44
+
45
+ # --- DATABASE OPERATIONS ---
46
+ def get_vectorstore(username):
47
+ """Returns the persistent ChromaDB for a SPECIFIC USER."""
48
+ # Safety: Ensure username doesn't contain path traversal characters
49
+ safe_username = os.path.basename(username)
50
+ user_db_path = os.path.join(DB_ROOT, safe_username)
51
+
52
+ if not os.path.exists(user_db_path):
53
+ os.makedirs(user_db_path, exist_ok=True)
54
+
55
+ return Chroma(
56
+ persist_directory=user_db_path,
57
+ embedding_function=get_embedding_function(),
58
+ collection_name=f"docs_{safe_username}"
59
+ )
60
+
61
+ def save_uploaded_file(uploaded_file):
62
+ """Saves upload to temp, sanitizing the filename."""
63
+ if not os.path.exists(UPLOAD_DIR):
64
+ os.makedirs(UPLOAD_DIR)
65
+
66
+ # SECURITY FIX: Sanitize filename to prevent directory traversal
67
+ safe_filename = os.path.basename(uploaded_file.name)
68
+ file_path = os.path.join(UPLOAD_DIR, safe_filename)
69
+
70
+ with open(file_path, "wb") as f:
71
+ f.write(uploaded_file.getbuffer())
72
+ return file_path
73
+
74
+ def process_and_add_document(file_path, username, strategy="paragraph"):
75
+ try:
76
+ print(f"🧠 Chunking {os.path.basename(file_path)}...")
77
+
78
+ strat_map = {
79
+ "token": ChunkingStrategy.TOKEN,
80
+ "paragraph": ChunkingStrategy.PARAGRAPH,
81
+ "page": ChunkingStrategy.PAGE
82
+ }
83
+ selected_strategy = strat_map.get(strategy, ChunkingStrategy.PARAGRAPH)
84
+
85
+ # Use the lazy-loaded chunk manager
86
+ manager = get_chunk_manager()
87
+ chunks = manager.process_document(
88
+ file_path=file_path,
89
+ strategy=selected_strategy,
90
+ preprocess=True
91
+ )
92
+
93
+ if not chunks:
94
+ return False, "No text extracted. Is the file empty/scanned?"
95
+
96
+ print(f"💾 Indexing {len(chunks)} chunks into Vector DB...")
97
+ db = get_vectorstore(username)
98
+ db.add_documents(chunks)
99
+
100
+ # Sync immediately
101
+ tracker.upload_user_db(username)
102
+
103
+ if os.path.exists(file_path):
104
+ os.remove(file_path)
105
+
106
+ return True, f"Successfully added {len(chunks)} chunks to Knowledge Base."
107
+
108
+ except Exception as e:
109
+ print(f"❌ Processing Error: {e}")
110
+ return False, str(e)
111
+
112
+ # --- RETRIEVAL ENGINE ---
113
+ def search_knowledge_base(query, username, k=10):
114
+ """
115
+ Two-Stage Retrieval System (RAG):
116
+ 1. Retrieval: Get 10 candidates via fast Vector Search.
117
+ 2. Reranking: Sort them via Cross-Encoder (Slow/Precise).
118
+ 3. Return top k.
119
+ """
120
+ db = get_vectorstore(username)
121
+ reranker = get_reranker_model()
122
+
123
+ # 1. Broad Search (Retrieve more than needed to filter later)
124
+ results = db.similarity_search(query, k=10)
125
+
126
+ if not results:
127
+ return []
128
+
129
+ # 2. Reranking
130
+ # Prepare pairs: [[Query, Text1], [Query, Text2]...]
131
+ passages = [doc.page_content for doc in results]
132
+ ranks = reranker.rank(query, passages)
133
+
134
+ # 3. Sort and Filter
135
+ # Reranker returns list of dicts: {'corpus_id': 0, 'score': 0.9}
136
+ top_results = []
137
+
138
+ # Sort ranks by score descending just to be safe (though .rank() usually sorts)
139
+ sorted_ranks = sorted(ranks, key=lambda x: x['score'], reverse=True)
140
+
141
+ for rank in sorted_ranks[:k]:
142
+ doc_index = rank['corpus_id']
143
+ doc = results[doc_index]
144
+ # Append score for transparency
145
+ doc.metadata["relevance_score"] = round(rank['score'], 4)
146
+ top_results.append(doc)
147
+
148
+ return top_results
149
+
150
+ def list_documents(username):
151
+ """
152
+ Returns a list of unique files currently in the user's database.
153
+ WARNING: This pulls all metadata. Performance degrades >10k chunks.
154
+ """
155
+ try:
156
+ db = get_vectorstore(username)
157
+ data = db.get()
158
+ metadatas = data['metadatas']
159
+
160
+ file_stats = {}
161
+
162
+ for meta in metadatas:
163
+ src = meta.get('source', 'unknown')
164
+ filename = os.path.basename(src)
165
+
166
+ if src not in file_stats:
167
+ file_stats[src] = {'source': src, 'filename': filename, 'chunks': 0}
168
+ file_stats[src]['chunks'] += 1
169
+
170
+ return list(file_stats.values())
171
+
172
+ except Exception as e:
173
+ print(f"❌ Error listing docs: {e}")
174
+ return []
175
+
176
+ def delete_document(username, source_path):
177
+ """Removes all chunks associated with a specific source file."""
178
+ try:
179
+ print(f"🗑️ Deleting {source_path} for {username}...")
180
+ db = get_vectorstore(username)
181
+
182
+ db.delete(where={"source": source_path})
183
+
184
+ tracker.upload_user_db(username)
185
+ return True, f"Deleted {os.path.basename(source_path)}"
186
+
187
+ except Exception as e:
188
+ return False, str(e)
189
+
190
+ def reset_knowledge_base(username):
191
+ """Nuke option: Clears the entire database for the user."""
192
+ try:
193
+ db = get_vectorstore(username)
194
+ db.delete_collection()
195
+ tracker.upload_user_db(username)
196
+ return True, "Knowledge Base completely reset."
197
+ except Exception as e:
198
  return False, str(e)