import os import glob import yaml from typing import List, Tuple import faiss import numpy as np import gradio as gr from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline from PyPDF2 import PdfReader import docx # ----------------------------- # CONFIG # ----------------------------- def load_config(): """Load configuration with error handling""" try: with open("config.yaml", "r", encoding="utf-8") as f: return yaml.safe_load(f) except FileNotFoundError: print("⚠️ config.yaml not found, using defaults") return get_default_config() except Exception as e: print(f"⚠️ Error loading config: {e}, using defaults") return get_default_config() def get_default_config(): """Provide default configuration""" return { "kb": { "directory": "./knowledge_base", "index_directory": "./index", }, "models": { "embedding": "all-MiniLM-L6-v2", "qa": "deepset/roberta-base-squad2", }, "chunking": { "chunk_size": 500, "overlap": 50, }, "thresholds": { "similarity": 0.3, }, "messages": { "welcome": "Ask me anything about the documents in the knowledge base!", "no_answer": "I couldn't find a relevant answer in the knowledge base.", }, "client": { "name": "RAG AI Assistant", }, "quick_actions": [], } CONFIG = load_config() KB_DIR = CONFIG["kb"]["directory"] INDEX_DIR = CONFIG["kb"]["index_directory"] EMBEDDING_MODEL_NAME = CONFIG["models"]["embedding"] QA_MODEL_NAME = CONFIG["models"]["qa"] CHUNK_SIZE = CONFIG["chunking"]["chunk_size"] CHUNK_OVERLAP = CONFIG["chunking"]["overlap"] SIM_THRESHOLD = CONFIG["thresholds"]["similarity"] WELCOME_MSG = CONFIG["messages"]["welcome"] NO_ANSWER_MSG = CONFIG["messages"]["no_answer"] # ----------------------------- # UTILITIES # ----------------------------- def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]: """Split text into overlapping chunks""" if not text or not text.strip(): return [] chunks = [] start = 0 text_len = len(text) while start < text_len: end = min(start + chunk_size, text_len) chunk = text[start:end].strip() if chunk and len(chunk) > 20: # Avoid tiny chunks chunks.append(chunk) if end >= text_len: break start += chunk_size - overlap return chunks def load_file_text(path: str) -> str: """Load text from various file formats with error handling""" if not os.path.exists(path): raise FileNotFoundError(f"File not found: {path}") ext = os.path.splitext(path)[1].lower() try: if ext == ".pdf": reader = PdfReader(path) text_parts = [] for page in reader.pages: page_text = page.extract_text() if page_text: text_parts.append(page_text) return "\n".join(text_parts) elif ext in [".docx", ".doc"]: doc = docx.Document(path) return "\n".join(p.text for p in doc.paragraphs if p.text.strip()) else: # .txt, .md, etc. with open(path, "r", encoding="utf-8", errors="ignore") as f: return f.read() except Exception as e: print(f"Error reading {path}: {e}") raise def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]: """Load all documents from knowledge base directory""" docs: List[Tuple[str, str]] = [] if not os.path.exists(kb_dir): print(f"⚠️ Knowledge base directory not found: {kb_dir}") print(f"Creating directory: {kb_dir}") os.makedirs(kb_dir, exist_ok=True) return docs if not os.path.isdir(kb_dir): print(f"⚠️ {kb_dir} is not a directory") return docs # Support multiple file formats patterns = ["*.txt", "*.md", "*.pdf", "*.docx", "*.doc"] paths = [] for pattern in patterns: paths.extend(glob.glob(os.path.join(kb_dir, pattern))) if not paths: print(f"⚠️ No documents found in {kb_dir}") return docs print(f"Found {len(paths)} documents in knowledge base") for path in paths: try: text = load_file_text(path) if text and text.strip(): docs.append((os.path.basename(path), text)) print(f"✓ Loaded: {os.path.basename(path)}") else: print(f"⚠️ Empty file: {os.path.basename(path)}") except Exception as e: print(f"✗ Could not read {path}: {e}") return docs # ----------------------------- # KB INDEX (FAISS) # ----------------------------- class RAGIndex: def __init__(self): self.embedder = None self.qa_pipeline = None self.chunks: List[str] = [] self.chunk_sources: List[str] = [] self.index = None self.initialized = False try: print("🔄 Initializing RAG Assistant...") self._initialize_models() self._build_or_load_index() self.initialized = True print("✅ RAG Assistant ready!") except Exception as e: print(f"❌ Initialization error: {e}") print("The assistant will run in limited mode.") def _initialize_models(self): """Initialize embedding and QA models""" try: print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}") self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME) print(f"Loading QA model: {QA_MODEL_NAME}") self.qa_pipeline = pipeline( "question-answering", model=AutoModelForQuestionAnswering.from_pretrained(QA_MODEL_NAME), tokenizer=AutoTokenizer.from_pretrained(QA_MODEL_NAME), handle_impossible_answer=True, ) except Exception as e: print(f"Error loading models: {e}") raise def _build_or_load_index(self): """Build or load FAISS index from knowledge base""" os.makedirs(INDEX_DIR, exist_ok=True) idx_path = os.path.join(INDEX_DIR, "kb.index") meta_path = os.path.join(INDEX_DIR, "kb_meta.npy") # Try to load existing index if os.path.exists(idx_path) and os.path.exists(meta_path): try: print("Loading existing FAISS index...") self.index = faiss.read_index(idx_path) meta = np.load(meta_path, allow_pickle=True).item() self.chunks = list(meta["chunks"]) self.chunk_sources = list(meta["sources"]) print(f"✓ Index loaded with {len(self.chunks)} chunks") return except Exception as e: print(f"⚠️ Could not load existing index: {e}") print("Building new index...") # Build new index print("Building new FAISS index from knowledge base...") docs = load_kb_documents(KB_DIR) if not docs: print("⚠️ No documents found in knowledge base") print(f" Please add .txt, .md, .pdf, or .docx files to: {KB_DIR}") self.index = None return all_chunks: List[str] = [] all_sources: List[str] = [] for source, text in docs: chunks = chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP) for chunk in chunks: all_chunks.append(chunk) all_sources.append(source) if not all_chunks: print("⚠️ No valid chunks created from documents") self.index = None return print(f"Created {len(all_chunks)} chunks from {len(docs)} documents") print("Generating embeddings...") embeddings = self.embedder.encode( all_chunks, show_progress_bar=True, convert_to_numpy=True, batch_size=32, ) dimension = embeddings.shape[1] index = faiss.IndexFlatIP(dimension) # Normalize for cosine similarity faiss.normalize_L2(embeddings) index.add(embeddings) # Save index try: faiss.write_index(index, idx_path) np.save( meta_path, { "chunks": np.array(all_chunks, dtype=object), "sources": np.array(all_sources, dtype=object), }, ) print("✓ Index saved successfully") except Exception as e: print(f"⚠️ Could not save index: {e}") self.index = index self.chunks = all_chunks self.chunk_sources = all_sources def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[str, str, float]]: """Retrieve relevant chunks for a query""" if not query or not query.strip(): return [] if self.index is None or not self.initialized: return [] try: q_emb = self.embedder.encode([query], convert_to_numpy=True) faiss.normalize_L2(q_emb) scores, idxs = self.index.search(q_emb, min(top_k, len(self.chunks))) results: List[Tuple[str, str, float]] = [] for score, idx in zip(scores[0], idxs[0]): if idx == -1 or idx >= len(self.chunks): continue if score < SIM_THRESHOLD: continue results.append( (self.chunks[idx], self.chunk_sources[idx], float(score)) ) return results except Exception as e: print(f"Retrieval error: {e}") return [] def answer(self, question: str) -> str: """Answer a question using RAG""" if not self.initialized: return "❌ Assistant not properly initialized. Please check the logs." if not question or not question.strip(): return "Please ask a question." if self.index is None: return ( f"📚 Knowledge base is empty.\n\n" f"Please add documents to: `{KB_DIR}`\n" f"Supported formats: .txt, .md, .pdf, .docx" ) # Retrieve relevant contexts contexts = self.retrieve(question, top_k=3) if not contexts: return ( f"{NO_ANSWER_MSG}\n\n" f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base." ) # Try to extract answer from each context answers = [] for ctx, source, score in contexts: # Truncate context if too long (max 512 tokens for most QA models) max_context_length = 2000 # characters, roughly 512 tokens truncated_ctx = ctx[:max_context_length] qa_input = {"question": question, "context": truncated_ctx} try: result = self.qa_pipeline(qa_input) answer_text = result.get("answer", "").strip() answer_score = result.get("score", 0.0) if answer_text and answer_score > 0.01: # Minimum confidence threshold answers.append((answer_text, source, answer_score, score)) except Exception as e: print(f"QA error on context from {source}: {e}") continue if not answers: # Provide context even if no specific answer found best_ctx, best_src, best_score = contexts[0] preview = best_ctx[:300] + "..." if len(best_ctx) > 300 else best_ctx return ( f"I found relevant information but couldn't extract a specific answer.\n\n" f"**Relevant context from {best_src}:**\n{preview}\n\n" f"💡 Try asking a more specific question." ) # Pick best answer (weighted by both retrieval and QA scores) answers.sort(key=lambda x: x[2] * x[3], reverse=True) best_answer, best_source, qa_score, retrieval_score = answers[0] return ( f"**Answer:** {best_answer}\n\n" f"**Source:** {best_source}\n" f"**Confidence:** {qa_score:.2%}" ) # Initialize RAG system print("=" * 50) rag_index = RAGIndex() print("=" * 50) # ----------------------------- # GRADIO CHAT # ----------------------------- def rag_respond(message, history): """Handle chat messages""" if not message or not str(message).strip(): return "Please enter a question." return rag_index.answer(str(message)) # Build interface description = WELCOME_MSG if not rag_index.initialized or rag_index.index is None: description += ( f"\n\n⚠️ **Note:** Knowledge base is empty. " f"Add documents to `{KB_DIR}` and restart." ) examples = [ qa.get("query") for qa in CONFIG.get("quick_actions", []) if qa.get("query") ] if not examples and rag_index.initialized and rag_index.index is not None: examples = [ "What is this document about?", "Can you summarize the main points?", "What are the key findings?", ] chat = gr.ChatInterface( fn=rag_respond, title=CONFIG["client"]["name"], description=description, type="text", # FIX: use text so `message` is a string examples=examples if examples else None, cache_examples=False, retry_btn="🔄 Retry", undo_btn="↩️ Undo", clear_btn="🗑️ Clear", ) if __name__ == "__main__": # Launch with better settings for Hugging Face Spaces port = int(os.environ.get("PORT", 7860)) # FIX: use HF port if provided chat.launch( server_name="0.0.0.0", server_port=port, share=False, )