Spaces:
Sleeping
Sleeping
| import os | |
| import glob | |
| import yaml | |
| from typing import List, Tuple | |
| import faiss | |
| import numpy as np | |
| import gradio as gr | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline | |
| from PyPDF2 import PdfReader | |
| import docx | |
| # ----------------------------- | |
| # CONFIG | |
| # ----------------------------- | |
| def load_config(): | |
| """Load configuration with error handling""" | |
| try: | |
| with open("config.yaml", "r", encoding="utf-8") as f: | |
| return yaml.safe_load(f) | |
| except FileNotFoundError: | |
| print("⚠️ config.yaml not found, using defaults") | |
| return get_default_config() | |
| except Exception as e: | |
| print(f"⚠️ Error loading config: {e}, using defaults") | |
| return get_default_config() | |
| def get_default_config(): | |
| """Provide default configuration""" | |
| return { | |
| "kb": { | |
| "directory": "./knowledge_base", | |
| "index_directory": "./index", | |
| }, | |
| "models": { | |
| "embedding": "all-MiniLM-L6-v2", | |
| "qa": "deepset/roberta-base-squad2", | |
| }, | |
| "chunking": { | |
| "chunk_size": 500, | |
| "overlap": 50, | |
| }, | |
| "thresholds": { | |
| "similarity": 0.3, | |
| }, | |
| "messages": { | |
| "welcome": "Ask me anything about the documents in the knowledge base!", | |
| "no_answer": "I couldn't find a relevant answer in the knowledge base.", | |
| }, | |
| "client": { | |
| "name": "RAG AI Assistant", | |
| }, | |
| "quick_actions": [], | |
| } | |
| CONFIG = load_config() | |
| KB_DIR = CONFIG["kb"]["directory"] | |
| INDEX_DIR = CONFIG["kb"]["index_directory"] | |
| EMBEDDING_MODEL_NAME = CONFIG["models"]["embedding"] | |
| QA_MODEL_NAME = CONFIG["models"]["qa"] | |
| CHUNK_SIZE = CONFIG["chunking"]["chunk_size"] | |
| CHUNK_OVERLAP = CONFIG["chunking"]["overlap"] | |
| SIM_THRESHOLD = CONFIG["thresholds"]["similarity"] | |
| WELCOME_MSG = CONFIG["messages"]["welcome"] | |
| NO_ANSWER_MSG = CONFIG["messages"]["no_answer"] | |
| # ----------------------------- | |
| # UTILITIES | |
| # ----------------------------- | |
| def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]: | |
| """Split text into overlapping chunks""" | |
| if not text or not text.strip(): | |
| return [] | |
| chunks = [] | |
| start = 0 | |
| text_len = len(text) | |
| while start < text_len: | |
| end = min(start + chunk_size, text_len) | |
| chunk = text[start:end].strip() | |
| if chunk and len(chunk) > 20: # Avoid tiny chunks | |
| chunks.append(chunk) | |
| if end >= text_len: | |
| break | |
| start += chunk_size - overlap | |
| return chunks | |
| def load_file_text(path: str) -> str: | |
| """Load text from various file formats with error handling""" | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"File not found: {path}") | |
| ext = os.path.splitext(path)[1].lower() | |
| try: | |
| if ext == ".pdf": | |
| reader = PdfReader(path) | |
| text_parts = [] | |
| for page in reader.pages: | |
| page_text = page.extract_text() | |
| if page_text: | |
| text_parts.append(page_text) | |
| return "\n".join(text_parts) | |
| elif ext in [".docx", ".doc"]: | |
| doc = docx.Document(path) | |
| return "\n".join(p.text for p in doc.paragraphs if p.text.strip()) | |
| else: # .txt, .md, etc. | |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: | |
| return f.read() | |
| except Exception as e: | |
| print(f"Error reading {path}: {e}") | |
| raise | |
| def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]: | |
| """Load all documents from knowledge base directory""" | |
| docs: List[Tuple[str, str]] = [] | |
| if not os.path.exists(kb_dir): | |
| print(f"⚠️ Knowledge base directory not found: {kb_dir}") | |
| print(f"Creating directory: {kb_dir}") | |
| os.makedirs(kb_dir, exist_ok=True) | |
| return docs | |
| if not os.path.isdir(kb_dir): | |
| print(f"⚠️ {kb_dir} is not a directory") | |
| return docs | |
| # Support multiple file formats | |
| patterns = ["*.txt", "*.md", "*.pdf", "*.docx", "*.doc"] | |
| paths = [] | |
| for pattern in patterns: | |
| paths.extend(glob.glob(os.path.join(kb_dir, pattern))) | |
| if not paths: | |
| print(f"⚠️ No documents found in {kb_dir}") | |
| return docs | |
| print(f"Found {len(paths)} documents in knowledge base") | |
| for path in paths: | |
| try: | |
| text = load_file_text(path) | |
| if text and text.strip(): | |
| docs.append((os.path.basename(path), text)) | |
| print(f"✓ Loaded: {os.path.basename(path)}") | |
| else: | |
| print(f"⚠️ Empty file: {os.path.basename(path)}") | |
| except Exception as e: | |
| print(f"✗ Could not read {path}: {e}") | |
| return docs | |
| # ----------------------------- | |
| # KB INDEX (FAISS) | |
| # ----------------------------- | |
| class RAGIndex: | |
| def __init__(self): | |
| self.embedder = None | |
| self.qa_pipeline = None | |
| self.chunks: List[str] = [] | |
| self.chunk_sources: List[str] = [] | |
| self.index = None | |
| self.initialized = False | |
| try: | |
| print("🔄 Initializing RAG Assistant...") | |
| self._initialize_models() | |
| self._build_or_load_index() | |
| self.initialized = True | |
| print("✅ RAG Assistant ready!") | |
| except Exception as e: | |
| print(f"❌ Initialization error: {e}") | |
| print("The assistant will run in limited mode.") | |
| def _initialize_models(self): | |
| """Initialize embedding and QA models""" | |
| try: | |
| print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}") | |
| self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME) | |
| print(f"Loading QA model: {QA_MODEL_NAME}") | |
| self.qa_pipeline = pipeline( | |
| "question-answering", | |
| model=AutoModelForQuestionAnswering.from_pretrained(QA_MODEL_NAME), | |
| tokenizer=AutoTokenizer.from_pretrained(QA_MODEL_NAME), | |
| handle_impossible_answer=True, | |
| ) | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| raise | |
| def _build_or_load_index(self): | |
| """Build or load FAISS index from knowledge base""" | |
| os.makedirs(INDEX_DIR, exist_ok=True) | |
| idx_path = os.path.join(INDEX_DIR, "kb.index") | |
| meta_path = os.path.join(INDEX_DIR, "kb_meta.npy") | |
| # Try to load existing index | |
| if os.path.exists(idx_path) and os.path.exists(meta_path): | |
| try: | |
| print("Loading existing FAISS index...") | |
| self.index = faiss.read_index(idx_path) | |
| meta = np.load(meta_path, allow_pickle=True).item() | |
| self.chunks = list(meta["chunks"]) | |
| self.chunk_sources = list(meta["sources"]) | |
| print(f"✓ Index loaded with {len(self.chunks)} chunks") | |
| return | |
| except Exception as e: | |
| print(f"⚠️ Could not load existing index: {e}") | |
| print("Building new index...") | |
| # Build new index | |
| print("Building new FAISS index from knowledge base...") | |
| docs = load_kb_documents(KB_DIR) | |
| if not docs: | |
| print("⚠️ No documents found in knowledge base") | |
| print(f" Please add .txt, .md, .pdf, or .docx files to: {KB_DIR}") | |
| self.index = None | |
| return | |
| all_chunks: List[str] = [] | |
| all_sources: List[str] = [] | |
| for source, text in docs: | |
| chunks = chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP) | |
| for chunk in chunks: | |
| all_chunks.append(chunk) | |
| all_sources.append(source) | |
| if not all_chunks: | |
| print("⚠️ No valid chunks created from documents") | |
| self.index = None | |
| return | |
| print(f"Created {len(all_chunks)} chunks from {len(docs)} documents") | |
| print("Generating embeddings...") | |
| embeddings = self.embedder.encode( | |
| all_chunks, | |
| show_progress_bar=True, | |
| convert_to_numpy=True, | |
| batch_size=32, | |
| ) | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatIP(dimension) | |
| # Normalize for cosine similarity | |
| faiss.normalize_L2(embeddings) | |
| index.add(embeddings) | |
| # Save index | |
| try: | |
| faiss.write_index(index, idx_path) | |
| np.save( | |
| meta_path, | |
| { | |
| "chunks": np.array(all_chunks, dtype=object), | |
| "sources": np.array(all_sources, dtype=object), | |
| }, | |
| ) | |
| print("✓ Index saved successfully") | |
| except Exception as e: | |
| print(f"⚠️ Could not save index: {e}") | |
| self.index = index | |
| self.chunks = all_chunks | |
| self.chunk_sources = all_sources | |
| def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[str, str, float]]: | |
| """Retrieve relevant chunks for a query""" | |
| if not query or not query.strip(): | |
| return [] | |
| if self.index is None or not self.initialized: | |
| return [] | |
| try: | |
| q_emb = self.embedder.encode([query], convert_to_numpy=True) | |
| faiss.normalize_L2(q_emb) | |
| scores, idxs = self.index.search(q_emb, min(top_k, len(self.chunks))) | |
| results: List[Tuple[str, str, float]] = [] | |
| for score, idx in zip(scores[0], idxs[0]): | |
| if idx == -1 or idx >= len(self.chunks): | |
| continue | |
| if score < SIM_THRESHOLD: | |
| continue | |
| results.append( | |
| (self.chunks[idx], self.chunk_sources[idx], float(score)) | |
| ) | |
| return results | |
| except Exception as e: | |
| print(f"Retrieval error: {e}") | |
| return [] | |
| def answer(self, question: str) -> str: | |
| """Answer a question using RAG""" | |
| if not self.initialized: | |
| return "❌ Assistant not properly initialized. Please check the logs." | |
| if not question or not question.strip(): | |
| return "Please ask a question." | |
| if self.index is None: | |
| return ( | |
| f"📚 Knowledge base is empty.\n\n" | |
| f"Please add documents to: `{KB_DIR}`\n" | |
| f"Supported formats: .txt, .md, .pdf, .docx" | |
| ) | |
| # Retrieve relevant contexts | |
| contexts = self.retrieve(question, top_k=3) | |
| if not contexts: | |
| return ( | |
| f"{NO_ANSWER_MSG}\n\n" | |
| f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base." | |
| ) | |
| # Try to extract answer from each context | |
| answers = [] | |
| for ctx, source, score in contexts: | |
| # Truncate context if too long (max 512 tokens for most QA models) | |
| max_context_length = 2000 # characters, roughly 512 tokens | |
| truncated_ctx = ctx[:max_context_length] | |
| qa_input = {"question": question, "context": truncated_ctx} | |
| try: | |
| result = self.qa_pipeline(qa_input) | |
| answer_text = result.get("answer", "").strip() | |
| answer_score = result.get("score", 0.0) | |
| if answer_text and answer_score > 0.01: # Minimum confidence threshold | |
| answers.append((answer_text, source, answer_score, score)) | |
| except Exception as e: | |
| print(f"QA error on context from {source}: {e}") | |
| continue | |
| if not answers: | |
| # Provide context even if no specific answer found | |
| best_ctx, best_src, best_score = contexts[0] | |
| preview = best_ctx[:300] + "..." if len(best_ctx) > 300 else best_ctx | |
| return ( | |
| f"I found relevant information but couldn't extract a specific answer.\n\n" | |
| f"**Relevant context from {best_src}:**\n{preview}\n\n" | |
| f"💡 Try asking a more specific question." | |
| ) | |
| # Pick best answer (weighted by both retrieval and QA scores) | |
| answers.sort(key=lambda x: x[2] * x[3], reverse=True) | |
| best_answer, best_source, qa_score, retrieval_score = answers[0] | |
| return ( | |
| f"**Answer:** {best_answer}\n\n" | |
| f"**Source:** {best_source}\n" | |
| f"**Confidence:** {qa_score:.2%}" | |
| ) | |
| # Initialize RAG system | |
| print("=" * 50) | |
| rag_index = RAGIndex() | |
| print("=" * 50) | |
| # ----------------------------- | |
| # GRADIO CHAT | |
| # ----------------------------- | |
| def rag_respond(message, history): | |
| """Handle chat messages""" | |
| if not message or not str(message).strip(): | |
| return "Please enter a question." | |
| return rag_index.answer(str(message)) | |
| # Build interface | |
| description = WELCOME_MSG | |
| if not rag_index.initialized or rag_index.index is None: | |
| description += ( | |
| f"\n\n⚠️ **Note:** Knowledge base is empty. " | |
| f"Add documents to `{KB_DIR}` and restart." | |
| ) | |
| examples = [ | |
| qa.get("query") | |
| for qa in CONFIG.get("quick_actions", []) | |
| if qa.get("query") | |
| ] | |
| if not examples and rag_index.initialized and rag_index.index is not None: | |
| examples = [ | |
| "What is this document about?", | |
| "Can you summarize the main points?", | |
| "What are the key findings?", | |
| ] | |
| chat = gr.ChatInterface( | |
| fn=rag_respond, | |
| title=CONFIG["client"]["name"], | |
| description=description, | |
| type="text", # FIX: use text so `message` is a string | |
| examples=examples if examples else None, | |
| cache_examples=False, | |
| retry_btn="🔄 Retry", | |
| undo_btn="↩️ Undo", | |
| clear_btn="🗑️ Clear", | |
| ) | |
| if __name__ == "__main__": | |
| # Launch with better settings for Hugging Face Spaces | |
| port = int(os.environ.get("PORT", 7860)) # FIX: use HF port if provided | |
| chat.launch( | |
| server_name="0.0.0.0", | |
| server_port=port, | |
| share=False, | |
| ) | |