Spaces:
Sleeping
Sleeping
| """Gradio setup for the Multimodal RAG system.""" | |
| import os | |
| import torch | |
| import shutil | |
| import gradio as gr | |
| # import gc | |
| from utils import save_cache, load_cache, save_faiss_index, load_faiss_index | |
| from model_setup import embedding_model, model, processor | |
| from main import preprocess_pdf, semantic_search, generate_answer_stream | |
| torch.set_num_threads(4) # cpu thread limit | |
| # Creating a cache directory for the retrieved chunks and index files | |
| CACHE_DIR = "cache_dir" | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| INDEX_FILE = os.path.join(CACHE_DIR, "index.faiss") | |
| CHUNKS_FILE = os.path.join(CACHE_DIR, "chunks.json") | |
| # Global state shared across chats | |
| state = { | |
| "index": None, | |
| "chunks": None, | |
| "pdf_path": None, | |
| } | |
| def handle_pdf_upload(file): | |
| if file is None: | |
| return "[ERROR] No file uploaded." | |
| state["pdf_path"] = file.name | |
| state["image_dir"] = os.path.join(CACHE_DIR, "extracted_images") | |
| try: | |
| if os.path.exists(INDEX_FILE) and os.path.exists(CHUNKS_FILE): | |
| # Load from cache | |
| state["index"] = load_faiss_index(INDEX_FILE) | |
| state["chunks"] = load_cache(CHUNKS_FILE) | |
| return "✅ Loaded from cache and ready for Q&A!" | |
| else: | |
| # Run your PDF preprocessing | |
| index, chunks = preprocess_pdf( | |
| state["pdf_path"], | |
| state["image_dir"], | |
| embedding_model=embedding_model, | |
| index_file=INDEX_FILE, | |
| chunks_file=CHUNKS_FILE, | |
| use_cache=True) | |
| state["index"] = index | |
| state["chunks"] = chunks | |
| # Save to cache | |
| save_faiss_index(index, INDEX_FILE) | |
| save_cache(chunks, CHUNKS_FILE) | |
| return "✅ Document processed and ready for Q&A!" | |
| except Exception as e: | |
| return f"[⚠️ ERROR] Failed to process document: {e}" | |
| def chat_streaming(message, history): | |
| if state["index"] is None and state["chunks"] is None: | |
| yield "[ERROR] Please upload and process a PDF first." | |
| return | |
| # Perform semantic search | |
| retrieved_chunks = semantic_search(message, embedding_model, state["index"], state["chunks"], top_k=10) | |
| # Stream the answer | |
| for partial in generate_answer_stream(message, retrieved_chunks, model, processor): | |
| yield partial | |
| # Function for clearing the cache files before uploading another document to prevent stale cache retrieval | |
| def manual_clear_cache(): | |
| if not os.path.exists(INDEX_FILE) or not os.path.exists(CHUNKS_FILE): | |
| return "⚠️No cache files exists to clear." | |
| if os.path.exists(CACHE_DIR): | |
| shutil.rmtree(CACHE_DIR) | |
| state["index"], state["chunks"] = None, None | |
| return "✅ Cache cleared! You can upload a new document now." | |
| description = """ | |
| Remember to be specific when querying for better response. | |
| 📖🧐 | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 📚Multimodal RAG System\nUpload a PDF (≤50 pages recommended) and ask questions about it.") | |
| with gr.Row(): | |
| file_input = gr.File(label="📂Upload PDF") | |
| upload_button = gr.Button("🔁Process PDF") | |
| with gr.Row(): | |
| clear_cache_button = gr.Button("🧹 Clear Cache") | |
| clear_cache_status = gr.Textbox(label="Cache Clear Status", interactive=False) | |
| upload_status = gr.Textbox(label="Upload Status", interactive=False) | |
| upload_button.click(handle_pdf_upload, inputs=file_input, outputs=upload_status) | |
| clear_cache_button.click(manual_clear_cache, outputs=clear_cache_status) | |
| chat = gr.ChatInterface( | |
| fn=chat_streaming, | |
| type="messages", | |
| title="📄Ask Questions from PDF", | |
| description=description, | |
| examples=[["What is this document about?"]] | |
| ) | |
| chat.queue() | |
| demo.launch() | |