import os import shutil import logging from typing import List, Literal, Tuple # --- LANGCHAIN & DB IMPORTS --- from langchain_chroma import Chroma from langchain_huggingface import HuggingFaceEmbeddings from langchain_core.documents import Document from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter from sentence_transformers import CrossEncoder # --- CUSTOM CORE IMPORTS --- from core.ParagraphChunker import ParagraphChunker from core.TokenChunker import TokenChunker # --- CONFIGURATION --- CHROMA_PATH = "chroma_db" UPLOAD_DIR = "source_documents" EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" RERANK_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" # Configure Logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- LAZY LOADING GLOBALS --- _embedding_func = None _rerank_model = None def get_embedding_func(): """Lazy loads the embedding model to save startup resources.""" global _embedding_func if _embedding_func is None: logger.info(f"⏳ Loading Embedding Model: {EMBED_MODEL_NAME}...") _embedding_func = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME) logger.info("✅ Embedding Model Loaded.") return _embedding_func def get_rerank_model(): """Lazy loads the Cross-Encoder model.""" global _rerank_model if _rerank_model is None: logger.info(f"⏳ Loading Reranker: {RERANK_MODEL_NAME}...") _rerank_model = CrossEncoder(RERANK_MODEL_NAME) logger.info("✅ Reranker Loaded.") return _rerank_model # --- PART 1: CHUNKING LOGIC (The New System) --- def _process_markdown(file_path: str, chunk_size: int = 1000, chunk_overlap: int = 100) -> List[Document]: """Internal helper to process Markdown files using Header Semantic Splitting.""" try: with open(file_path, 'r', encoding='utf-8') as f: markdown_text = f.read() headers_to_split_on = [ ("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3"), ] markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on) md_header_splits = markdown_splitter.split_text(markdown_text) text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap ) final_docs = text_splitter.split_documents(md_header_splits) for doc in final_docs: doc.metadata['source'] = os.path.basename(file_path) doc.metadata['file_type'] = 'md' doc.metadata['strategy'] = 'markdown_header' return final_docs except Exception as e: logger.error(f"Error processing Markdown file {file_path}: {e}") return [] def process_file( file_path: str, chunking_strategy: Literal["paragraph", "token"] = "paragraph", chunk_size: int = 512, chunk_overlap: int = 50, model_name: str = "gpt-4" ) -> List[Document]: """ Main chunking engine. Routes file to specific chunkers based on type/strategy. """ if not os.path.exists(file_path): logger.error(f"File not found: {file_path}") return [] file_extension = os.path.splitext(file_path)[1].lower() file_name = os.path.basename(file_path) logger.info(f"Processing {file_name} using strategy: {chunking_strategy}") # 1. Handle Markdown if file_extension == ".md": return _process_markdown(file_path, chunk_size, chunk_overlap) # 2. Handle PDF and TXT elif file_extension in [".pdf", ".txt"]: if chunking_strategy == "token": chunker = TokenChunker( model_name=model_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap ) else: chunker = ParagraphChunker(model_name=model_name) try: if file_extension == ".pdf": docs = chunker.process_document(file_path) elif file_extension == ".txt": docs = chunker.process_text_file(file_path) # Ensure metadata consistency for doc in docs: doc.metadata["source"] = file_name doc.metadata["strategy"] = chunking_strategy return docs except Exception as e: logger.error(f"Error using {chunking_strategy} chunker on {file_name}: {e}") return [] else: logger.warning(f"Unsupported file extension: {file_extension}") return [] # --- PART 2: DATABASE & FILE MANAGEMENT (The Old Stable System) --- def save_uploaded_file(uploaded_file, username: str = "default") -> str: """Saves a StreamlitUploadedFile to disk so the loaders can read it.""" try: user_dir = os.path.join(UPLOAD_DIR, username) os.makedirs(user_dir, exist_ok=True) file_path = os.path.join(user_dir, uploaded_file.name) with open(file_path, "wb") as f: f.write(uploaded_file.getbuffer()) logger.info(f"File saved: {file_path}") return file_path except Exception as e: logger.error(f"Error saving file: {e}") return None def process_and_add_text(text: str, source_name: str, username: str) -> Tuple[bool, str]: """ Ingests raw text string (e.g., from the Flattener tool) directly into Chroma. """ try: user_db_path = os.path.join(CHROMA_PATH, username) emb_fn = get_embedding_func() # Initialize DB db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn) # Create a Document object directly doc = Document( page_content=text, metadata={ "source": source_name, "strategy": "flattened_text", "file_type": "generated" } ) # Add single document db.add_documents([doc]) return True, f"Successfully indexed flattened text: {source_name}" except Exception as e: logger.error(f"Error indexing raw text: {e}") return False, f"Error: {str(e)}" def ingest_file(file_path: str, username: str, strategy: str = "paragraph") -> Tuple[bool, str]: """ The High-Level Bridge: Takes a file path, chunks it, and saves to Vector DB. Replaces the old 'process_and_add_document'. """ try: # 1. Chunk the file using the new engine docs = process_file(file_path, chunking_strategy=strategy) if not docs: return False, "No valid chunks generated from file." # 2. Add to Chroma DB user_db_path = os.path.join(CHROMA_PATH, username) emb_fn = get_embedding_func() db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn) db.add_documents(docs) return True, f"Successfully indexed {len(docs)} chunks." except Exception as e: logger.error(f"Ingestion failed: {e}") return False, f"System Error: {str(e)}" def search_knowledge_base(query: str, username: str, k: int = 10, final_k: int = 4) -> List[Document]: """Retrieves top K chunks, then uses Cross-Encoder to re-rank them.""" user_db_path = os.path.join(CHROMA_PATH, username) if not os.path.exists(user_db_path): return [] try: # 1. Vector Retrieval emb_fn = get_embedding_func() db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn) results = db.similarity_search_with_relevance_scores(query, k=k) if not results: return [] # 2. Reranking candidate_docs = [doc for doc, _ in results] candidate_texts = [doc.page_content for doc in candidate_docs] pairs = [[query, text] for text in candidate_texts] reranker = get_rerank_model() scores = reranker.predict(pairs) # Sort by new score scored_docs = list(zip(candidate_docs, scores)) scored_docs.sort(key=lambda x: x[1], reverse=True) return [doc for doc, score in scored_docs[:final_k]] except Exception as e: logger.error(f"Search Error: {e}") return [] def list_documents(username: str) -> List[dict]: """ Returns a list of unique files currently in the vector database. (Used for the sidebar list) """ user_db_path = os.path.join(CHROMA_PATH, username) if not os.path.exists(user_db_path): return [] try: emb_fn = get_embedding_func() db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn) # Chroma's .get() returns all metadata data = db.get() metadatas = data['metadatas'] inventory = {} for m in metadatas: # Metadata keys might differ slightly, handle gracefully src = m.get('source', 'Unknown') if src not in inventory: inventory[src] = { "chunks": 0, "strategy": m.get('strategy', 'unknown') } inventory[src]["chunks"] += 1 # FIXED: Added "source": k to the dictionary below return [ {"filename": k, "chunks": v["chunks"], "strategy": v["strategy"], "source": k} for k, v in inventory.items() ] except Exception as e: logger.error(f"Error listing docs: {e}") return [] def delete_document(username: str, filename: str) -> Tuple[bool, str]: """Removes a document from the vector database.""" user_db_path = os.path.join(CHROMA_PATH, username) try: emb_fn = get_embedding_func() db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn) data = db.get() ids_to_delete = [] for i, meta in enumerate(data['metadatas']): if meta.get('source') == filename: ids_to_delete.append(data['ids'][i]) if ids_to_delete: db.delete(ids=ids_to_delete) return True, f"Deleted {filename}." else: return False, "File not found in index." except Exception as e: return False, f"Delete failed: {e}" def reset_knowledge_base(username: str) -> Tuple[bool, str]: """Nukes the user's database folder.""" user_db_path = os.path.join(CHROMA_PATH, username) if os.path.exists(user_db_path): shutil.rmtree(user_db_path) return True, "Database Reset." return False, "Database already empty."