NavyDevilDoc commited on
Commit
d469e88
·
verified ·
1 Parent(s): b417daa

Update src/rag_engine.py

Browse files

added support for new AcronymManager.py file

Files changed (1) hide show
  1. src/rag_engine.py +22 -10
src/rag_engine.py CHANGED
@@ -13,6 +13,7 @@ from sentence_transformers import CrossEncoder
13
  # --- CUSTOM CORE IMPORTS ---
14
  from core.ParagraphChunker import ParagraphChunker
15
  from core.TokenChunker import TokenChunker
 
16
 
17
  # --- CONFIGURATION ---
18
  CHROMA_PATH = "chroma_db"
@@ -180,17 +181,20 @@ def process_and_add_text(text: str, source_name: str, username: str) -> Tuple[bo
180
  return False, f"Error: {str(e)}"
181
 
182
  def ingest_file(file_path: str, username: str, strategy: str = "paragraph") -> Tuple[bool, str]:
183
- """
184
- The High-Level Bridge: Takes a file path, chunks it, and saves to Vector DB.
185
- Replaces the old 'process_and_add_document'.
186
- """
187
  try:
188
- # 1. Chunk the file using the new engine
189
  docs = process_file(file_path, chunking_strategy=strategy)
190
 
191
  if not docs:
192
  return False, "No valid chunks generated from file."
193
 
 
 
 
 
 
 
 
194
  # 2. Add to Chroma DB
195
  user_db_path = os.path.join(CHROMA_PATH, username)
196
  emb_fn = get_embedding_func()
@@ -205,24 +209,32 @@ def ingest_file(file_path: str, username: str, strategy: str = "paragraph") -> T
205
  return False, f"System Error: {str(e)}"
206
 
207
  def search_knowledge_base(query: str, username: str, k: int = 10, final_k: int = 4) -> List[Document]:
208
- """Retrieves top K chunks, then uses Cross-Encoder to re-rank them."""
209
  user_db_path = os.path.join(CHROMA_PATH, username)
210
  if not os.path.exists(user_db_path):
211
  return []
212
 
213
  try:
214
- # 1. Vector Retrieval
 
 
 
 
 
 
 
 
 
215
  emb_fn = get_embedding_func()
216
  db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
217
- results = db.similarity_search_with_relevance_scores(query, k=k)
218
 
219
  if not results:
220
  return []
221
 
222
- # 2. Reranking
223
  candidate_docs = [doc for doc, _ in results]
224
  candidate_texts = [doc.page_content for doc in candidate_docs]
225
- pairs = [[query, text] for text in candidate_texts]
226
 
227
  reranker = get_rerank_model()
228
  scores = reranker.predict(pairs)
 
13
  # --- CUSTOM CORE IMPORTS ---
14
  from core.ParagraphChunker import ParagraphChunker
15
  from core.TokenChunker import TokenChunker
16
+ from core.AcronymManager import AcronymManager
17
 
18
  # --- CONFIGURATION ---
19
  CHROMA_PATH = "chroma_db"
 
181
  return False, f"Error: {str(e)}"
182
 
183
  def ingest_file(file_path: str, username: str, strategy: str = "paragraph") -> Tuple[bool, str]:
 
 
 
 
184
  try:
185
+ # 1. Chunk the file
186
  docs = process_file(file_path, chunking_strategy=strategy)
187
 
188
  if not docs:
189
  return False, "No valid chunks generated from file."
190
 
191
+ # --- ACRONYM SCANNING ---
192
+ # We scan the raw text of the chunks to learn new definitions
193
+ acronym_mgr = AcronymManager()
194
+ for doc in docs:
195
+ acronym_mgr.scan_text_for_acronyms(doc.page_content)
196
+ # -----------------------------
197
+
198
  # 2. Add to Chroma DB
199
  user_db_path = os.path.join(CHROMA_PATH, username)
200
  emb_fn = get_embedding_func()
 
209
  return False, f"System Error: {str(e)}"
210
 
211
  def search_knowledge_base(query: str, username: str, k: int = 10, final_k: int = 4) -> List[Document]:
 
212
  user_db_path = os.path.join(CHROMA_PATH, username)
213
  if not os.path.exists(user_db_path):
214
  return []
215
 
216
  try:
217
+ # --- NEW: QUERY EXPANSION ---
218
+ acronym_mgr = AcronymManager()
219
+ expanded_query = acronym_mgr.expand_query(query)
220
+ if expanded_query != query:
221
+ logger.info(f"Query Expanded: '{query}' -> '{expanded_query}'")
222
+ else:
223
+ expanded_query = query
224
+ # ----------------------------
225
+
226
+ # 1. Vector Retrieval (Use expanded_query instead of query)
227
  emb_fn = get_embedding_func()
228
  db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
229
+ results = db.similarity_search_with_relevance_scores(expanded_query, k=k) # <--- UPDATED VAR
230
 
231
  if not results:
232
  return []
233
 
234
+ # 2. Reranking (Pass expanded_query here too)
235
  candidate_docs = [doc for doc, _ in results]
236
  candidate_texts = [doc.page_content for doc in candidate_docs]
237
+ pairs = [[expanded_query, text] for text in candidate_texts] # <--- UPDATED VAR
238
 
239
  reranker = get_rerank_model()
240
  scores = reranker.predict(pairs)