Spaces:
Sleeping
Sleeping
Update src/rag_engine.py
Browse filesadded support for new AcronymManager.py file
- src/rag_engine.py +22 -10
src/rag_engine.py
CHANGED
|
@@ -13,6 +13,7 @@ from sentence_transformers import CrossEncoder
|
|
| 13 |
# --- CUSTOM CORE IMPORTS ---
|
| 14 |
from core.ParagraphChunker import ParagraphChunker
|
| 15 |
from core.TokenChunker import TokenChunker
|
|
|
|
| 16 |
|
| 17 |
# --- CONFIGURATION ---
|
| 18 |
CHROMA_PATH = "chroma_db"
|
|
@@ -180,17 +181,20 @@ def process_and_add_text(text: str, source_name: str, username: str) -> Tuple[bo
|
|
| 180 |
return False, f"Error: {str(e)}"
|
| 181 |
|
| 182 |
def ingest_file(file_path: str, username: str, strategy: str = "paragraph") -> Tuple[bool, str]:
|
| 183 |
-
"""
|
| 184 |
-
The High-Level Bridge: Takes a file path, chunks it, and saves to Vector DB.
|
| 185 |
-
Replaces the old 'process_and_add_document'.
|
| 186 |
-
"""
|
| 187 |
try:
|
| 188 |
-
# 1. Chunk the file
|
| 189 |
docs = process_file(file_path, chunking_strategy=strategy)
|
| 190 |
|
| 191 |
if not docs:
|
| 192 |
return False, "No valid chunks generated from file."
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
# 2. Add to Chroma DB
|
| 195 |
user_db_path = os.path.join(CHROMA_PATH, username)
|
| 196 |
emb_fn = get_embedding_func()
|
|
@@ -205,24 +209,32 @@ def ingest_file(file_path: str, username: str, strategy: str = "paragraph") -> T
|
|
| 205 |
return False, f"System Error: {str(e)}"
|
| 206 |
|
| 207 |
def search_knowledge_base(query: str, username: str, k: int = 10, final_k: int = 4) -> List[Document]:
|
| 208 |
-
"""Retrieves top K chunks, then uses Cross-Encoder to re-rank them."""
|
| 209 |
user_db_path = os.path.join(CHROMA_PATH, username)
|
| 210 |
if not os.path.exists(user_db_path):
|
| 211 |
return []
|
| 212 |
|
| 213 |
try:
|
| 214 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
emb_fn = get_embedding_func()
|
| 216 |
db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
|
| 217 |
-
results = db.similarity_search_with_relevance_scores(
|
| 218 |
|
| 219 |
if not results:
|
| 220 |
return []
|
| 221 |
|
| 222 |
-
# 2. Reranking
|
| 223 |
candidate_docs = [doc for doc, _ in results]
|
| 224 |
candidate_texts = [doc.page_content for doc in candidate_docs]
|
| 225 |
-
pairs = [[
|
| 226 |
|
| 227 |
reranker = get_rerank_model()
|
| 228 |
scores = reranker.predict(pairs)
|
|
|
|
| 13 |
# --- CUSTOM CORE IMPORTS ---
|
| 14 |
from core.ParagraphChunker import ParagraphChunker
|
| 15 |
from core.TokenChunker import TokenChunker
|
| 16 |
+
from core.AcronymManager import AcronymManager
|
| 17 |
|
| 18 |
# --- CONFIGURATION ---
|
| 19 |
CHROMA_PATH = "chroma_db"
|
|
|
|
| 181 |
return False, f"Error: {str(e)}"
|
| 182 |
|
| 183 |
def ingest_file(file_path: str, username: str, strategy: str = "paragraph") -> Tuple[bool, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
try:
|
| 185 |
+
# 1. Chunk the file
|
| 186 |
docs = process_file(file_path, chunking_strategy=strategy)
|
| 187 |
|
| 188 |
if not docs:
|
| 189 |
return False, "No valid chunks generated from file."
|
| 190 |
|
| 191 |
+
# --- ACRONYM SCANNING ---
|
| 192 |
+
# We scan the raw text of the chunks to learn new definitions
|
| 193 |
+
acronym_mgr = AcronymManager()
|
| 194 |
+
for doc in docs:
|
| 195 |
+
acronym_mgr.scan_text_for_acronyms(doc.page_content)
|
| 196 |
+
# -----------------------------
|
| 197 |
+
|
| 198 |
# 2. Add to Chroma DB
|
| 199 |
user_db_path = os.path.join(CHROMA_PATH, username)
|
| 200 |
emb_fn = get_embedding_func()
|
|
|
|
| 209 |
return False, f"System Error: {str(e)}"
|
| 210 |
|
| 211 |
def search_knowledge_base(query: str, username: str, k: int = 10, final_k: int = 4) -> List[Document]:
|
|
|
|
| 212 |
user_db_path = os.path.join(CHROMA_PATH, username)
|
| 213 |
if not os.path.exists(user_db_path):
|
| 214 |
return []
|
| 215 |
|
| 216 |
try:
|
| 217 |
+
# --- NEW: QUERY EXPANSION ---
|
| 218 |
+
acronym_mgr = AcronymManager()
|
| 219 |
+
expanded_query = acronym_mgr.expand_query(query)
|
| 220 |
+
if expanded_query != query:
|
| 221 |
+
logger.info(f"Query Expanded: '{query}' -> '{expanded_query}'")
|
| 222 |
+
else:
|
| 223 |
+
expanded_query = query
|
| 224 |
+
# ----------------------------
|
| 225 |
+
|
| 226 |
+
# 1. Vector Retrieval (Use expanded_query instead of query)
|
| 227 |
emb_fn = get_embedding_func()
|
| 228 |
db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
|
| 229 |
+
results = db.similarity_search_with_relevance_scores(expanded_query, k=k) # <--- UPDATED VAR
|
| 230 |
|
| 231 |
if not results:
|
| 232 |
return []
|
| 233 |
|
| 234 |
+
# 2. Reranking (Pass expanded_query here too)
|
| 235 |
candidate_docs = [doc for doc, _ in results]
|
| 236 |
candidate_texts = [doc.page_content for doc in candidate_docs]
|
| 237 |
+
pairs = [[expanded_query, text] for text in candidate_texts] # <--- UPDATED VAR
|
| 238 |
|
| 239 |
reranker = get_rerank_model()
|
| 240 |
scores = reranker.predict(pairs)
|