Spaces:
Sleeping
Sleeping
| import os | |
| import glob | |
| import yaml | |
| import shutil | |
| import re | |
| from typing import List, Tuple | |
| import faiss | |
| import numpy as np | |
| import gradio as gr | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from PyPDF2 import PdfReader | |
| import docx | |
| # ----------------------------- | |
| # CONFIG | |
| # ----------------------------- | |
| def load_config(): | |
| """Load configuration with error handling""" | |
| try: | |
| with open("config.yaml", "r", encoding="utf-8") as f: | |
| return yaml.safe_load(f) | |
| except FileNotFoundError: | |
| print("⚠️ config.yaml not found, using defaults") | |
| return get_default_config() | |
| except Exception as e: | |
| print(f"⚠️ Error loading config: {e}, using defaults") | |
| return get_default_config() | |
| def get_default_config(): | |
| """Provide default configuration""" | |
| return { | |
| "kb": { | |
| "directory": "./knowledge_base", # can be overridden in config.yaml (e.g., ./kb) | |
| "index_directory": "./index", | |
| }, | |
| "models": { | |
| "embedding": "sentence-transformers/all-MiniLM-L6-v2", | |
| "qa": "google/flan-t5-small", | |
| }, | |
| "chunking": { | |
| "chunk_size": 1200, | |
| "overlap": 200, | |
| }, | |
| "thresholds": { | |
| "similarity": 0.1, | |
| }, | |
| "messages": { | |
| "welcome": "Ask me anything about the documents in the knowledge base!", | |
| "no_answer": "I couldn't find a relevant answer in the knowledge base.", | |
| }, | |
| "client": { | |
| "name": "RAG AI Assistant", | |
| }, | |
| "quick_actions": [], | |
| } | |
| CONFIG = load_config() | |
| KB_DIR = CONFIG["kb"]["directory"] | |
| INDEX_DIR = CONFIG["kb"]["index_directory"] | |
| EMBEDDING_MODEL_NAME = CONFIG["models"]["embedding"] | |
| QA_MODEL_NAME = CONFIG["models"].get("qa", "google/flan-t5-small") | |
| CHUNK_SIZE = CONFIG["chunking"]["chunk_size"] | |
| CHUNK_OVERLAP = CONFIG["chunking"]["overlap"] | |
| SIM_THRESHOLD = CONFIG["thresholds"]["similarity"] | |
| WELCOME_MSG = CONFIG["messages"]["welcome"] | |
| NO_ANSWER_MSG = CONFIG["messages"]["no_answer"] | |
| # ----------------------------- | |
| # UTILITIES | |
| # ----------------------------- | |
| def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]: | |
| """Split text into overlapping chunks""" | |
| if not text or not text.strip(): | |
| return [] | |
| chunks = [] | |
| start = 0 | |
| text_len = len(text) | |
| while start < text_len: | |
| end = min(start + chunk_size, text_len) | |
| chunk = text[start:end].strip() | |
| if chunk and len(chunk) > 20: # Avoid tiny chunks | |
| chunks.append(chunk) | |
| if end >= text_len: | |
| break | |
| start += chunk_size - overlap | |
| return chunks | |
| def load_file_text(path: str) -> str: | |
| """Load text from various file formats with error handling""" | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"File not found: {path}") | |
| ext = os.path.splitext(path)[1].lower() | |
| try: | |
| if ext == ".pdf": | |
| reader = PdfReader(path) | |
| text_parts = [] | |
| for page in reader.pages: | |
| page_text = page.extract_text() | |
| if page_text: | |
| text_parts.append(page_text) | |
| return "\n".join(text_parts) | |
| elif ext in [".docx", ".doc"]: | |
| doc = docx.Document(path) | |
| return "\n".join(p.text for p in doc.paragraphs if p.text.strip()) | |
| else: # .txt, .md, etc. | |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: | |
| return f.read() | |
| except Exception as e: | |
| print(f"Error reading {path}: {e}") | |
| raise | |
| def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]: | |
| """Load all documents from knowledge base directory""" | |
| docs: List[Tuple[str, str]] = [] | |
| if not os.path.exists(kb_dir): | |
| print(f"⚠️ Knowledge base directory not found: {kb_dir}") | |
| print(f"Creating directory: {kb_dir}") | |
| os.makedirs(kb_dir, exist_ok=True) | |
| return docs | |
| if not os.path.isdir(kb_dir): | |
| print(f"⚠️ {kb_dir} is not a directory") | |
| return docs | |
| # Support multiple file formats | |
| patterns = ["*.txt", "*.md", "*.pdf", "*.docx", "*.doc"] | |
| paths = [] | |
| for pattern in patterns: | |
| paths.extend(glob.glob(os.path.join(kb_dir, pattern))) | |
| if not paths: | |
| print(f"⚠️ No documents found in {kb_dir}") | |
| return docs | |
| print(f"Found {len(paths)} documents in knowledge base") | |
| for path in paths: | |
| try: | |
| text = load_file_text(path) | |
| if text and text.strip(): | |
| docs.append((os.path.basename(path), text)) | |
| print(f"✓ Loaded: {os.path.basename(path)}") | |
| else: | |
| print(f"⚠️ Empty file: {os.path.basename(path)}") | |
| except Exception as e: | |
| print(f"✗ Could not read {path}: {e}") | |
| return docs | |
| def clean_context_text(text: str) -> str: | |
| """ | |
| Clean raw document context before sending to the answer builder: | |
| - Remove markdown headings (#, ##, ###) | |
| - Remove list markers (1., 2), -, *) | |
| - Remove duplicate lines | |
| - Remove title-like lines (e.g. 'Knowledge Base Structure & Information Architecture Best Practices') | |
| """ | |
| lines = text.splitlines() | |
| cleaned = [] | |
| seen = set() | |
| for line in lines: | |
| l = line.strip() | |
| if not l: | |
| continue | |
| # Remove markdown headings like "# 1. Title", "## Section" | |
| l = re.sub(r"^#+\s*", "", l) | |
| # Remove ordered list prefixes like "1. ", "2) " | |
| l = re.sub(r"^\d+[\.\)]\s*", "", l) | |
| # Remove bullet markers like "- ", "* " | |
| l = re.sub(r"^[-*]\s*", "", l) | |
| # Skip very short "noise" lines | |
| if len(l) < 5: | |
| continue | |
| # Heuristic: skip "title-like" lines where almost every word is capitalized | |
| words = l.split() | |
| if words: | |
| cap_words = sum(1 for w in words if w[:1].isupper()) | |
| if len(words) <= 10 and cap_words >= len(words) - 1: | |
| # Looks like a heading / title, skip it | |
| continue | |
| # Avoid exact duplicates | |
| if l in seen: | |
| continue | |
| seen.add(l) | |
| cleaned.append(l) | |
| return "\n".join(cleaned) | |
| # ----------------------------- | |
| # KB INDEX (FAISS) | |
| # ----------------------------- | |
| class RAGIndex: | |
| def __init__(self): | |
| self.embedder = None | |
| self.qa_tokenizer = None | |
| self.qa_model = None | |
| self.chunks: List[str] = [] | |
| self.chunk_sources: List[str] = [] | |
| self.index = None | |
| self.initialized = False | |
| try: | |
| print("🔄 Initializing RAG Assistant...") | |
| self._initialize_models() | |
| self._build_or_load_index() | |
| self.initialized = True | |
| print("✅ RAG Assistant ready!") | |
| except Exception as e: | |
| print(f"❌ Initialization error: {e}") | |
| print("The assistant will run in limited mode.") | |
| def _initialize_models(self): | |
| """Initialize embedding and QA models""" | |
| try: | |
| print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}") | |
| self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME) | |
| print(f"Loading QA (seq2seq) model: {QA_MODEL_NAME}") | |
| self.qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME) | |
| self.qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME) | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| raise | |
| def _build_or_load_index(self): | |
| """Build or load FAISS index from knowledge base""" | |
| os.makedirs(INDEX_DIR, exist_ok=True) | |
| idx_path = os.path.join(INDEX_DIR, "kb.index") | |
| meta_path = os.path.join(INDEX_DIR, "kb_meta.npy") | |
| # Try to load existing index | |
| if os.path.exists(idx_path) and os.path.exists(meta_path): | |
| try: | |
| print("Loading existing FAISS index...") | |
| self.index = faiss.read_index(idx_path) | |
| meta = np.load(meta_path, allow_pickle=True).item() | |
| self.chunks = list(meta["chunks"]) | |
| self.chunk_sources = list(meta["sources"]) | |
| print(f"✓ Index loaded with {len(self.chunks)} chunks") | |
| return | |
| except Exception as e: | |
| print(f"⚠️ Could not load existing index: {e}") | |
| print("Building new index...") | |
| # Build new index | |
| print("Building new FAISS index from knowledge base...") | |
| docs = load_kb_documents(KB_DIR) | |
| if not docs: | |
| print("⚠️ No documents found in knowledge base") | |
| print(f" Please add .txt, .md, .pdf, or .docx files to: {KB_DIR}") | |
| self.index = None | |
| self.chunks = [] | |
| self.chunk_sources = [] | |
| return | |
| all_chunks: List[str] = [] | |
| all_sources: List[str] = [] | |
| for source, text in docs: | |
| chunks = chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP) | |
| for chunk in chunks: | |
| all_chunks.append(chunk) | |
| all_sources.append(source) | |
| if not all_chunks: | |
| print("⚠️ No valid chunks created from documents") | |
| self.index = None | |
| self.chunks = [] | |
| self.chunk_sources = [] | |
| return | |
| print(f"Created {len(all_chunks)} chunks from {len(docs)} documents") | |
| print("Generating embeddings...") | |
| embeddings = self.embedder.encode( | |
| all_chunks, | |
| show_progress_bar=True, | |
| convert_to_numpy=True, | |
| batch_size=32, | |
| ) | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatIP(dimension) | |
| # Normalize for cosine similarity | |
| faiss.normalize_L2(embeddings) | |
| index.add(embeddings) | |
| # Save index | |
| try: | |
| faiss.write_index(index, idx_path) | |
| np.save( | |
| meta_path, | |
| { | |
| "chunks": np.array(all_chunks, dtype=object), | |
| "sources": np.array(all_sources, dtype=object), | |
| }, | |
| ) | |
| print("✓ Index saved successfully") | |
| except Exception as e: | |
| print(f"⚠️ Could not save index: {e}") | |
| self.index = index | |
| self.chunks = all_chunks | |
| self.chunk_sources = all_sources | |
| def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[str, str, float]]: | |
| """Retrieve relevant chunks for a query""" | |
| if not query or not query.strip(): | |
| return [] | |
| if self.index is None or not self.initialized: | |
| return [] | |
| try: | |
| q_emb = self.embedder.encode([query], convert_to_numpy=True) | |
| faiss.normalize_L2(q_emb) | |
| k = min(top_k, len(self.chunks)) if self.chunks else 0 | |
| if k == 0: | |
| return [] | |
| scores, idxs = self.index.search(q_emb, k) | |
| results: List[Tuple[str, str, float]] = [] | |
| for score, idx in zip(scores[0], idxs[0]): | |
| if idx == -1 or idx >= len(self.chunks): | |
| continue | |
| if score < SIM_THRESHOLD: | |
| continue | |
| results.append( | |
| (self.chunks[idx], self.chunk_sources[idx], float(score)) | |
| ) | |
| return results | |
| except Exception as e: | |
| print(f"Retrieval error: {e}") | |
| return [] | |
| def _generate_from_context( | |
| self, | |
| question: str, | |
| context: str, | |
| max_new_tokens: int = 180, | |
| ) -> str: | |
| """ | |
| Generate a grounded answer from the retrieved context using a seq2seq model | |
| (FLAN-T5, BART, etc.). The prompt forces the model to only use the context. | |
| """ | |
| if self.qa_model is None or self.qa_tokenizer is None: | |
| raise RuntimeError("QA model not loaded.") | |
| prompt = ( | |
| "You are a knowledge base assistant. Answer the question ONLY using the information " | |
| "in the context below.\n" | |
| "If the context does not contain the answer, say exactly: " | |
| "\"The documents do not contain enough information to answer this.\"\n\n" | |
| f"Question: {question}\n\n" | |
| "Context:\n" | |
| f"{context}\n\n" | |
| "Write a helpful answer in 2–4 sentences. Keep it factual and concise. " | |
| "Do NOT repeat the question. Do NOT include section titles or headings." | |
| ) | |
| inputs = self.qa_tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=768, | |
| ) | |
| outputs = self.qa_model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=0.0, # deterministic | |
| do_sample=False, | |
| ) | |
| answer = self.qa_tokenizer.decode( | |
| outputs[0], | |
| skip_special_tokens=True, | |
| ).strip() | |
| return answer | |
| def answer(self, question: str) -> str: | |
| """ | |
| Answer a question using RAG with sentence-level semantic selection | |
| and a generic seq2seq model (Flan-T5, BART, etc.). | |
| This function is fully stateless per call: it only uses the question | |
| and the indexed knowledge base, never previous answers. | |
| """ | |
| if not self.initialized: | |
| return "❌ Assistant not properly initialized. Please check the logs." | |
| if not question or not question.strip(): | |
| return "Please ask a question." | |
| if self.index is None or not self.chunks: | |
| return ( | |
| f"📚 Knowledge base is empty.\n\n" | |
| f"Please add documents to: `{KB_DIR}`\n" | |
| f"Supported formats: .txt, .md, .pdf, .docx" | |
| ) | |
| # ----------------------------- | |
| # 1) Retrieve top-K chunks for this question | |
| # ----------------------------- | |
| contexts = self.retrieve(question, top_k=5) | |
| if not contexts: | |
| return ( | |
| f"{NO_ANSWER_MSG}\n\n" | |
| f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base." | |
| ) | |
| used_sources = set() | |
| candidate_sentences = [] | |
| candidate_sources = [] | |
| # ----------------------------- | |
| # 2) Split retrieved chunks into sentences (generic, no KB-specific logic) | |
| # ----------------------------- | |
| for ctx, source, score in contexts: | |
| used_sources.add(source) | |
| cleaned_ctx = clean_context_text(ctx) | |
| if not cleaned_ctx: | |
| continue | |
| # Simple sentence splitter: split on ., ?, ! plus newlines | |
| raw_sents = re.split(r'(?<=[.!?])\s+|\n+', cleaned_ctx) | |
| for s in raw_sents: | |
| s_clean = s.strip() | |
| # skip very short sentences | |
| if len(s_clean) < 25: | |
| continue | |
| candidate_sentences.append(s_clean) | |
| candidate_sources.append(source) | |
| if not candidate_sentences: | |
| return ( | |
| f"{NO_ANSWER_MSG}\n\n" | |
| f"💡 Try adding more detailed documents to the knowledge base." | |
| ) | |
| # ----------------------------- | |
| # 3) Score sentences: semantic + lexical (generic) | |
| # ----------------------------- | |
| try: | |
| # Semantic similarity via sentence embeddings | |
| q_emb = self.embedder.encode([question], convert_to_numpy=True) | |
| s_embs = self.embedder.encode(candidate_sentences, convert_to_numpy=True) | |
| faiss.normalize_L2(q_emb) | |
| faiss.normalize_L2(s_embs) | |
| sims = np.dot(s_embs, q_emb.T).reshape(-1) # cosine similarity | |
| except Exception as e: | |
| print(f"Sentence embedding error, falling back to lexical scoring only: {e}") | |
| sims = np.zeros(len(candidate_sentences), dtype=float) | |
| # Lexical overlap (shared content words) | |
| q_words = {w.lower() for w in re.findall(r"\w+", question) if len(w) > 3} | |
| lex_scores = [] | |
| for sent in candidate_sentences: | |
| s_words = {w.lower() for w in re.findall(r"\w+", sent) if len(w) > 3} | |
| lex_scores.append(len(q_words & s_words)) | |
| lex_scores = np.array(lex_scores, dtype=float) | |
| # Combine scores in a generic way: semantic + a bit of lexical | |
| combined = (1.5 * sims) + (0.5 * lex_scores) | |
| # ----------------------------- | |
| # 4) Pick top-N sentences to form the context | |
| # ----------------------------- | |
| if len(combined) == 0: | |
| answer_text = NO_ANSWER_MSG | |
| else: | |
| top_idx = np.argsort(-combined) | |
| max_sentences = 5 # you can tune this | |
| chosen_sentences = [] | |
| chosen_sources = set() | |
| for i in top_idx: | |
| if len(chosen_sentences) >= max_sentences: | |
| break | |
| s = candidate_sentences[i].strip() | |
| if not s: | |
| continue | |
| if s in chosen_sentences: | |
| continue # avoid duplicates | |
| chosen_sentences.append(s) | |
| chosen_sources.add(candidate_sources[i]) | |
| if not chosen_sentences: | |
| answer_text = NO_ANSWER_MSG | |
| else: | |
| context_for_llm = "\n".join(chosen_sentences) | |
| # ----------------------------- | |
| # 5) Let the seq2seq model generate a natural answer | |
| # ----------------------------- | |
| try: | |
| answer_text = self._generate_from_context( | |
| question=question, | |
| context=context_for_llm, | |
| max_new_tokens=200, | |
| ).strip() | |
| except Exception as e: | |
| print(f"Generation error, falling back to extractive answer: {e}") | |
| answer_text = " ".join(chosen_sentences) | |
| if not answer_text: | |
| answer_text = NO_ANSWER_MSG | |
| # Track sources from retrieved chunks (or from chosen sentences if you prefer) | |
| sources_str = ", ".join(sorted(used_sources)) if used_sources else "N/A" | |
| return ( | |
| f"**Answer:** {answer_text}\n\n" | |
| f"**Sources:** {sources_str}" | |
| ) | |
| # Initialize RAG system | |
| print("=" * 50) | |
| rag_index = RAGIndex() | |
| print("=" * 50) | |
| # ----------------------------- | |
| # GRADIO APP (BLOCKS) | |
| # ----------------------------- | |
| def rag_respond(message, history): | |
| if history is None: | |
| history = [] | |
| user_msg = str(message) | |
| # Append to UI history ONLY | |
| history.append({"role": "user", "content": user_msg}) | |
| # ❗ Do NOT pass history to rag_index.answer() | |
| bot_reply = rag_index.answer(user_msg) | |
| # Append assistant reply for UI display | |
| history.append({"role": "assistant", "content": bot_reply}) | |
| # Return blank input + updated UI history | |
| return "", history | |
| def upload_to_kb(files): | |
| """Save uploaded files into the KB directory""" | |
| if not files: | |
| return "No files uploaded." | |
| if not isinstance(files, list): | |
| files = [files] | |
| os.makedirs(KB_DIR, exist_ok=True) | |
| saved_files = [] | |
| for f in files: | |
| src_path = getattr(f, "name", None) or str(f) | |
| if not os.path.exists(src_path): | |
| continue | |
| filename = os.path.basename(src_path) | |
| dest_path = os.path.join(KB_DIR, filename) | |
| try: | |
| shutil.copy(src_path, dest_path) | |
| saved_files.append(filename) | |
| except Exception as e: | |
| print(f"Error saving file {filename}: {e}") | |
| if not saved_files: | |
| return "No files could be saved. Check logs." | |
| return ( | |
| f"✅ Saved {len(saved_files)} file(s) to knowledge base:\n- " | |
| + "\n- ".join(saved_files) | |
| + "\n\nClick **Rebuild index** to include them in search." | |
| ) | |
| def rebuild_index(): | |
| """Trigger index rebuild from UI""" | |
| rag_index._build_or_load_index() | |
| if rag_index.index is None or not rag_index.chunks: | |
| return ( | |
| "⚠️ Index rebuild finished, but no documents or chunks were found.\n" | |
| f"Add files to `{KB_DIR}` and try again." | |
| ) | |
| return ( | |
| f"✅ Index rebuilt successfully.\n" | |
| f"Chunks in index: {len(rag_index.chunks)}" | |
| ) | |
| # Description + optional examples | |
| description = WELCOME_MSG | |
| if not rag_index.initialized or rag_index.index is None or not rag_index.chunks: | |
| description += ( | |
| f"\n\n⚠️ **Note:** Knowledge base is currently empty or index is not built.\n" | |
| f"Upload documents in the **Knowledge Base** tab and click **Rebuild index**." | |
| ) | |
| examples = [ | |
| qa.get("query") | |
| for qa in CONFIG.get("quick_actions", []) | |
| if qa.get("query") | |
| ] | |
| if not examples and rag_index.initialized and rag_index.index is not None and rag_index.chunks: | |
| examples = [ | |
| "What is a knowledge base?", | |
| "What are best practices for maintaining a KB?", | |
| "How should I structure knowledge base articles?", | |
| ] | |
| with gr.Blocks(title=CONFIG["client"]["name"]) as demo: | |
| gr.Markdown(f"# {CONFIG['client']['name']}") | |
| gr.Markdown(description) | |
| with gr.Tab("Chat"): | |
| chatbot = gr.Chatbot(label="RAG Chat") | |
| with gr.Row(): | |
| txt = gr.Textbox( | |
| show_label=False, | |
| placeholder="Ask a question about your documents and press Enter to send...", | |
| lines=1, # single line so Enter submits | |
| ) | |
| with gr.Row(): | |
| send_btn = gr.Button("Send") | |
| clear_btn = gr.Button("Clear") | |
| txt.submit(rag_respond, [txt, chatbot], [txt, chatbot]) | |
| send_btn.click(rag_respond, [txt, chatbot], [txt, chatbot]) | |
| clear_btn.click(lambda: ([], ""), None, [chatbot, txt]) | |
| with gr.Tab("Knowledge Base"): | |
| gr.Markdown( | |
| f""" | |
| ### Manage Knowledge Base | |
| - Supported formats: `.txt`, `.md`, `.pdf`, `.docx`, `.doc` | |
| - Files are stored in: `{KB_DIR}` | |
| - After uploading, click **Rebuild index** so the assistant can use the new content. | |
| """ | |
| ) | |
| kb_upload = gr.File( | |
| label="Upload documents", | |
| file_count="multiple", | |
| ) | |
| kb_status = gr.Textbox( | |
| label="Status", | |
| lines=6, | |
| interactive=False, | |
| ) | |
| rebuild_btn = gr.Button("Rebuild index") | |
| kb_upload.change(upload_to_kb, inputs=kb_upload, outputs=kb_status) | |
| rebuild_btn.click(rebuild_index, inputs=None, outputs=kb_status) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=port, | |
| share=False, | |
| ) |