Spaces:
Sleeping
Sleeping
| import os | |
| from langchain_chroma import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from sentence_transformers import CrossEncoder | |
| from core.ChunkingManager import ChunkingManager, ChunkingStrategy | |
| import tracker | |
| # --- CONFIGURATION --- | |
| UPLOAD_DIR = "/tmp/rag_uploads" | |
| DB_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chroma_db") | |
| EMBEDDING_MODEL_NAME = "all-MiniLM-L12-v2" | |
| RERANKER_MODEL_NAME = "BAAI/bge-reranker-base" | |
| # --- LAZY LOADING SINGLETONS --- | |
| _embedding_fn = None | |
| _reranker = None | |
| _chunk_manager = None | |
| def get_embedding_function(): | |
| global _embedding_fn | |
| if _embedding_fn is None: | |
| print("⚙️ Loading Embedding Model...") | |
| _embedding_fn = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) | |
| return _embedding_fn | |
| def get_reranker_model(): | |
| global _reranker | |
| if _reranker is None: | |
| print("⚙️ Loading Reranker Model...") | |
| _reranker = CrossEncoder(RERANKER_MODEL_NAME) | |
| return _reranker | |
| def get_chunk_manager(): | |
| global _chunk_manager | |
| if _chunk_manager is None: | |
| print("⚙️ Loading Chunk Manager...") | |
| _chunk_manager = ChunkingManager(embedding_model_name=EMBEDDING_MODEL_NAME) | |
| return _chunk_manager | |
| # --- DATABASE OPERATIONS --- | |
| def get_vectorstore(username): | |
| safe_username = os.path.basename(username) | |
| user_db_path = os.path.join(DB_ROOT, safe_username) | |
| if not os.path.exists(user_db_path): | |
| os.makedirs(user_db_path, exist_ok=True) | |
| return Chroma( | |
| persist_directory=user_db_path, | |
| embedding_function=get_embedding_function(), | |
| collection_name=f"docs_{safe_username}" | |
| ) | |
| def save_uploaded_file(uploaded_file): | |
| if not os.path.exists(UPLOAD_DIR): | |
| os.makedirs(UPLOAD_DIR) | |
| safe_filename = os.path.basename(uploaded_file.name) | |
| file_path = os.path.join(UPLOAD_DIR, safe_filename) | |
| with open(file_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| return file_path | |
| def process_and_add_document(file_path, username, strategy="paragraph"): | |
| try: | |
| print(f"🧠 Chunking {os.path.basename(file_path)}...") | |
| strat_map = { | |
| "token": ChunkingStrategy.TOKEN, | |
| "paragraph": ChunkingStrategy.PARAGRAPH, | |
| "page": ChunkingStrategy.PAGE | |
| } | |
| selected_strategy = strat_map.get(strategy, ChunkingStrategy.PARAGRAPH) | |
| manager = get_chunk_manager() | |
| chunks = manager.process_document( | |
| file_path=file_path, | |
| strategy=selected_strategy, | |
| preprocess=True | |
| ) | |
| if not chunks: | |
| return False, "No text extracted. Is the file empty/scanned?" | |
| # FIX #1: Tag every chunk with the strategy used | |
| for chunk in chunks: | |
| chunk.metadata["strategy"] = strategy | |
| print(f"💾 Indexing {len(chunks)} chunks into Vector DB...") | |
| db = get_vectorstore(username) | |
| db.add_documents(chunks) | |
| tracker.upload_user_db(username) | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| return True, f"Successfully added {len(chunks)} chunks to Knowledge Base." | |
| except Exception as e: | |
| print(f"❌ Processing Error: {e}") | |
| return False, str(e) | |
| # --- RETRIEVAL ENGINE --- | |
| def search_knowledge_base(query, username, k=6): | |
| """ | |
| Two-Stage Retrieval System (RAG): | |
| 1. Retrieval: Get 10 candidates via fast Vector Search. | |
| 2. Reranking: Sort them via Cross-Encoder (Slow/Precise). | |
| 3. Return top k. | |
| """ | |
| try: | |
| db = get_vectorstore(username) | |
| if db._collection.count() == 0: | |
| return [] | |
| reranker = get_reranker_model() | |
| # 1. Vector Search (Broad Net) | |
| vector_results = db.similarity_search(query, k=25) | |
| # 2. "Poor Man's" Keyword Search (The Safety Net) | |
| # We perform a basic text search for unique terms in the query | |
| # This catches acronyms like "C&D" if we normalize them | |
| # Normalize query acronyms (e.g., "C&D" -> "C D") | |
| normalized_query = query.replace("&", " ") | |
| keyword_results = [] | |
| # (Optional: In a production DB like Pinecone/Weaviate, this is built-in. | |
| # For Chroma local, we rely on the vector net mostly, but we can | |
| # extend k significantly to catch edge cases). | |
| # STRATEGY: Just widen the net significantly. | |
| # Vector models often hide the match at rank 30 or 40 if the spelling differs. | |
| results = db.similarity_search(query, k=50) # Widen from 25 to 50 | |
| if not results: | |
| return [] | |
| # 2. Reranking | |
| passages = [doc.page_content for doc in results] | |
| ranks = reranker.rank(query, passages) | |
| top_results = [] | |
| sorted_ranks = sorted(ranks, key=lambda x: x['score'], reverse=True) | |
| # Return the top k results | |
| for rank in sorted_ranks[:k]: | |
| doc_index = rank['corpus_id'] | |
| doc = results[doc_index] | |
| doc.metadata["relevance_score"] = round(rank['score'], 4) | |
| top_results.append(doc) | |
| return top_results | |
| except Exception as e: | |
| print(f"⚠️ Search Error (likely empty DB): {e}") | |
| return [] | |
| def list_documents(username): | |
| try: | |
| db = get_vectorstore(username) | |
| # Check if empty before fetching to prevent errors | |
| if db._collection.count() == 0: | |
| return [] | |
| data = db.get() | |
| metadatas = data['metadatas'] | |
| file_stats = {} | |
| for meta in metadatas: | |
| src = meta.get('source', 'unknown') | |
| filename = os.path.basename(src) | |
| # FIX #2: Retrieve the strategy (Default to 'unknown' for old docs) | |
| strat = meta.get('strategy', 'unknown') | |
| if src not in file_stats: | |
| file_stats[src] = { | |
| 'source': src, | |
| 'filename': filename, | |
| 'chunks': 0, | |
| 'strategy': strat | |
| } | |
| file_stats[src]['chunks'] += 1 | |
| return list(file_stats.values()) | |
| except Exception as e: | |
| print(f"❌ Error listing docs: {e}") | |
| return [] | |
| def delete_document(username, source_path): | |
| try: | |
| print(f"🗑️ Deleting {source_path} for {username}...") | |
| db = get_vectorstore(username) | |
| db.delete(where={"source": source_path}) | |
| tracker.upload_user_db(username) | |
| return True, f"Deleted {os.path.basename(source_path)}" | |
| except Exception as e: | |
| return False, str(e) | |
| def reset_knowledge_base(username): | |
| try: | |
| db = get_vectorstore(username) | |
| db.delete_collection() | |
| tracker.upload_user_db(username) | |
| return True, "Knowledge Base completely reset." | |
| except Exception as e: | |
| return False, str(e) |