Spaces:
Running
Running
| import os | |
| import shutil | |
| import logging | |
| from typing import List, Literal, Tuple | |
| # --- LANGCHAIN & DB IMPORTS --- | |
| from langchain_chroma import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_core.documents import Document | |
| from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter | |
| from sentence_transformers import CrossEncoder | |
| # --- CUSTOM CORE IMPORTS --- | |
| from core.ParagraphChunker import ParagraphChunker | |
| from core.TokenChunker import TokenChunker | |
| from core.AcronymManager import AcronymManager | |
| # --- CONFIGURATION --- | |
| CHROMA_PATH = "chroma_db" | |
| UPLOAD_DIR = "source_documents" | |
| EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" | |
| RERANK_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| # Configure Logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --- LAZY LOADING GLOBALS --- | |
| _embedding_func = None | |
| _rerank_model = None | |
| def get_embedding_func(): | |
| """Lazy loads the embedding model to save startup resources.""" | |
| global _embedding_func | |
| if _embedding_func is None: | |
| logger.info(f"⏳ Loading Embedding Model: {EMBED_MODEL_NAME}...") | |
| _embedding_func = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME) | |
| logger.info("✅ Embedding Model Loaded.") | |
| return _embedding_func | |
| def get_rerank_model(): | |
| """Lazy loads the Cross-Encoder model.""" | |
| global _rerank_model | |
| if _rerank_model is None: | |
| logger.info(f"⏳ Loading Reranker: {RERANK_MODEL_NAME}...") | |
| _rerank_model = CrossEncoder(RERANK_MODEL_NAME) | |
| logger.info("✅ Reranker Loaded.") | |
| return _rerank_model | |
| # --- PART 1: CHUNKING LOGIC (The New System) --- | |
| def _process_markdown(file_path: str, chunk_size: int = 1000, chunk_overlap: int = 100) -> List[Document]: | |
| """Internal helper to process Markdown files using Header Semantic Splitting.""" | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| markdown_text = f.read() | |
| headers_to_split_on = [ | |
| ("#", "Header 1"), | |
| ("##", "Header 2"), | |
| ("###", "Header 3"), | |
| ] | |
| markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on) | |
| md_header_splits = markdown_splitter.split_text(markdown_text) | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap | |
| ) | |
| final_docs = text_splitter.split_documents(md_header_splits) | |
| for doc in final_docs: | |
| doc.metadata['source'] = os.path.basename(file_path) | |
| doc.metadata['file_type'] = 'md' | |
| doc.metadata['strategy'] = 'markdown_header' | |
| return final_docs | |
| except Exception as e: | |
| logger.error(f"Error processing Markdown file {file_path}: {e}") | |
| return [] | |
| def process_file( | |
| file_path: str, | |
| chunking_strategy: Literal["paragraph", "token"] = "paragraph", | |
| chunk_size: int = 512, | |
| chunk_overlap: int = 100, | |
| model_name: str = "gpt-4o" | |
| ) -> List[Document]: | |
| """ | |
| Main chunking engine. Routes file to specific chunkers based on type/strategy. | |
| """ | |
| if not os.path.exists(file_path): | |
| logger.error(f"File not found: {file_path}") | |
| return [] | |
| file_extension = os.path.splitext(file_path)[1].lower() | |
| file_name = os.path.basename(file_path) | |
| logger.info(f"Processing {file_name} using strategy: {chunking_strategy}") | |
| # 1. Handle Markdown | |
| if file_extension == ".md": | |
| return _process_markdown(file_path, chunk_size, chunk_overlap) | |
| # 2. Handle PDF and TXT | |
| elif file_extension in [".pdf", ".txt"]: | |
| if chunking_strategy == "token": | |
| chunker = TokenChunker( | |
| model_name=model_name, | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap | |
| ) | |
| else: | |
| chunker = ParagraphChunker(model_name=model_name) | |
| try: | |
| if file_extension == ".pdf": | |
| docs = chunker.process_document(file_path) | |
| elif file_extension == ".txt": | |
| docs = chunker.process_text_file(file_path) | |
| # Ensure metadata consistency | |
| for doc in docs: | |
| doc.metadata["source"] = file_name | |
| doc.metadata["strategy"] = chunking_strategy | |
| return docs | |
| except Exception as e: | |
| logger.error(f"Error using {chunking_strategy} chunker on {file_name}: {e}") | |
| return [] | |
| else: | |
| logger.warning(f"Unsupported file extension: {file_extension}") | |
| return [] | |
| # --- PART 2: DATABASE & FILE MANAGEMENT (The Old Stable System) --- | |
| def save_uploaded_file(uploaded_file, username: str = "default") -> str: | |
| """Saves a StreamlitUploadedFile to disk so the loaders can read it.""" | |
| try: | |
| user_dir = os.path.join(UPLOAD_DIR, username) | |
| os.makedirs(user_dir, exist_ok=True) | |
| file_path = os.path.join(user_dir, uploaded_file.name) | |
| with open(file_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| logger.info(f"File saved: {file_path}") | |
| return file_path | |
| except Exception as e: | |
| logger.error(f"Error saving file: {e}") | |
| return None | |
| def process_and_add_text(text: str, source_name: str, username: str) -> Tuple[bool, str]: | |
| """ | |
| Ingests raw text string (e.g., from the Flattener tool) directly into Chroma. | |
| """ | |
| try: | |
| user_db_path = os.path.join(CHROMA_PATH, username) | |
| emb_fn = get_embedding_func() | |
| # Initialize DB | |
| db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn) | |
| # Create a Document object directly | |
| doc = Document( | |
| page_content=text, | |
| metadata={ | |
| "source": source_name, | |
| "strategy": "flattened_text", | |
| "file_type": "generated" | |
| } | |
| ) | |
| # Add single document | |
| db.add_documents([doc]) | |
| return True, f"Successfully indexed flattened text: {source_name}" | |
| except Exception as e: | |
| logger.error(f"Error indexing raw text: {e}") | |
| return False, f"Error: {str(e)}" | |
| def ingest_file(file_path: str, username: str, strategy: str = "paragraph") -> Tuple[bool, str]: | |
| try: | |
| # 1. Chunk the file | |
| docs = process_file(file_path, chunking_strategy=strategy) | |
| if not docs: | |
| return False, "No valid chunks generated from file." | |
| # --- ACRONYM SCANNING --- | |
| # We scan the raw text of the chunks to learn new definitions | |
| acronym_mgr = AcronymManager() | |
| for doc in docs: | |
| acronym_mgr.scan_text_for_acronyms(doc.page_content) | |
| # ----------------------------- | |
| # 2. Add to Chroma DB | |
| user_db_path = os.path.join(CHROMA_PATH, username) | |
| emb_fn = get_embedding_func() | |
| db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn) | |
| db.add_documents(docs) | |
| return True, f"Successfully indexed {len(docs)} chunks." | |
| except Exception as e: | |
| logger.error(f"Ingestion failed: {e}") | |
| return False, f"System Error: {str(e)}" | |
| def search_knowledge_base(query: str, username: str, k: int = 10, final_k: int = 4) -> List[Document]: | |
| user_db_path = os.path.join(CHROMA_PATH, username) | |
| if not os.path.exists(user_db_path): | |
| return [] | |
| try: | |
| # --- NEW: QUERY EXPANSION --- | |
| acronym_mgr = AcronymManager() | |
| expanded_query = acronym_mgr.expand_query(query) | |
| if expanded_query != query: | |
| logger.info(f"Query Expanded: '{query}' -> '{expanded_query}'") | |
| else: | |
| expanded_query = query | |
| # ---------------------------- | |
| # 1. Vector Retrieval (Use expanded_query instead of query) | |
| emb_fn = get_embedding_func() | |
| db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn) | |
| results = db.similarity_search_with_relevance_scores(expanded_query, k=k) # <--- UPDATED VAR | |
| if not results: | |
| return [] | |
| # 2. Reranking (Pass expanded_query here too) | |
| candidate_docs = [doc for doc, _ in results] | |
| candidate_texts = [doc.page_content for doc in candidate_docs] | |
| pairs = [[expanded_query, text] for text in candidate_texts] # <--- UPDATED VAR | |
| reranker = get_rerank_model() | |
| scores = reranker.predict(pairs) | |
| # Sort by new score | |
| scored_docs = list(zip(candidate_docs, scores)) | |
| scored_docs.sort(key=lambda x: x[1], reverse=True) | |
| return [doc for doc, score in scored_docs[:final_k]] | |
| except Exception as e: | |
| logger.error(f"Search Error: {e}") | |
| return [] | |
| def list_documents(username: str) -> List[dict]: | |
| """ | |
| Returns a list of unique files currently in the vector database. | |
| (Used for the sidebar list) | |
| """ | |
| user_db_path = os.path.join(CHROMA_PATH, username) | |
| if not os.path.exists(user_db_path): | |
| return [] | |
| try: | |
| emb_fn = get_embedding_func() | |
| db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn) | |
| # Chroma's .get() returns all metadata | |
| data = db.get() | |
| metadatas = data['metadatas'] | |
| inventory = {} | |
| for m in metadatas: | |
| # Metadata keys might differ slightly, handle gracefully | |
| src = m.get('source', 'Unknown') | |
| if src not in inventory: | |
| inventory[src] = { | |
| "chunks": 0, | |
| "strategy": m.get('strategy', 'unknown') | |
| } | |
| inventory[src]["chunks"] += 1 | |
| # FIXED: Added "source": k to the dictionary below | |
| return [ | |
| {"filename": k, "chunks": v["chunks"], "strategy": v["strategy"], "source": k} | |
| for k, v in inventory.items() | |
| ] | |
| except Exception as e: | |
| logger.error(f"Error listing docs: {e}") | |
| return [] | |
| def delete_document(username: str, filename: str) -> Tuple[bool, str]: | |
| """Removes a document from the vector database.""" | |
| user_db_path = os.path.join(CHROMA_PATH, username) | |
| try: | |
| emb_fn = get_embedding_func() | |
| db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn) | |
| data = db.get() | |
| ids_to_delete = [] | |
| for i, meta in enumerate(data['metadatas']): | |
| if meta.get('source') == filename: | |
| ids_to_delete.append(data['ids'][i]) | |
| if ids_to_delete: | |
| db.delete(ids=ids_to_delete) | |
| return True, f"Deleted {filename}." | |
| else: | |
| return False, "File not found in index." | |
| except Exception as e: | |
| return False, f"Delete failed: {e}" | |
| def reset_knowledge_base(username: str) -> Tuple[bool, str]: | |
| """Nukes the user's database folder.""" | |
| user_db_path = os.path.join(CHROMA_PATH, username) | |
| if os.path.exists(user_db_path): | |
| shutil.rmtree(user_db_path) | |
| return True, "Database Reset." | |
| return False, "Database already empty." |