multimodal-rag / src /web /app.py
itachi
Remove all remaining emojis
2688667
"""
Gradio Web UI for Multimodal RAG System.
Provides a visual interface for document upload and Q&A.
"""
import gradio as gr
import tempfile
import shutil
from pathlib import Path
from typing import List, Tuple, Generator
# Add parent to path
import sys
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from src.preprocessing import PDFParser, TextChunker
from src.embeddings import CustomEmbedder
from src.retrieval import FAISSVectorStore, Document, HybridRetriever, RAGPipeline
from src.utils import get_logger
logger = get_logger(__name__)
# Global state
vector_store = None
rag_pipeline = None
embedder = None
def initialize_system(model_name: str = "qwen2") -> str:
"""Initialize the RAG system components."""
global vector_store, rag_pipeline, embedder
try:
embedder = CustomEmbedder()
vector_store = FAISSVectorStore(embedding_dim=embedder.embedding_dim)
retriever = HybridRetriever(
dense_retriever=vector_store,
embedder=embedder
)
rag_pipeline = RAGPipeline(
retriever=retriever,
model_name=model_name
)
return f"[OK] System initialized with model: {model_name}"
except Exception as e:
return f"[ERROR] Initialization failed: {str(e)}"
def ingest_documents(files: List[tempfile.SpooledTemporaryFile]) -> str:
"""Process uploaded documents."""
global vector_store, embedder, rag_pipeline
# Auto-initialize if not done
if vector_store is None or embedder is None:
embedder = CustomEmbedder()
vector_store = FAISSVectorStore(embedding_dim=embedder.embedding_dim)
if not files:
return "[ERROR] No files uploaded!"
try:
pdf_parser = PDFParser()
chunker = TextChunker(chunk_size=512, chunk_overlap=50)
all_chunks = []
for file in files:
file_path = Path(file.name)
if file_path.suffix.lower() == ".pdf":
doc = pdf_parser.parse(file_path)
for page in doc.pages:
chunks = chunker.chunk(page.text)
for chunk in chunks:
chunk.metadata["source_file"] = file_path.name
chunk.metadata["page_number"] = page.page_number
all_chunks.append(chunk)
if not all_chunks:
return "[ERROR] No text extracted from documents!"
# Generate embeddings
texts = [c.text for c in all_chunks]
embeddings = embedder.encode(texts, show_progress=True)
# Create documents
documents = [
Document(
id=c.chunk_id,
text=c.text,
embedding=embeddings[i],
metadata=c.metadata
)
for i, c in enumerate(all_chunks)
]
vector_store.add_documents(documents)
# Auto-initialize RAG pipeline
from src.retrieval import SparseRetriever, DenseRetriever
# Wrap vector_store with DenseRetriever
dense_retriever = DenseRetriever(vector_store=vector_store, embedder=embedder)
sparse_retriever = SparseRetriever()
sparse_retriever.index_documents(documents)
retriever = HybridRetriever(
dense_retriever=dense_retriever,
sparse_retriever=sparse_retriever
)
rag_pipeline = RAGPipeline(
retriever=retriever,
model_name="llama3"
)
return f"[OK] Ingested {len(documents)} chunks from {len(files)} file(s) - Ready to chat!"
except Exception as e:
logger.error(f"Ingestion error: {e}")
return f"[ERROR] Error: {str(e)}"
def query_rag(
message: str,
history: List[Tuple[str, str]],
top_k: int = 5
) -> str:
"""Query the RAG system."""
global rag_pipeline
if rag_pipeline is None:
return "[ERROR] Please load an index first (Documents tab → Load Index)!"
if not message.strip():
return "[ERROR] Please enter a question!"
try:
logger.info(f"Processing query: {message}")
# Query RAG pipeline
response = rag_pipeline.query(message, top_k=top_k)
# Format answer with sources
answer = response.answer
# Add source citations (RAGResponse uses 'citations')
if response.citations:
answer += "\n\n---\n**Sources:**\n"
for i, citation in enumerate(response.citations[:3], 1):
text_preview = citation.text_snippet[:150].replace("\n", " ") if citation.text_snippet else ""
source = citation.source_file
if citation.page:
source += f" (p.{citation.page})"
answer += f"\n[{i}] **{source}**: {text_preview}..."
return answer
except Exception as e:
import traceback
logger.error(f"Query error: {e}")
logger.error(traceback.format_exc())
return f"[ERROR] Error: {str(e)}"
def query_rag_streaming(
message: str,
history: List[Tuple[str, str]],
top_k: int = 5
) -> Generator[str, None, None]:
"""Query the RAG system with streaming response."""
global rag_pipeline
if rag_pipeline is None:
yield "[ERROR] Please load an index first (Documents tab → Load Index)!"
return
if not message.strip():
yield "[ERROR] Please enter a question!"
return
try:
logger.info(f"Processing streaming query: {message}")
# Show thinking indicator
yield "Searching documents..."
# Get response (we simulate streaming by yielding partial content)
response = rag_pipeline.query(message, top_k=top_k)
# Stream the answer word by word for effect
answer = response.answer
words = answer.split()
partial = ""
for i, word in enumerate(words):
partial += word + " "
if i % 5 == 0: # Update every 5 words
yield partial
# Add sources at the end
if response.citations:
sources = "\n\n---\n**Sources:**\n"
for i, citation in enumerate(response.citations[:3], 1):
text_preview = citation.text_snippet[:150].replace("\n", " ") if citation.text_snippet else ""
source = citation.source_file
if citation.page:
source += f" (p.{citation.page})"
sources += f"\n[{i}] **{source}**: {text_preview}..."
yield partial + sources
else:
yield partial
except Exception as e:
logger.error(f"Streaming query error: {e}")
yield f"[ERROR] Error: {str(e)}"
def export_conversation(history: List[dict]) -> str:
"""Export conversation history to markdown."""
if not history:
return "No conversation to export."
markdown = "# RAG Conversation Export\n\n"
markdown += f"*Exported on: {__import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M')}*\n\n"
markdown += "---\n\n"
for msg in history:
role = msg.get("role", "unknown")
content = msg.get("content", "")
if role == "user":
markdown += f"## Question\n\n{content}\n\n"
else:
markdown += f"## Answer\n\n{content}\n\n"
markdown += "---\n\n"
return markdown
def save_export(history: List[dict]) -> str:
"""Save conversation export to file."""
import tempfile
from datetime import datetime
markdown = export_conversation(history)
filename = f"rag_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
filepath = Path(tempfile.gettempdir()) / filename
with open(filepath, 'w', encoding='utf-8') as f:
f.write(markdown)
return str(filepath)
def load_existing_index(index_path: str) -> str:
"""Load an existing FAISS index."""
global vector_store, rag_pipeline, embedder
try:
path = Path(index_path)
if not path.exists():
return f"[ERROR] Index path not found: {index_path}"
embedder = CustomEmbedder()
vector_store = FAISSVectorStore(embedding_dim=embedder.embedding_dim)
vector_store.load(index_path)
# Import SparseRetriever for hybrid search
from src.retrieval import SparseRetriever, DenseRetriever
# Wrap vector_store with DenseRetriever
dense_retriever = DenseRetriever(vector_store=vector_store, embedder=embedder)
# Get documents for sparse indexing
docs = vector_store.get_all_documents()
sparse_retriever = SparseRetriever()
sparse_retriever.index_documents(docs)
retriever = HybridRetriever(
dense_retriever=dense_retriever,
sparse_retriever=sparse_retriever
)
rag_pipeline = RAGPipeline(
retriever=retriever,
model_name="llama3"
)
# PRELOAD the LLM to avoid threading issues during query
logger.info("Preloading LLM (this takes 30-60 seconds)...")
rag_pipeline._load_llm()
logger.info("LLM preloaded successfully!")
return f"[OK] Loaded index from {index_path} ({vector_store.count} documents) - Ready to chat!"
except Exception as e:
import traceback
logger.error(traceback.format_exc())
return f"[ERROR] Error loading index: {str(e)}"
# Create Gradio interface
def create_ui():
"""Create the Gradio interface."""
with gr.Blocks(
title="Multimodal RAG System"
) as demo:
gr.Markdown("""
# Multimodal RAG System
### Intelligent Document Q&A with Citations
""")
with gr.Tab("Chat"):
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Conversation",
height=500
)
msg = gr.Textbox(
label="Your Question",
placeholder="Ask a question about your documents...",
lines=2
)
with gr.Row():
submit_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Column(scale=1):
gr.Markdown("### Settings")
top_k = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Number of Sources"
)
model_select = gr.Dropdown(
choices=["llama3", "mistral", "qwen2", "phi3", "flan-t5"],
value="llama3",
label="LLM Model"
)
init_btn = gr.Button("Initialize System", variant="secondary")
init_status = gr.Textbox(label="Status", interactive=False)
# Chat handlers (Gradio 6.x format)
def respond(message, chat_history, top_k):
chat_history = chat_history or []
response = query_rag(message, chat_history, top_k)
# Gradio 6.x uses dict format
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": response})
return "", chat_history
submit_btn.click(respond, [msg, chatbot, top_k], [msg, chatbot])
msg.submit(respond, [msg, chatbot, top_k], [msg, chatbot])
clear_btn.click(lambda: [], None, chatbot)
init_btn.click(initialize_system, [model_select], [init_status])
with gr.Tab("Documents"):
gr.Markdown("### Upload Documents")
file_upload = gr.File(
label="Upload PDFs",
file_types=[".pdf"],
file_count="multiple"
)
upload_btn = gr.Button("Process Documents", variant="primary")
upload_status = gr.Textbox(label="Upload Status", interactive=False)
upload_btn.click(ingest_documents, [file_upload], [upload_status])
gr.Markdown("---")
gr.Markdown("### Or Load Existing Index")
index_path = gr.Textbox(
label="Index Path",
value="artifacts/index",
placeholder="Path to saved FAISS index"
)
load_btn = gr.Button("Load Index")
load_status = gr.Textbox(label="Load Status", interactive=False)
load_btn.click(load_existing_index, [index_path], [load_status])
with gr.Tab("About"):
gr.Markdown("""
## About This System
This is a **Multimodal RAG (Retrieval-Augmented Generation)** system for document intelligence.
### Features
- **PDF Document Processing** - Extract text from PDFs
- **Hybrid Search** - Combines dense vectors + BM25
- **LLM-Powered Answers** - Generates responses with citations
- **GPU Accelerated** - Fast inference with CUDA
### How to Use
1. Go to **Documents** tab and upload PDFs (or load existing index)
2. Click **Initialize System** in the Chat tab
3. Ask questions about your documents!
### Models
- **Qwen2** (1.5B) - Fast, good quality
- **Phi-3** (3.8B) - Better quality, slower
- **Flan-T5** - Lightweight option
""")
return demo
if __name__ == "__main__":
demo = create_ui()
# Enable queue for long-running tasks (LLM loading takes ~30s)
demo.queue(default_concurrency_limit=1)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)