Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import yaml | |
| import argparse | |
| from pathlib import Path | |
| from typing import List, Dict, Tuple, Optional | |
| import numpy as np | |
| import faiss | |
| import gradio as gr | |
| from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering | |
| from sentence_transformers import SentenceTransformer | |
| import PyPDF2 | |
| import docx | |
| # ----------- Configuration Loader ----------- | |
| class Config: | |
| """Load and manage configuration from YAML file.""" | |
| def __init__(self, config_path: str = "config.yaml"): | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| self.data = yaml.safe_load(f) | |
| def client_name(self) -> str: | |
| return self.data.get('client', {}).get('name', 'RAG Assistant') | |
| def client_description(self) -> str: | |
| return self.data.get('client', {}).get('description', 'AI-powered Q&A with document retrieval and citation') | |
| def client_logo(self) -> Optional[str]: | |
| return self.data.get('client', {}).get('logo') | |
| def theme_color(self) -> str: | |
| return self.data.get('client', {}).get('theme_color', 'blue') | |
| def kb_directory(self) -> Path: | |
| return Path(self.data.get('kb', {}).get('directory', './kb')) | |
| def index_directory(self) -> Path: | |
| return Path(self.data.get('kb', {}).get('index_directory', './.index')) | |
| def embedding_model(self) -> str: | |
| return self.data.get('models', {}).get('embedding', 'sentence-transformers/all-MiniLM-L6-v2') | |
| def qa_model(self) -> str: | |
| return self.data.get('models', {}).get('qa', 'deepset/roberta-base-squad2') | |
| def confidence_threshold(self) -> float: | |
| return self.data.get('thresholds', {}).get('confidence', 0.25) | |
| def similarity_threshold(self) -> float: | |
| return self.data.get('thresholds', {}).get('similarity', 0.35) | |
| def chunk_size(self) -> int: | |
| return self.data.get('chunking', {}).get('chunk_size', 800) | |
| def chunk_overlap(self) -> int: | |
| return self.data.get('chunking', {}).get('overlap', 200) | |
| def quick_actions(self) -> List[Tuple[str, str]]: | |
| actions = self.data.get('quick_actions', []) | |
| return [(a['label'], a['query']) for a in actions] | |
| def welcome_message(self) -> str: | |
| return self.data.get('messages', {}).get('welcome', | |
| 'π How can I help? Ask me anything or use a quick action button below.') | |
| def no_answer_message(self) -> str: | |
| return self.data.get('messages', {}).get('no_answer', | |
| "β **I don't know the answer to that** but if you have any document with details I can learn about it.") | |
| def upload_prompt(self) -> str: | |
| return self.data.get('messages', {}).get('upload_prompt', | |
| 'π€ Upload a relevant document above, and I\'ll be able to help you find the information you need!') | |
| # Global config instance | |
| config = None | |
| # ----------- Document Extraction ----------- | |
| def extract_text_from_pdf(file_path: str) -> str: | |
| """Extract text from PDF file.""" | |
| text = "" | |
| try: | |
| with open(file_path, 'rb') as file: | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| except Exception as e: | |
| raise RuntimeError(f"Error reading PDF: {str(e)}") | |
| return text | |
| def extract_text_from_docx(file_path: str) -> str: | |
| """Extract text from DOCX file.""" | |
| try: | |
| doc = docx.Document(file_path) | |
| text = "\n".join([paragraph.text for paragraph in doc.paragraphs]) | |
| return text | |
| except Exception as e: | |
| raise RuntimeError(f"Error reading DOCX: {str(e)}") | |
| def extract_text_from_txt(file_path: str) -> str: | |
| """Extract text from TXT file.""" | |
| try: | |
| with open(file_path, 'r', encoding='utf-8', errors='ignore') as file: | |
| return file.read() | |
| except Exception as e: | |
| raise RuntimeError(f"Error reading TXT: {str(e)}") | |
| def extract_text_from_file(file_path: str) -> Tuple[str, str]: | |
| """Extract text from uploaded file based on extension.""" | |
| ext = Path(file_path).suffix.lower() | |
| if ext == '.pdf': | |
| return extract_text_from_pdf(file_path), 'PDF' | |
| elif ext == '.docx': | |
| return extract_text_from_docx(file_path), 'DOCX' | |
| elif ext in ['.txt', '.md']: | |
| return extract_text_from_txt(file_path), 'Text' | |
| else: | |
| raise ValueError(f"Unsupported file type: {ext}. Supported: .pdf, .docx, .txt, .md") | |
| # ----------- Document Processing ----------- | |
| HEADING_RE = re.compile(r"^(#{1,6})\s+(.*)$", re.MULTILINE) | |
| def read_markdown_files(kb_dir: Path) -> List[Dict]: | |
| """Read all markdown files from the knowledge base directory.""" | |
| docs = [] | |
| for md_path in sorted(kb_dir.glob("*.md")): | |
| text = md_path.read_text(encoding="utf-8", errors="ignore") | |
| title = md_path.stem.replace("_", " ").title() | |
| m = re.search(r"^#\s+(.*)$", text, flags=re.MULTILINE) | |
| if m: | |
| title = m.group(1).strip() | |
| docs.append({ | |
| "filepath": str(md_path), | |
| "filename": md_path.name, | |
| "title": title, | |
| "text": text | |
| }) | |
| return docs | |
| def chunk_markdown(doc: Dict, chunk_chars: int = None, overlap: int = None) -> List[Dict]: | |
| """Split markdown document into overlapping chunks.""" | |
| if chunk_chars is None: | |
| chunk_chars = config.chunk_size | |
| if overlap is None: | |
| overlap = config.chunk_overlap | |
| text = doc["text"] | |
| sections = re.split(r"(?=^##\s+|\n##\s+|\n###\s+|^###\s+)", text, flags=re.MULTILINE) | |
| if len(sections) == 1: | |
| sections = [text] | |
| chunks = [] | |
| for sec in sections: | |
| sec = sec.strip() | |
| if not sec or len(sec) < 50: | |
| continue | |
| heading_match = HEADING_RE.search(sec) | |
| section_heading = heading_match.group(2).strip() if heading_match else doc["title"] | |
| start = 0 | |
| while start < len(sec): | |
| end = min(start + chunk_chars, len(sec)) | |
| chunk_text = sec[start:end].strip() | |
| if len(chunk_text) > 50: | |
| chunks.append({ | |
| "doc_title": doc["title"], | |
| "filename": doc["filename"], | |
| "filepath": doc["filepath"], | |
| "section": section_heading, | |
| "content": chunk_text | |
| }) | |
| if end == len(sec): | |
| break | |
| start = max(0, end - overlap) | |
| return chunks | |
| # ----------- KB Index ----------- | |
| class KBIndex: | |
| def __init__(self): | |
| self.embedder = SentenceTransformer(config.embedding_model) | |
| self.reader_tokenizer = AutoTokenizer.from_pretrained(config.qa_model) | |
| self.reader_model = AutoModelForQuestionAnswering.from_pretrained(config.qa_model) | |
| self.reader = pipeline( | |
| "question-answering", | |
| model=self.reader_model, | |
| tokenizer=self.reader_tokenizer, | |
| max_answer_len=200, | |
| handle_impossible_answer=True | |
| ) | |
| self.index = None | |
| self.embeddings = None | |
| self.metadata = [] | |
| self.uploaded_file_active = False | |
| # Paths based on config | |
| self.embeddings_path = config.index_directory / "kb_embeddings.npy" | |
| self.metadata_path = config.index_directory / "kb_metadata.json" | |
| self.faiss_path = config.index_directory / "kb_faiss.index" | |
| def build(self, kb_dir: Path): | |
| """Build the FAISS index from markdown files.""" | |
| docs = read_markdown_files(kb_dir) | |
| if not docs: | |
| raise RuntimeError(f"No markdown files found in {kb_dir.resolve()}") | |
| all_chunks = [] | |
| for d in docs: | |
| all_chunks.extend(chunk_markdown(d)) | |
| if not all_chunks: | |
| raise RuntimeError("No content chunks generated from KB.") | |
| texts = [c["content"] for c in all_chunks] | |
| embeddings = self.embedder.encode( | |
| texts, | |
| batch_size=32, | |
| convert_to_numpy=True, | |
| show_progress_bar=True | |
| ) | |
| faiss.normalize_L2(embeddings) | |
| dim = embeddings.shape[1] | |
| index = faiss.IndexFlatIP(dim) | |
| index.add(embeddings) | |
| self.index = index | |
| self.embeddings = embeddings | |
| self.metadata = all_chunks | |
| self.uploaded_file_active = False | |
| # Ensure index directory exists | |
| config.index_directory.mkdir(exist_ok=True, parents=True) | |
| np.save(self.embeddings_path, embeddings) | |
| with open(self.metadata_path, "w", encoding="utf-8") as f: | |
| json.dump(self.metadata, f, ensure_ascii=False, indent=2) | |
| faiss.write_index(index, str(self.faiss_path)) | |
| def build_from_uploaded_file(self, file_path: str, filename: str): | |
| """Build temporary index from an uploaded file.""" | |
| text_content, file_type = extract_text_from_file(file_path) | |
| if not text_content or len(text_content.strip()) < 100: | |
| raise RuntimeError("File appears to be empty or too short.") | |
| doc = { | |
| "filepath": file_path, | |
| "filename": filename, | |
| "title": Path(filename).stem.replace("_", " ").title(), | |
| "text": text_content | |
| } | |
| all_chunks = chunk_markdown(doc) | |
| if not all_chunks: | |
| raise RuntimeError("Could not extract meaningful content from file.") | |
| texts = [c["content"] for c in all_chunks] | |
| embeddings = self.embedder.encode( | |
| texts, | |
| batch_size=32, | |
| convert_to_numpy=True, | |
| show_progress_bar=False | |
| ) | |
| faiss.normalize_L2(embeddings) | |
| dim = embeddings.shape[1] | |
| index = faiss.IndexFlatIP(dim) | |
| index.add(embeddings) | |
| self.index = index | |
| self.embeddings = embeddings | |
| self.metadata = all_chunks | |
| self.uploaded_file_active = True | |
| return len(all_chunks), file_type | |
| def load(self) -> bool: | |
| """Load pre-built index from disk.""" | |
| if not (self.embeddings_path.exists() and self.metadata_path.exists() and self.faiss_path.exists()): | |
| return False | |
| self.embeddings = np.load(self.embeddings_path) | |
| with open(self.metadata_path, "r", encoding="utf-8") as f: | |
| self.metadata = json.load(f) | |
| self.index = faiss.read_index(str(self.faiss_path)) | |
| self.uploaded_file_active = False | |
| return True | |
| def retrieve(self, query: str, top_k: int = 6) -> List[Tuple[int, float]]: | |
| """Retrieve top-k most similar chunks for a query.""" | |
| q_emb = self.embedder.encode([query], convert_to_numpy=True) | |
| faiss.normalize_L2(q_emb) | |
| D, I = self.index.search(q_emb, top_k) | |
| return list(zip(I[0].tolist(), D[0].tolist())) | |
| def answer(self, question: str, retrieved: List[Tuple[int, float]]) -> Tuple[Optional[str], float, List[Dict], float]: | |
| """Extract answer from retrieved chunks using QA model.""" | |
| candidates = [] | |
| for idx, sim in retrieved: | |
| meta = self.metadata[idx] | |
| ctx = meta["content"] | |
| try: | |
| out = self.reader(question=question, context=ctx) | |
| score = float(out.get("score", 0.0)) | |
| answer_text = out.get("answer", "").strip() | |
| if answer_text and len(answer_text) > 3: | |
| expanded_answer = self._expand_answer(answer_text, ctx) | |
| candidates.append({ | |
| "text": expanded_answer, | |
| "original": answer_text, | |
| "score": score, | |
| "meta": meta, | |
| "sim": float(sim), | |
| "context": ctx | |
| }) | |
| except Exception as e: | |
| continue | |
| if not candidates: | |
| return None, 0.0, [], max([s for _, s in retrieved]) if retrieved else 0.0 | |
| candidates.sort(key=lambda x: x["score"] * 0.7 + x["sim"] * 0.3, reverse=True) | |
| best = candidates[0] | |
| citations = [] | |
| seen = set() | |
| for idx, _ in retrieved[:3]: | |
| m = self.metadata[idx] | |
| key = (m["filename"], m["section"]) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| citations.append({ | |
| "title": m["doc_title"], | |
| "filename": m["filename"], | |
| "section": m["section"] | |
| }) | |
| best_sim = max([s for _, s in retrieved]) if retrieved else 0.0 | |
| return best["text"], best["score"], citations, best_sim | |
| def _expand_answer(self, answer: str, context: str, max_chars: int = 300) -> str: | |
| """Expand the extracted answer with surrounding context.""" | |
| answer_pos = context.lower().find(answer.lower()) | |
| if answer_pos == -1: | |
| return answer | |
| start = answer_pos | |
| end = answer_pos + len(answer) | |
| while start > 0 and context[start - 1] not in '.!?\n': | |
| start -= 1 | |
| if answer_pos - start > max_chars // 2: | |
| break | |
| while end < len(context) and context[end] not in '.!?\n': | |
| end += 1 | |
| if end - answer_pos > max_chars // 2: | |
| break | |
| if end < len(context) and context[end] in '.!?': | |
| end += 1 | |
| expanded = context[start:end].strip() | |
| if len(expanded) < 50: | |
| sentences = context.split('.') | |
| for i, sent in enumerate(sentences): | |
| if answer.lower() in sent.lower(): | |
| result = sent.strip() | |
| if i + 1 < len(sentences) and len(result) < 100: | |
| result += ". " + sentences[i + 1].strip() | |
| return result + ("." if not result.endswith(".") else "") | |
| return expanded | |
| # Initialize KB (will be done after config is loaded) | |
| kb = None | |
| def ensure_index(): | |
| """Build index on first run or load from cache.""" | |
| try: | |
| # Try to load existing index first | |
| if kb.load(): | |
| print(f"β Loaded existing index from {config.index_directory}") | |
| return | |
| except Exception as e: | |
| print(f"β οΈ Could not load existing index: {e}") | |
| # Try to build new index if KB directory exists and has files | |
| if config.kb_directory.exists(): | |
| md_files = list(config.kb_directory.glob("*.md")) | |
| if md_files: | |
| try: | |
| print(f"π¨ Building index from {len(md_files)} markdown files...") | |
| kb.build(config.kb_directory) | |
| print(f"β Index built successfully!") | |
| except Exception as e: | |
| print(f"β οΈ Could not build index: {e}") | |
| print(f"βΉοΈ You can upload documents via the UI or add .md files to {config.kb_directory}") | |
| else: | |
| print(f"βΉοΈ No markdown files found in {config.kb_directory}") | |
| print(f"βΉοΈ Upload documents via the UI or add .md files to start using the knowledge base") | |
| else: | |
| print(f"βΉοΈ KB directory {config.kb_directory} not found. Creating it...") | |
| config.kb_directory.mkdir(exist_ok=True, parents=True) | |
| print(f"βΉοΈ Add .md files to {config.kb_directory} or upload documents via the UI") | |
| # ----------- Response Generation ----------- | |
| def format_citations(citations: List[Dict]) -> str: | |
| """Format citations as markdown list.""" | |
| if not citations: | |
| return "" | |
| lines = [] | |
| for c in citations: | |
| lines.append(f"β’ **{c['title']}** β _{c['section']}_") | |
| return "\n".join(lines) | |
| def respond(user_msg: str, history: List, uploaded_file_info: str = None) -> str: | |
| """Generate response to user query using RAG pipeline.""" | |
| user_msg = (user_msg or "").strip() | |
| if not user_msg: | |
| return config.welcome_message | |
| if kb.index is None or len(kb.metadata) == 0: | |
| return f"{config.no_answer_message}\n\n{config.upload_prompt}" | |
| source_info = f" in the uploaded file" if kb.uploaded_file_active and uploaded_file_info else " in the knowledge base" | |
| retrieved = kb.retrieve(user_msg, top_k=6) | |
| if not retrieved or (retrieved and max([s for _, s in retrieved]) < 0.20): | |
| return f"{config.no_answer_message}\n\n{config.upload_prompt}" | |
| answer, qa_score, citations, best_sim = kb.answer(user_msg, retrieved) | |
| if not answer or qa_score < 0.15 or best_sim < 0.25: | |
| return ( | |
| f"{config.no_answer_message}\n\n" | |
| f"The question seems outside the scope of what I currently know{source_info}. " | |
| f"Try uploading a relevant document, or rephrase your question if you think the information might be here." | |
| ) | |
| answer = answer.strip() | |
| if answer and answer[-1] not in '.!?': | |
| answer += "." | |
| low_confidence = (qa_score < config.confidence_threshold) or (best_sim < config.similarity_threshold) | |
| citations_md = format_citations(citations) | |
| if low_confidence: | |
| return ( | |
| f"β οΈ **Answer (Low Confidence):**\n\n{answer}\n\n" | |
| f"---\n" | |
| f"π **Related Sources:**\n{citations_md}\n\n" | |
| f"π¬ *I'm not entirely certain about this answer. If you have a more detailed document about this topic, please upload it for better accuracy.*" | |
| ) | |
| else: | |
| return ( | |
| f"β **Answer:**\n\n{answer}\n\n" | |
| f"---\n" | |
| f"π **Sources:**\n{citations_md}\n\n" | |
| f"π‘ *Say \"show more details\" to see the full context.*" | |
| ) | |
| # ----------- UI Handlers ----------- | |
| def process_message(user_input: str, history: List, uploaded_file_info: str) -> Tuple[List, Dict]: | |
| """Process user message and return updated chat history.""" | |
| user_input = (user_input or "").strip() | |
| if not user_input: | |
| return history, gr.update(value="") | |
| reply = respond(user_input, history or [], uploaded_file_info) | |
| new_history = (history or []) + [ | |
| {"role": "user", "content": user_input}, | |
| {"role": "assistant", "content": reply} | |
| ] | |
| return new_history, gr.update(value="") | |
| def process_quick(label: str, history: List, uploaded_file_info: str) -> Tuple[List, Dict]: | |
| """Process quick action button click.""" | |
| for btn_label, query in config.quick_actions: | |
| if label == btn_label: | |
| return process_message(query, history, uploaded_file_info) | |
| return history, gr.update(value="") | |
| def handle_file_upload(file): | |
| """Process uploaded file and build index.""" | |
| if file is None: | |
| return "βΉοΈ No file uploaded.", "" | |
| try: | |
| filename = Path(file.name).name | |
| num_chunks, file_type = kb.build_from_uploaded_file(file.name, filename) | |
| return ( | |
| f"β **File processed successfully!**\n\n" | |
| f"π **File:** {filename}\n" | |
| f"π **Type:** {file_type}\n" | |
| f"π’ **Chunks:** {num_chunks}\n\n" | |
| f"You can now ask questions about this document!" | |
| ), filename | |
| except Exception as e: | |
| return f"β **Error processing file:** {str(e)}\n\nPlease ensure the file is a valid PDF, DOCX, TXT, or MD file.", "" | |
| def clear_uploaded_file(): | |
| """Clear uploaded file and reload KB index.""" | |
| try: | |
| if kb.load(): | |
| return "β Switched back to knowledge base.", "", None | |
| else: | |
| kb.index = None | |
| kb.embeddings = None | |
| kb.metadata = [] | |
| kb.uploaded_file_active = False | |
| return "βΉοΈ No knowledge base found. Please upload a file or build the KB index.", "", None | |
| except Exception as e: | |
| return f"β οΈ Error: {str(e)}", "", None | |
| def rebuild_index_handler(): | |
| """Rebuild the search index from KB directory.""" | |
| try: | |
| kb.build(config.kb_directory) | |
| return "β Index rebuilt successfully! Ready to answer questions." | |
| except Exception as e: | |
| return f"β Error rebuilding index: {str(e)}" | |
| # ----------- Gradio UI ----------- | |
| def create_interface(): | |
| """Create Gradio interface with configuration.""" | |
| with gr.Blocks( | |
| title=config.client_name, | |
| theme=gr.themes.Soft(primary_hue=config.theme_color), | |
| css=""" | |
| .contain { max-width: 1200px; margin: auto; } | |
| .quick-btn { min-width: 180px !important; } | |
| """ | |
| ) as demo: | |
| uploaded_file_state = gr.State("") | |
| # Header | |
| header_text = f"# π€ {config.client_name}\n### {config.client_description}" | |
| if config.client_logo: | |
| header_text += f"\n" | |
| gr.Markdown(header_text) | |
| # File upload section | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Upload Document") | |
| file_upload = gr.File( | |
| label="Upload PDF, DOCX, TXT, or MD file", | |
| file_types=[".pdf", ".docx", ".txt", ".md"], | |
| type="filepath" | |
| ) | |
| upload_status = gr.Markdown("βΉοΈ Upload a file to ask questions about it.") | |
| with gr.Row(): | |
| clear_btn = gr.Button("π Clear & Use KB", variant="secondary", size="sm") | |
| # Main chat interface | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| chat = gr.Chatbot( | |
| height=500, | |
| show_copy_button=True, | |
| type="messages", | |
| avatar_images=(None, "https://em-content.zobj.net/source/twitter/376/robot_1f916.png") | |
| ) | |
| with gr.Row(): | |
| txt = gr.Textbox( | |
| placeholder="π¬ Ask a question about the document or knowledge base...", | |
| scale=9, | |
| show_label=False, | |
| container=False | |
| ) | |
| send = gr.Button("Send", variant="primary", scale=1) | |
| # Quick action buttons (if configured) | |
| if config.quick_actions: | |
| with gr.Accordion("β‘ Quick Actions", open=False): | |
| with gr.Row(): | |
| quick_buttons = [] | |
| for label, _ in config.quick_actions: | |
| btn = gr.Button(label, elem_classes="quick-btn", size="sm") | |
| quick_buttons.append((btn, label)) | |
| # Admin section | |
| with gr.Accordion("π§ Admin Panel", open=False): | |
| gr.Markdown( | |
| """ | |
| **Rebuild Index:** Use this after adding or modifying files in the `{config.kb_directory}` directory. | |
| The system will re-scan all markdown files and update the search index. | |
| """ | |
| ) | |
| with gr.Row(): | |
| rebuild_btn = gr.Button("π Rebuild KB Index", variant="secondary") | |
| status_msg = gr.Markdown("") | |
| # Event handlers | |
| file_upload.change( | |
| handle_file_upload, | |
| inputs=[file_upload], | |
| outputs=[upload_status, uploaded_file_state] | |
| ) | |
| clear_btn.click( | |
| clear_uploaded_file, | |
| outputs=[upload_status, uploaded_file_state, file_upload] | |
| ) | |
| send.click( | |
| process_message, | |
| inputs=[txt, chat, uploaded_file_state], | |
| outputs=[chat, txt] | |
| ) | |
| txt.submit( | |
| process_message, | |
| inputs=[txt, chat, uploaded_file_state], | |
| outputs=[chat, txt] | |
| ) | |
| if config.quick_actions: | |
| for btn, label in quick_buttons: | |
| btn.click( | |
| process_quick, | |
| inputs=[gr.State(label), chat, uploaded_file_state], | |
| outputs=[chat, txt] | |
| ) | |
| rebuild_btn.click(rebuild_index_handler, outputs=status_msg) | |
| # Footer | |
| gr.Markdown( | |
| """ | |
| --- | |
| π‘ **Tips:** | |
| - Upload a document to ask questions specifically about that file | |
| - Use "Clear & Use KB" to switch back to the knowledge base | |
| - Be specific in your questions for better results | |
| - Check the cited sources for full context | |
| """ | |
| ) | |
| return demo | |
| # ----------- Main Entry Point ----------- | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description='Configurable RAG Assistant') | |
| parser.add_argument('--config', type=str, default='config.yaml', | |
| help='Path to configuration YAML file (default: config.yaml)') | |
| args = parser.parse_args() | |
| # Load configuration | |
| config = Config(args.config) | |
| # Initialize KB with config | |
| kb = KBIndex() | |
| ensure_index() | |
| # Create and launch interface | |
| demo = create_interface() | |
| demo.launch() |