AI_Toolkit / src /rag_engine.py
NavyDevilDoc's picture
Update src/rag_engine.py
1fd5385 verified
raw
history blame
7.03 kB
import os
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
from core.ChunkingManager import ChunkingManager, ChunkingStrategy
import tracker
# --- CONFIGURATION ---
UPLOAD_DIR = "/tmp/rag_uploads"
DB_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chroma_db")
EMBEDDING_MODEL_NAME = "all-MiniLM-L12-v2"
RERANKER_MODEL_NAME = "BAAI/bge-reranker-base"
# --- LAZY LOADING SINGLETONS ---
_embedding_fn = None
_reranker = None
_chunk_manager = None
def get_embedding_function():
global _embedding_fn
if _embedding_fn is None:
print("⚙️ Loading Embedding Model...")
_embedding_fn = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
return _embedding_fn
def get_reranker_model():
global _reranker
if _reranker is None:
print("⚙️ Loading Reranker Model...")
_reranker = CrossEncoder(RERANKER_MODEL_NAME)
return _reranker
def get_chunk_manager():
global _chunk_manager
if _chunk_manager is None:
print("⚙️ Loading Chunk Manager...")
_chunk_manager = ChunkingManager(embedding_model_name=EMBEDDING_MODEL_NAME)
return _chunk_manager
# --- DATABASE OPERATIONS ---
def get_vectorstore(username):
safe_username = os.path.basename(username)
user_db_path = os.path.join(DB_ROOT, safe_username)
if not os.path.exists(user_db_path):
os.makedirs(user_db_path, exist_ok=True)
return Chroma(
persist_directory=user_db_path,
embedding_function=get_embedding_function(),
collection_name=f"docs_{safe_username}"
)
def save_uploaded_file(uploaded_file):
if not os.path.exists(UPLOAD_DIR):
os.makedirs(UPLOAD_DIR)
safe_filename = os.path.basename(uploaded_file.name)
file_path = os.path.join(UPLOAD_DIR, safe_filename)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
return file_path
def process_and_add_document(file_path, username, strategy="paragraph"):
try:
print(f"🧠 Chunking {os.path.basename(file_path)}...")
strat_map = {
"token": ChunkingStrategy.TOKEN,
"paragraph": ChunkingStrategy.PARAGRAPH,
"page": ChunkingStrategy.PAGE
}
selected_strategy = strat_map.get(strategy, ChunkingStrategy.PARAGRAPH)
manager = get_chunk_manager()
chunks = manager.process_document(
file_path=file_path,
strategy=selected_strategy,
preprocess=True
)
if not chunks:
return False, "No text extracted. Is the file empty/scanned?"
# FIX #1: Tag every chunk with the strategy used
for chunk in chunks:
chunk.metadata["strategy"] = strategy
print(f"💾 Indexing {len(chunks)} chunks into Vector DB...")
db = get_vectorstore(username)
db.add_documents(chunks)
tracker.upload_user_db(username)
if os.path.exists(file_path):
os.remove(file_path)
return True, f"Successfully added {len(chunks)} chunks to Knowledge Base."
except Exception as e:
print(f"❌ Processing Error: {e}")
return False, str(e)
# --- RETRIEVAL ENGINE ---
def search_knowledge_base(query, username, k=6):
"""
Two-Stage Retrieval System (RAG):
1. Retrieval: Get 10 candidates via fast Vector Search.
2. Reranking: Sort them via Cross-Encoder (Slow/Precise).
3. Return top k.
"""
try:
db = get_vectorstore(username)
if db._collection.count() == 0:
return []
reranker = get_reranker_model()
# 1. Vector Search (Broad Net)
vector_results = db.similarity_search(query, k=25)
# 2. "Poor Man's" Keyword Search (The Safety Net)
# We perform a basic text search for unique terms in the query
# This catches acronyms like "C&D" if we normalize them
# Normalize query acronyms (e.g., "C&D" -> "C D")
normalized_query = query.replace("&", " ")
keyword_results = []
# (Optional: In a production DB like Pinecone/Weaviate, this is built-in.
# For Chroma local, we rely on the vector net mostly, but we can
# extend k significantly to catch edge cases).
# STRATEGY: Just widen the net significantly.
# Vector models often hide the match at rank 30 or 40 if the spelling differs.
results = db.similarity_search(query, k=50) # Widen from 25 to 50
if not results:
return []
# 2. Reranking
passages = [doc.page_content for doc in results]
ranks = reranker.rank(query, passages)
top_results = []
sorted_ranks = sorted(ranks, key=lambda x: x['score'], reverse=True)
# Return the top k results
for rank in sorted_ranks[:k]:
doc_index = rank['corpus_id']
doc = results[doc_index]
doc.metadata["relevance_score"] = round(rank['score'], 4)
top_results.append(doc)
return top_results
except Exception as e:
print(f"⚠️ Search Error (likely empty DB): {e}")
return []
def list_documents(username):
try:
db = get_vectorstore(username)
# Check if empty before fetching to prevent errors
if db._collection.count() == 0:
return []
data = db.get()
metadatas = data['metadatas']
file_stats = {}
for meta in metadatas:
src = meta.get('source', 'unknown')
filename = os.path.basename(src)
# FIX #2: Retrieve the strategy (Default to 'unknown' for old docs)
strat = meta.get('strategy', 'unknown')
if src not in file_stats:
file_stats[src] = {
'source': src,
'filename': filename,
'chunks': 0,
'strategy': strat
}
file_stats[src]['chunks'] += 1
return list(file_stats.values())
except Exception as e:
print(f"❌ Error listing docs: {e}")
return []
def delete_document(username, source_path):
try:
print(f"🗑️ Deleting {source_path} for {username}...")
db = get_vectorstore(username)
db.delete(where={"source": source_path})
tracker.upload_user_db(username)
return True, f"Deleted {os.path.basename(source_path)}"
except Exception as e:
return False, str(e)
def reset_knowledge_base(username):
try:
db = get_vectorstore(username)
db.delete_collection()
tracker.upload_user_db(username)
return True, "Knowledge Base completely reset."
except Exception as e:
return False, str(e)