Saint5's picture
Uploading Mulitimodal Retrieval Augmented Generation System.
8b28692 verified
"""Gradio setup for the Multimodal RAG system."""
import os
import torch
import shutil
import gradio as gr
# import gc
from utils import save_cache, load_cache, save_faiss_index, load_faiss_index
from model_setup import embedding_model, model, processor
from main import preprocess_pdf, semantic_search, generate_answer_stream
torch.set_num_threads(4) # cpu thread limit
# Creating a cache directory for the retrieved chunks and index files
CACHE_DIR = "cache_dir"
os.makedirs(CACHE_DIR, exist_ok=True)
INDEX_FILE = os.path.join(CACHE_DIR, "index.faiss")
CHUNKS_FILE = os.path.join(CACHE_DIR, "chunks.json")
# Global state shared across chats
state = {
"index": None,
"chunks": None,
"pdf_path": None,
}
def handle_pdf_upload(file):
if file is None:
return "[ERROR] No file uploaded."
state["pdf_path"] = file.name
state["image_dir"] = os.path.join(CACHE_DIR, "extracted_images")
try:
if os.path.exists(INDEX_FILE) and os.path.exists(CHUNKS_FILE):
# Load from cache
state["index"] = load_faiss_index(INDEX_FILE)
state["chunks"] = load_cache(CHUNKS_FILE)
return "✅ Loaded from cache and ready for Q&A!"
else:
# Run your PDF preprocessing
index, chunks = preprocess_pdf(
state["pdf_path"],
state["image_dir"],
embedding_model=embedding_model,
index_file=INDEX_FILE,
chunks_file=CHUNKS_FILE,
use_cache=True)
state["index"] = index
state["chunks"] = chunks
# Save to cache
save_faiss_index(index, INDEX_FILE)
save_cache(chunks, CHUNKS_FILE)
return "✅ Document processed and ready for Q&A!"
except Exception as e:
return f"[⚠️ ERROR] Failed to process document: {e}"
def chat_streaming(message, history):
if state["index"] is None and state["chunks"] is None:
yield "[ERROR] Please upload and process a PDF first."
return
# Perform semantic search
retrieved_chunks = semantic_search(message, embedding_model, state["index"], state["chunks"], top_k=10)
# Stream the answer
for partial in generate_answer_stream(message, retrieved_chunks, model, processor):
yield partial
# Function for clearing the cache files before uploading another document to prevent stale cache retrieval
def manual_clear_cache():
if not os.path.exists(INDEX_FILE) or not os.path.exists(CHUNKS_FILE):
return "⚠️No cache files exists to clear."
if os.path.exists(CACHE_DIR):
shutil.rmtree(CACHE_DIR)
state["index"], state["chunks"] = None, None
return "✅ Cache cleared! You can upload a new document now."
description = """
Remember to be specific when querying for better response.
📖🧐
"""
with gr.Blocks() as demo:
gr.Markdown("## 📚Multimodal RAG System\nUpload a PDF (≤50 pages recommended) and ask questions about it.")
with gr.Row():
file_input = gr.File(label="📂Upload PDF")
upload_button = gr.Button("🔁Process PDF")
with gr.Row():
clear_cache_button = gr.Button("🧹 Clear Cache")
clear_cache_status = gr.Textbox(label="Cache Clear Status", interactive=False)
upload_status = gr.Textbox(label="Upload Status", interactive=False)
upload_button.click(handle_pdf_upload, inputs=file_input, outputs=upload_status)
clear_cache_button.click(manual_clear_cache, outputs=clear_cache_status)
chat = gr.ChatInterface(
fn=chat_streaming,
type="messages",
title="📄Ask Questions from PDF",
description=description,
examples=[["What is this document about?"]]
)
chat.queue()
demo.launch()