import os import glob import yaml import shutil import re from typing import List, Tuple import faiss import numpy as np import gradio as gr from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from PyPDF2 import PdfReader import docx # ----------------------------- # CONFIG # ----------------------------- def load_config(): """Load configuration with error handling""" try: with open("config.yaml", "r", encoding="utf-8") as f: return yaml.safe_load(f) except FileNotFoundError: print("⚠️ config.yaml not found, using defaults") return get_default_config() except Exception as e: print(f"⚠️ Error loading config: {e}, using defaults") return get_default_config() def get_default_config(): """Provide default configuration""" return { "kb": { "directory": "./knowledge_base", # can be overridden in config.yaml (e.g., ./kb) "index_directory": "./index", }, "models": { "embedding": "sentence-transformers/all-MiniLM-L6-v2", "qa": "google/flan-t5-small", }, "chunking": { "chunk_size": 1200, "overlap": 200, }, "thresholds": { "similarity": 0.1, }, "messages": { "welcome": "Ask me anything about the documents in the knowledge base!", "no_answer": "I couldn't find a relevant answer in the knowledge base.", }, "client": { "name": "RAG AI Assistant", }, "quick_actions": [], } CONFIG = load_config() KB_DIR = CONFIG["kb"]["directory"] INDEX_DIR = CONFIG["kb"]["index_directory"] EMBEDDING_MODEL_NAME = CONFIG["models"]["embedding"] QA_MODEL_NAME = CONFIG["models"].get("qa", "google/flan-t5-small") CHUNK_SIZE = CONFIG["chunking"]["chunk_size"] CHUNK_OVERLAP = CONFIG["chunking"]["overlap"] SIM_THRESHOLD = CONFIG["thresholds"]["similarity"] WELCOME_MSG = CONFIG["messages"]["welcome"] NO_ANSWER_MSG = CONFIG["messages"]["no_answer"] # ----------------------------- # UTILITIES # ----------------------------- def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]: """Split text into overlapping chunks""" if not text or not text.strip(): return [] chunks = [] start = 0 text_len = len(text) while start < text_len: end = min(start + chunk_size, text_len) chunk = text[start:end].strip() if chunk and len(chunk) > 20: # Avoid tiny chunks chunks.append(chunk) if end >= text_len: break start += chunk_size - overlap return chunks def load_file_text(path: str) -> str: """Load text from various file formats with error handling""" if not os.path.exists(path): raise FileNotFoundError(f"File not found: {path}") ext = os.path.splitext(path)[1].lower() try: if ext == ".pdf": reader = PdfReader(path) text_parts = [] for page in reader.pages: page_text = page.extract_text() if page_text: text_parts.append(page_text) return "\n".join(text_parts) elif ext in [".docx", ".doc"]: doc = docx.Document(path) return "\n".join(p.text for p in doc.paragraphs if p.text.strip()) else: # .txt, .md, etc. with open(path, "r", encoding="utf-8", errors="ignore") as f: return f.read() except Exception as e: print(f"Error reading {path}: {e}") raise def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]: """Load all documents from knowledge base directory""" docs: List[Tuple[str, str]] = [] if not os.path.exists(kb_dir): print(f"⚠️ Knowledge base directory not found: {kb_dir}") print(f"Creating directory: {kb_dir}") os.makedirs(kb_dir, exist_ok=True) return docs if not os.path.isdir(kb_dir): print(f"⚠️ {kb_dir} is not a directory") return docs # Support multiple file formats patterns = ["*.txt", "*.md", "*.pdf", "*.docx", "*.doc"] paths = [] for pattern in patterns: paths.extend(glob.glob(os.path.join(kb_dir, pattern))) if not paths: print(f"⚠️ No documents found in {kb_dir}") return docs print(f"Found {len(paths)} documents in knowledge base") for path in paths: try: text = load_file_text(path) if text and text.strip(): docs.append((os.path.basename(path), text)) print(f"✓ Loaded: {os.path.basename(path)}") else: print(f"⚠️ Empty file: {os.path.basename(path)}") except Exception as e: print(f"✗ Could not read {path}: {e}") return docs def clean_context_text(text: str) -> str: """ Clean raw document context before sending to the answer builder: - Remove markdown headings (#, ##, ###) - Remove list markers (1., 2), -, *) - Remove duplicate lines - Remove title-like lines (e.g. 'Knowledge Base Structure & Information Architecture Best Practices') """ lines = text.splitlines() cleaned = [] seen = set() for line in lines: l = line.strip() if not l: continue # Remove markdown headings like "# 1. Title", "## Section" l = re.sub(r"^#+\s*", "", l) # Remove ordered list prefixes like "1. ", "2) " l = re.sub(r"^\d+[\.\)]\s*", "", l) # Remove bullet markers like "- ", "* " l = re.sub(r"^[-*]\s*", "", l) # Skip very short "noise" lines if len(l) < 5: continue # Heuristic: skip "title-like" lines where almost every word is capitalized words = l.split() if words: cap_words = sum(1 for w in words if w[:1].isupper()) if len(words) <= 10 and cap_words >= len(words) - 1: # Looks like a heading / title, skip it continue # Avoid exact duplicates if l in seen: continue seen.add(l) cleaned.append(l) return "\n".join(cleaned) # ----------------------------- # KB INDEX (FAISS) # ----------------------------- class RAGIndex: def __init__(self): self.embedder = None self.qa_tokenizer = None self.qa_model = None self.chunks: List[str] = [] self.chunk_sources: List[str] = [] self.index = None self.initialized = False try: print("🔄 Initializing RAG Assistant...") self._initialize_models() self._build_or_load_index() self.initialized = True print("✅ RAG Assistant ready!") except Exception as e: print(f"❌ Initialization error: {e}") print("The assistant will run in limited mode.") def _initialize_models(self): """Initialize embedding and QA models""" try: print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}") self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME) print(f"Loading QA (seq2seq) model: {QA_MODEL_NAME}") self.qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME) self.qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME) except Exception as e: print(f"Error loading models: {e}") raise def _build_or_load_index(self): """Build or load FAISS index from knowledge base""" os.makedirs(INDEX_DIR, exist_ok=True) idx_path = os.path.join(INDEX_DIR, "kb.index") meta_path = os.path.join(INDEX_DIR, "kb_meta.npy") # Try to load existing index if os.path.exists(idx_path) and os.path.exists(meta_path): try: print("Loading existing FAISS index...") self.index = faiss.read_index(idx_path) meta = np.load(meta_path, allow_pickle=True).item() self.chunks = list(meta["chunks"]) self.chunk_sources = list(meta["sources"]) print(f"✓ Index loaded with {len(self.chunks)} chunks") return except Exception as e: print(f"⚠️ Could not load existing index: {e}") print("Building new index...") # Build new index print("Building new FAISS index from knowledge base...") docs = load_kb_documents(KB_DIR) if not docs: print("⚠️ No documents found in knowledge base") print(f" Please add .txt, .md, .pdf, or .docx files to: {KB_DIR}") self.index = None self.chunks = [] self.chunk_sources = [] return all_chunks: List[str] = [] all_sources: List[str] = [] for source, text in docs: chunks = chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP) for chunk in chunks: all_chunks.append(chunk) all_sources.append(source) if not all_chunks: print("⚠️ No valid chunks created from documents") self.index = None self.chunks = [] self.chunk_sources = [] return print(f"Created {len(all_chunks)} chunks from {len(docs)} documents") print("Generating embeddings...") embeddings = self.embedder.encode( all_chunks, show_progress_bar=True, convert_to_numpy=True, batch_size=32, ) dimension = embeddings.shape[1] index = faiss.IndexFlatIP(dimension) # Normalize for cosine similarity faiss.normalize_L2(embeddings) index.add(embeddings) # Save index try: faiss.write_index(index, idx_path) np.save( meta_path, { "chunks": np.array(all_chunks, dtype=object), "sources": np.array(all_sources, dtype=object), }, ) print("✓ Index saved successfully") except Exception as e: print(f"⚠️ Could not save index: {e}") self.index = index self.chunks = all_chunks self.chunk_sources = all_sources def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[str, str, float]]: """Retrieve relevant chunks for a query""" if not query or not query.strip(): return [] if self.index is None or not self.initialized: return [] try: q_emb = self.embedder.encode([query], convert_to_numpy=True) faiss.normalize_L2(q_emb) k = min(top_k, len(self.chunks)) if self.chunks else 0 if k == 0: return [] scores, idxs = self.index.search(q_emb, k) results: List[Tuple[str, str, float]] = [] for score, idx in zip(scores[0], idxs[0]): if idx == -1 or idx >= len(self.chunks): continue if score < SIM_THRESHOLD: continue results.append( (self.chunks[idx], self.chunk_sources[idx], float(score)) ) return results except Exception as e: print(f"Retrieval error: {e}") return [] def _generate_from_context( self, question: str, context: str, max_new_tokens: int = 180, ) -> str: """ Generate a grounded answer from the retrieved context using a seq2seq model (FLAN-T5, BART, etc.). The prompt forces the model to only use the context. """ if self.qa_model is None or self.qa_tokenizer is None: raise RuntimeError("QA model not loaded.") prompt = ( "You are a knowledge base assistant. Answer the question ONLY using the information " "in the context below.\n" "If the context does not contain the answer, say exactly: " "\"The documents do not contain enough information to answer this.\"\n\n" f"Question: {question}\n\n" "Context:\n" f"{context}\n\n" "Write a helpful answer in 2–4 sentences. Keep it factual and concise. " "Do NOT repeat the question. Do NOT include section titles or headings." ) inputs = self.qa_tokenizer( prompt, return_tensors="pt", truncation=True, max_length=768, ) outputs = self.qa_model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=0.0, # deterministic do_sample=False, ) answer = self.qa_tokenizer.decode( outputs[0], skip_special_tokens=True, ).strip() return answer def answer(self, question: str) -> str: """ Answer a question using RAG with sentence-level semantic selection and a generic seq2seq model (Flan-T5, BART, etc.). This function is fully stateless per call: it only uses the question and the indexed knowledge base, never previous answers. """ if not self.initialized: return "❌ Assistant not properly initialized. Please check the logs." if not question or not question.strip(): return "Please ask a question." if self.index is None or not self.chunks: return ( f"📚 Knowledge base is empty.\n\n" f"Please add documents to: `{KB_DIR}`\n" f"Supported formats: .txt, .md, .pdf, .docx" ) # ----------------------------- # 1) Retrieve top-K chunks for this question # ----------------------------- contexts = self.retrieve(question, top_k=5) if not contexts: return ( f"{NO_ANSWER_MSG}\n\n" f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base." ) used_sources = set() candidate_sentences = [] candidate_sources = [] # ----------------------------- # 2) Split retrieved chunks into sentences (generic, no KB-specific logic) # ----------------------------- for ctx, source, score in contexts: used_sources.add(source) cleaned_ctx = clean_context_text(ctx) if not cleaned_ctx: continue # Simple sentence splitter: split on ., ?, ! plus newlines raw_sents = re.split(r'(?<=[.!?])\s+|\n+', cleaned_ctx) for s in raw_sents: s_clean = s.strip() # skip very short sentences if len(s_clean) < 25: continue candidate_sentences.append(s_clean) candidate_sources.append(source) if not candidate_sentences: return ( f"{NO_ANSWER_MSG}\n\n" f"💡 Try adding more detailed documents to the knowledge base." ) # ----------------------------- # 3) Score sentences: semantic + lexical (generic) # ----------------------------- try: # Semantic similarity via sentence embeddings q_emb = self.embedder.encode([question], convert_to_numpy=True) s_embs = self.embedder.encode(candidate_sentences, convert_to_numpy=True) faiss.normalize_L2(q_emb) faiss.normalize_L2(s_embs) sims = np.dot(s_embs, q_emb.T).reshape(-1) # cosine similarity except Exception as e: print(f"Sentence embedding error, falling back to lexical scoring only: {e}") sims = np.zeros(len(candidate_sentences), dtype=float) # Lexical overlap (shared content words) q_words = {w.lower() for w in re.findall(r"\w+", question) if len(w) > 3} lex_scores = [] for sent in candidate_sentences: s_words = {w.lower() for w in re.findall(r"\w+", sent) if len(w) > 3} lex_scores.append(len(q_words & s_words)) lex_scores = np.array(lex_scores, dtype=float) # Combine scores in a generic way: semantic + a bit of lexical combined = (1.5 * sims) + (0.5 * lex_scores) # ----------------------------- # 4) Pick top-N sentences to form the context # ----------------------------- if len(combined) == 0: answer_text = NO_ANSWER_MSG else: top_idx = np.argsort(-combined) max_sentences = 5 # you can tune this chosen_sentences = [] chosen_sources = set() for i in top_idx: if len(chosen_sentences) >= max_sentences: break s = candidate_sentences[i].strip() if not s: continue if s in chosen_sentences: continue # avoid duplicates chosen_sentences.append(s) chosen_sources.add(candidate_sources[i]) if not chosen_sentences: answer_text = NO_ANSWER_MSG else: context_for_llm = "\n".join(chosen_sentences) # ----------------------------- # 5) Let the seq2seq model generate a natural answer # ----------------------------- try: answer_text = self._generate_from_context( question=question, context=context_for_llm, max_new_tokens=200, ).strip() except Exception as e: print(f"Generation error, falling back to extractive answer: {e}") answer_text = " ".join(chosen_sentences) if not answer_text: answer_text = NO_ANSWER_MSG # Track sources from retrieved chunks (or from chosen sentences if you prefer) sources_str = ", ".join(sorted(used_sources)) if used_sources else "N/A" return ( f"**Answer:** {answer_text}\n\n" f"**Sources:** {sources_str}" ) # Initialize RAG system print("=" * 50) rag_index = RAGIndex() print("=" * 50) # ----------------------------- # GRADIO APP (BLOCKS) # ----------------------------- def rag_respond(message, history): if history is None: history = [] user_msg = str(message) # Append to UI history ONLY history.append({"role": "user", "content": user_msg}) # ❗ Do NOT pass history to rag_index.answer() bot_reply = rag_index.answer(user_msg) # Append assistant reply for UI display history.append({"role": "assistant", "content": bot_reply}) # Return blank input + updated UI history return "", history def upload_to_kb(files): """Save uploaded files into the KB directory""" if not files: return "No files uploaded." if not isinstance(files, list): files = [files] os.makedirs(KB_DIR, exist_ok=True) saved_files = [] for f in files: src_path = getattr(f, "name", None) or str(f) if not os.path.exists(src_path): continue filename = os.path.basename(src_path) dest_path = os.path.join(KB_DIR, filename) try: shutil.copy(src_path, dest_path) saved_files.append(filename) except Exception as e: print(f"Error saving file {filename}: {e}") if not saved_files: return "No files could be saved. Check logs." return ( f"✅ Saved {len(saved_files)} file(s) to knowledge base:\n- " + "\n- ".join(saved_files) + "\n\nClick **Rebuild index** to include them in search." ) def rebuild_index(): """Trigger index rebuild from UI""" rag_index._build_or_load_index() if rag_index.index is None or not rag_index.chunks: return ( "⚠️ Index rebuild finished, but no documents or chunks were found.\n" f"Add files to `{KB_DIR}` and try again." ) return ( f"✅ Index rebuilt successfully.\n" f"Chunks in index: {len(rag_index.chunks)}" ) # Description + optional examples description = WELCOME_MSG if not rag_index.initialized or rag_index.index is None or not rag_index.chunks: description += ( f"\n\n⚠️ **Note:** Knowledge base is currently empty or index is not built.\n" f"Upload documents in the **Knowledge Base** tab and click **Rebuild index**." ) examples = [ qa.get("query") for qa in CONFIG.get("quick_actions", []) if qa.get("query") ] if not examples and rag_index.initialized and rag_index.index is not None and rag_index.chunks: examples = [ "What is a knowledge base?", "What are best practices for maintaining a KB?", "How should I structure knowledge base articles?", ] with gr.Blocks(title=CONFIG["client"]["name"]) as demo: gr.Markdown(f"# {CONFIG['client']['name']}") gr.Markdown(description) with gr.Tab("Chat"): chatbot = gr.Chatbot(label="RAG Chat") with gr.Row(): txt = gr.Textbox( show_label=False, placeholder="Ask a question about your documents and press Enter to send...", lines=1, # single line so Enter submits ) with gr.Row(): send_btn = gr.Button("Send") clear_btn = gr.Button("Clear") txt.submit(rag_respond, [txt, chatbot], [txt, chatbot]) send_btn.click(rag_respond, [txt, chatbot], [txt, chatbot]) clear_btn.click(lambda: ([], ""), None, [chatbot, txt]) with gr.Tab("Knowledge Base"): gr.Markdown( f""" ### Manage Knowledge Base - Supported formats: `.txt`, `.md`, `.pdf`, `.docx`, `.doc` - Files are stored in: `{KB_DIR}` - After uploading, click **Rebuild index** so the assistant can use the new content. """ ) kb_upload = gr.File( label="Upload documents", file_count="multiple", ) kb_status = gr.Textbox( label="Status", lines=6, interactive=False, ) rebuild_btn = gr.Button("Rebuild index") kb_upload.change(upload_to_kb, inputs=kb_upload, outputs=kb_status) rebuild_btn.click(rebuild_index, inputs=None, outputs=kb_status) if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) demo.launch( server_name="0.0.0.0", server_port=port, share=False, )