Spaces:
Running
Running
File size: 7,028 Bytes
c88e290 7841205 c88e290 8e0483e c88e290 7841205 c88e290 b1112c5 c88e290 7841205 c88e290 7841205 1fd5385 c88e290 7841205 b1112c5 7841205 c88e290 7841205 c88e290 7841205 c88e290 7841205 c88e290 3755446 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | import os
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
from core.ChunkingManager import ChunkingManager, ChunkingStrategy
import tracker
# --- CONFIGURATION ---
UPLOAD_DIR = "/tmp/rag_uploads"
DB_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chroma_db")
EMBEDDING_MODEL_NAME = "all-MiniLM-L12-v2"
RERANKER_MODEL_NAME = "BAAI/bge-reranker-base"
# --- LAZY LOADING SINGLETONS ---
_embedding_fn = None
_reranker = None
_chunk_manager = None
def get_embedding_function():
global _embedding_fn
if _embedding_fn is None:
print("⚙️ Loading Embedding Model...")
_embedding_fn = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
return _embedding_fn
def get_reranker_model():
global _reranker
if _reranker is None:
print("⚙️ Loading Reranker Model...")
_reranker = CrossEncoder(RERANKER_MODEL_NAME)
return _reranker
def get_chunk_manager():
global _chunk_manager
if _chunk_manager is None:
print("⚙️ Loading Chunk Manager...")
_chunk_manager = ChunkingManager(embedding_model_name=EMBEDDING_MODEL_NAME)
return _chunk_manager
# --- DATABASE OPERATIONS ---
def get_vectorstore(username):
safe_username = os.path.basename(username)
user_db_path = os.path.join(DB_ROOT, safe_username)
if not os.path.exists(user_db_path):
os.makedirs(user_db_path, exist_ok=True)
return Chroma(
persist_directory=user_db_path,
embedding_function=get_embedding_function(),
collection_name=f"docs_{safe_username}"
)
def save_uploaded_file(uploaded_file):
if not os.path.exists(UPLOAD_DIR):
os.makedirs(UPLOAD_DIR)
safe_filename = os.path.basename(uploaded_file.name)
file_path = os.path.join(UPLOAD_DIR, safe_filename)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
return file_path
def process_and_add_document(file_path, username, strategy="paragraph"):
try:
print(f"🧠 Chunking {os.path.basename(file_path)}...")
strat_map = {
"token": ChunkingStrategy.TOKEN,
"paragraph": ChunkingStrategy.PARAGRAPH,
"page": ChunkingStrategy.PAGE
}
selected_strategy = strat_map.get(strategy, ChunkingStrategy.PARAGRAPH)
manager = get_chunk_manager()
chunks = manager.process_document(
file_path=file_path,
strategy=selected_strategy,
preprocess=True
)
if not chunks:
return False, "No text extracted. Is the file empty/scanned?"
# FIX #1: Tag every chunk with the strategy used
for chunk in chunks:
chunk.metadata["strategy"] = strategy
print(f"💾 Indexing {len(chunks)} chunks into Vector DB...")
db = get_vectorstore(username)
db.add_documents(chunks)
tracker.upload_user_db(username)
if os.path.exists(file_path):
os.remove(file_path)
return True, f"Successfully added {len(chunks)} chunks to Knowledge Base."
except Exception as e:
print(f"❌ Processing Error: {e}")
return False, str(e)
# --- RETRIEVAL ENGINE ---
def search_knowledge_base(query, username, k=6):
"""
Two-Stage Retrieval System (RAG):
1. Retrieval: Get 10 candidates via fast Vector Search.
2. Reranking: Sort them via Cross-Encoder (Slow/Precise).
3. Return top k.
"""
try:
db = get_vectorstore(username)
if db._collection.count() == 0:
return []
reranker = get_reranker_model()
# 1. Vector Search (Broad Net)
vector_results = db.similarity_search(query, k=25)
# 2. "Poor Man's" Keyword Search (The Safety Net)
# We perform a basic text search for unique terms in the query
# This catches acronyms like "C&D" if we normalize them
# Normalize query acronyms (e.g., "C&D" -> "C D")
normalized_query = query.replace("&", " ")
keyword_results = []
# (Optional: In a production DB like Pinecone/Weaviate, this is built-in.
# For Chroma local, we rely on the vector net mostly, but we can
# extend k significantly to catch edge cases).
# STRATEGY: Just widen the net significantly.
# Vector models often hide the match at rank 30 or 40 if the spelling differs.
results = db.similarity_search(query, k=50) # Widen from 25 to 50
if not results:
return []
# 2. Reranking
passages = [doc.page_content for doc in results]
ranks = reranker.rank(query, passages)
top_results = []
sorted_ranks = sorted(ranks, key=lambda x: x['score'], reverse=True)
# Return the top k results
for rank in sorted_ranks[:k]:
doc_index = rank['corpus_id']
doc = results[doc_index]
doc.metadata["relevance_score"] = round(rank['score'], 4)
top_results.append(doc)
return top_results
except Exception as e:
print(f"⚠️ Search Error (likely empty DB): {e}")
return []
def list_documents(username):
try:
db = get_vectorstore(username)
# Check if empty before fetching to prevent errors
if db._collection.count() == 0:
return []
data = db.get()
metadatas = data['metadatas']
file_stats = {}
for meta in metadatas:
src = meta.get('source', 'unknown')
filename = os.path.basename(src)
# FIX #2: Retrieve the strategy (Default to 'unknown' for old docs)
strat = meta.get('strategy', 'unknown')
if src not in file_stats:
file_stats[src] = {
'source': src,
'filename': filename,
'chunks': 0,
'strategy': strat
}
file_stats[src]['chunks'] += 1
return list(file_stats.values())
except Exception as e:
print(f"❌ Error listing docs: {e}")
return []
def delete_document(username, source_path):
try:
print(f"🗑️ Deleting {source_path} for {username}...")
db = get_vectorstore(username)
db.delete(where={"source": source_path})
tracker.upload_user_db(username)
return True, f"Deleted {os.path.basename(source_path)}"
except Exception as e:
return False, str(e)
def reset_knowledge_base(username):
try:
db = get_vectorstore(username)
db.delete_collection()
tracker.upload_user_db(username)
return True, "Knowledge Base completely reset."
except Exception as e:
return False, str(e) |