""" Gradio Web UI for Multimodal RAG System. Provides a visual interface for document upload and Q&A. """ import gradio as gr import tempfile import shutil from pathlib import Path from typing import List, Tuple, Generator # Add parent to path import sys sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from src.preprocessing import PDFParser, TextChunker from src.embeddings import CustomEmbedder from src.retrieval import FAISSVectorStore, Document, HybridRetriever, RAGPipeline from src.utils import get_logger logger = get_logger(__name__) # Global state vector_store = None rag_pipeline = None embedder = None def initialize_system(model_name: str = "qwen2") -> str: """Initialize the RAG system components.""" global vector_store, rag_pipeline, embedder try: embedder = CustomEmbedder() vector_store = FAISSVectorStore(embedding_dim=embedder.embedding_dim) retriever = HybridRetriever( dense_retriever=vector_store, embedder=embedder ) rag_pipeline = RAGPipeline( retriever=retriever, model_name=model_name ) return f"[OK] System initialized with model: {model_name}" except Exception as e: return f"[ERROR] Initialization failed: {str(e)}" def ingest_documents(files: List[tempfile.SpooledTemporaryFile]) -> str: """Process uploaded documents.""" global vector_store, embedder, rag_pipeline # Auto-initialize if not done if vector_store is None or embedder is None: embedder = CustomEmbedder() vector_store = FAISSVectorStore(embedding_dim=embedder.embedding_dim) if not files: return "[ERROR] No files uploaded!" try: pdf_parser = PDFParser() chunker = TextChunker(chunk_size=512, chunk_overlap=50) all_chunks = [] for file in files: file_path = Path(file.name) if file_path.suffix.lower() == ".pdf": doc = pdf_parser.parse(file_path) for page in doc.pages: chunks = chunker.chunk(page.text) for chunk in chunks: chunk.metadata["source_file"] = file_path.name chunk.metadata["page_number"] = page.page_number all_chunks.append(chunk) if not all_chunks: return "[ERROR] No text extracted from documents!" # Generate embeddings texts = [c.text for c in all_chunks] embeddings = embedder.encode(texts, show_progress=True) # Create documents documents = [ Document( id=c.chunk_id, text=c.text, embedding=embeddings[i], metadata=c.metadata ) for i, c in enumerate(all_chunks) ] vector_store.add_documents(documents) # Auto-initialize RAG pipeline from src.retrieval import SparseRetriever, DenseRetriever # Wrap vector_store with DenseRetriever dense_retriever = DenseRetriever(vector_store=vector_store, embedder=embedder) sparse_retriever = SparseRetriever() sparse_retriever.index_documents(documents) retriever = HybridRetriever( dense_retriever=dense_retriever, sparse_retriever=sparse_retriever ) rag_pipeline = RAGPipeline( retriever=retriever, model_name="llama3" ) return f"[OK] Ingested {len(documents)} chunks from {len(files)} file(s) - Ready to chat!" except Exception as e: logger.error(f"Ingestion error: {e}") return f"[ERROR] Error: {str(e)}" def query_rag( message: str, history: List[Tuple[str, str]], top_k: int = 5 ) -> str: """Query the RAG system.""" global rag_pipeline if rag_pipeline is None: return "[ERROR] Please load an index first (Documents tab → Load Index)!" if not message.strip(): return "[ERROR] Please enter a question!" try: logger.info(f"Processing query: {message}") # Query RAG pipeline response = rag_pipeline.query(message, top_k=top_k) # Format answer with sources answer = response.answer # Add source citations (RAGResponse uses 'citations') if response.citations: answer += "\n\n---\n**Sources:**\n" for i, citation in enumerate(response.citations[:3], 1): text_preview = citation.text_snippet[:150].replace("\n", " ") if citation.text_snippet else "" source = citation.source_file if citation.page: source += f" (p.{citation.page})" answer += f"\n[{i}] **{source}**: {text_preview}..." return answer except Exception as e: import traceback logger.error(f"Query error: {e}") logger.error(traceback.format_exc()) return f"[ERROR] Error: {str(e)}" def query_rag_streaming( message: str, history: List[Tuple[str, str]], top_k: int = 5 ) -> Generator[str, None, None]: """Query the RAG system with streaming response.""" global rag_pipeline if rag_pipeline is None: yield "[ERROR] Please load an index first (Documents tab → Load Index)!" return if not message.strip(): yield "[ERROR] Please enter a question!" return try: logger.info(f"Processing streaming query: {message}") # Show thinking indicator yield "Searching documents..." # Get response (we simulate streaming by yielding partial content) response = rag_pipeline.query(message, top_k=top_k) # Stream the answer word by word for effect answer = response.answer words = answer.split() partial = "" for i, word in enumerate(words): partial += word + " " if i % 5 == 0: # Update every 5 words yield partial # Add sources at the end if response.citations: sources = "\n\n---\n**Sources:**\n" for i, citation in enumerate(response.citations[:3], 1): text_preview = citation.text_snippet[:150].replace("\n", " ") if citation.text_snippet else "" source = citation.source_file if citation.page: source += f" (p.{citation.page})" sources += f"\n[{i}] **{source}**: {text_preview}..." yield partial + sources else: yield partial except Exception as e: logger.error(f"Streaming query error: {e}") yield f"[ERROR] Error: {str(e)}" def export_conversation(history: List[dict]) -> str: """Export conversation history to markdown.""" if not history: return "No conversation to export." markdown = "# RAG Conversation Export\n\n" markdown += f"*Exported on: {__import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M')}*\n\n" markdown += "---\n\n" for msg in history: role = msg.get("role", "unknown") content = msg.get("content", "") if role == "user": markdown += f"## Question\n\n{content}\n\n" else: markdown += f"## Answer\n\n{content}\n\n" markdown += "---\n\n" return markdown def save_export(history: List[dict]) -> str: """Save conversation export to file.""" import tempfile from datetime import datetime markdown = export_conversation(history) filename = f"rag_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md" filepath = Path(tempfile.gettempdir()) / filename with open(filepath, 'w', encoding='utf-8') as f: f.write(markdown) return str(filepath) def load_existing_index(index_path: str) -> str: """Load an existing FAISS index.""" global vector_store, rag_pipeline, embedder try: path = Path(index_path) if not path.exists(): return f"[ERROR] Index path not found: {index_path}" embedder = CustomEmbedder() vector_store = FAISSVectorStore(embedding_dim=embedder.embedding_dim) vector_store.load(index_path) # Import SparseRetriever for hybrid search from src.retrieval import SparseRetriever, DenseRetriever # Wrap vector_store with DenseRetriever dense_retriever = DenseRetriever(vector_store=vector_store, embedder=embedder) # Get documents for sparse indexing docs = vector_store.get_all_documents() sparse_retriever = SparseRetriever() sparse_retriever.index_documents(docs) retriever = HybridRetriever( dense_retriever=dense_retriever, sparse_retriever=sparse_retriever ) rag_pipeline = RAGPipeline( retriever=retriever, model_name="llama3" ) # PRELOAD the LLM to avoid threading issues during query logger.info("Preloading LLM (this takes 30-60 seconds)...") rag_pipeline._load_llm() logger.info("LLM preloaded successfully!") return f"[OK] Loaded index from {index_path} ({vector_store.count} documents) - Ready to chat!" except Exception as e: import traceback logger.error(traceback.format_exc()) return f"[ERROR] Error loading index: {str(e)}" # Create Gradio interface def create_ui(): """Create the Gradio interface.""" with gr.Blocks( title="Multimodal RAG System" ) as demo: gr.Markdown(""" # Multimodal RAG System ### Intelligent Document Q&A with Citations """) with gr.Tab("Chat"): with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( label="Conversation", height=500 ) msg = gr.Textbox( label="Your Question", placeholder="Ask a question about your documents...", lines=2 ) with gr.Row(): submit_btn = gr.Button("Send", variant="primary") clear_btn = gr.Button("Clear") with gr.Column(scale=1): gr.Markdown("### Settings") top_k = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Number of Sources" ) model_select = gr.Dropdown( choices=["llama3", "mistral", "qwen2", "phi3", "flan-t5"], value="llama3", label="LLM Model" ) init_btn = gr.Button("Initialize System", variant="secondary") init_status = gr.Textbox(label="Status", interactive=False) # Chat handlers (Gradio 6.x format) def respond(message, chat_history, top_k): chat_history = chat_history or [] response = query_rag(message, chat_history, top_k) # Gradio 6.x uses dict format chat_history.append({"role": "user", "content": message}) chat_history.append({"role": "assistant", "content": response}) return "", chat_history submit_btn.click(respond, [msg, chatbot, top_k], [msg, chatbot]) msg.submit(respond, [msg, chatbot, top_k], [msg, chatbot]) clear_btn.click(lambda: [], None, chatbot) init_btn.click(initialize_system, [model_select], [init_status]) with gr.Tab("Documents"): gr.Markdown("### Upload Documents") file_upload = gr.File( label="Upload PDFs", file_types=[".pdf"], file_count="multiple" ) upload_btn = gr.Button("Process Documents", variant="primary") upload_status = gr.Textbox(label="Upload Status", interactive=False) upload_btn.click(ingest_documents, [file_upload], [upload_status]) gr.Markdown("---") gr.Markdown("### Or Load Existing Index") index_path = gr.Textbox( label="Index Path", value="artifacts/index", placeholder="Path to saved FAISS index" ) load_btn = gr.Button("Load Index") load_status = gr.Textbox(label="Load Status", interactive=False) load_btn.click(load_existing_index, [index_path], [load_status]) with gr.Tab("About"): gr.Markdown(""" ## About This System This is a **Multimodal RAG (Retrieval-Augmented Generation)** system for document intelligence. ### Features - **PDF Document Processing** - Extract text from PDFs - **Hybrid Search** - Combines dense vectors + BM25 - **LLM-Powered Answers** - Generates responses with citations - **GPU Accelerated** - Fast inference with CUDA ### How to Use 1. Go to **Documents** tab and upload PDFs (or load existing index) 2. Click **Initialize System** in the Chat tab 3. Ask questions about your documents! ### Models - **Qwen2** (1.5B) - Fast, good quality - **Phi-3** (3.8B) - Better quality, slower - **Flan-T5** - Lightweight option """) return demo if __name__ == "__main__": demo = create_ui() # Enable queue for long-running tasks (LLM loading takes ~30s) demo.queue(default_concurrency_limit=1) demo.launch( server_name="0.0.0.0", server_port=7860, share=False )