Saint5 commited on
Commit
c3a4b6a
·
verified ·
1 Parent(s): 7613313

Uploading Mulitimodal Retrieval Augmented Generation System.

Browse files
Files changed (2) hide show
  1. app.py +72 -58
  2. utils.py +3 -5
app.py CHANGED
@@ -1,16 +1,18 @@
1
  """Gradio setup for the Multimodal RAG system."""
2
  import os
3
  import torch
 
4
  import gradio as gr
5
  # import gc
6
 
7
- from utils import load_faiss_index, load_cache
8
  from model_setup import embedding_model, model, processor
9
  from main import preprocess_pdf, semantic_search, generate_answer_stream
10
 
11
  torch.set_num_threads(4) # cpu thread limit
12
 
13
- CACHE_DIR = "cache"
 
14
  os.makedirs(CACHE_DIR, exist_ok=True)
15
 
16
  INDEX_FILE = os.path.join(CACHE_DIR, "index.faiss")
@@ -21,73 +23,85 @@ state = {
21
  "index": None,
22
  "chunks": None,
23
  "pdf_path": None,
24
- "image_dir": "extracted_images",
25
  }
26
 
27
- # Function to clear cache to prevent stale cache retrieval if new document is uploaded
28
- def clear_cache_files():
29
- if os.path.exists(INDEX_FILE):
30
- os.remove(INDEX_FILE)
31
- if os.path.exists(CHUNKS_FILE):
32
- os.remove(CHUNKS_FILE)
33
- state["index"], state["chunks"] = None, None
34
-
35
  def handle_pdf_upload(file):
36
  if file is None:
37
- return "[ERROR ⚠️] No file uploaded."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Save uploaded file to cache directory to ensure accessibility
40
- pdf_path = os.path.join(CACHE_DIR, os.path.basename(file.name))
41
- with open(pdf_path, "wb") as f_out:
42
- f_out.write(file.file.read())
43
-
44
- if state["pdf_path"] != pdf_path:
45
- clear_cache_files()
46
 
47
- state["pdf_path"] = pdf_path
 
48
 
49
- index, chunks = preprocess_pdf(
50
- file_path=state["pdf_path"],
51
- image_dir=state["image_dir"],
52
- embedding_model=embedding_model,
53
- index_file=INDEX_FILE,
54
- chunks_file=CHUNKS_FILE,
55
- use_cache=True
56
- )
57
- state["index"], state["chunks"] = index, chunks
58
- return "✅ Document processed and ready for Q&A!"
59
 
60
- def chat_streaming(message, history):
61
- if state["index"] is None or state["chunks"] is None:
62
- yield "[ERROR ⚠️] Please upload and process a PDF first."
63
- return
64
- retrieved_chunks = semantic_search(message, embedding_model, state["index"], state["chunks"], top_k=10)
65
- for partial in generate_answer_stream(message, retrieved_chunks, model, processor):
66
- yield partial
67
 
 
 
68
 
69
  description = """
70
- Remember to be specific when querying for better response.
71
- 📖🧐
72
  """
73
- # Gradio setup
74
- with gr.Blocks() as demo:
75
- gr.Markdown("""## 📚Simple Multimodal RAG System
76
- Upload a PDF (≤50 pages recommended) and ask questions about it.""")
77
- with gr.Row():
78
- file_input = gr.File(label="📂Upload PDF")
79
- upload_button = gr.Button("🔁Process PDF")
80
- upload_status = gr.Textbox(label="Upload Status", interactive=False)
81
-
82
- upload_button.click(handle_pdf_upload, inputs=file_input, outputs=upload_status)
83
-
84
- chat = gr.ChatInterface(
85
- fn=chat_streaming,
86
- type="messages",
87
- title="📄😃 Ask Questions on your PDF!",
88
- description=description,
89
- examples=[["What is this document about?"]]
90
- )
91
- chat.queue()
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  demo.launch()
 
1
  """Gradio setup for the Multimodal RAG system."""
2
  import os
3
  import torch
4
+ import shutil
5
  import gradio as gr
6
  # import gc
7
 
8
+ from utils import save_cache, load_cache, save_faiss_index, load_faiss_index
9
  from model_setup import embedding_model, model, processor
10
  from main import preprocess_pdf, semantic_search, generate_answer_stream
11
 
12
  torch.set_num_threads(4) # cpu thread limit
13
 
14
+ # Creating a cache directory for the retrieved chunks and index files
15
+ CACHE_DIR = "cache_dir"
16
  os.makedirs(CACHE_DIR, exist_ok=True)
17
 
18
  INDEX_FILE = os.path.join(CACHE_DIR, "index.faiss")
 
23
  "index": None,
24
  "chunks": None,
25
  "pdf_path": None,
 
26
  }
27
 
 
 
 
 
 
 
 
 
28
  def handle_pdf_upload(file):
29
  if file is None:
30
+ return "[ERROR] No file uploaded."
31
+
32
+ state["pdf_path"] = file.name
33
+ state["image_dir"] = os.path.join(CACHE_DIR, "extracted_images")
34
+
35
+ if os.path.exists(INDEX_FILE) and os.path.exists(CHUNKS_FILE):
36
+ # Load from cache
37
+ state["index"] = load_faiss_index(INDEX_FILE)
38
+ state["chunks"] = load_cache(CHUNKS_FILE)
39
+ return "✅ Loaded from cache and ready for Q&A!"
40
+ else:
41
+ # Run your PDF preprocessing
42
+ index, chunks = preprocess_pdf(
43
+ state["pdf_path"],
44
+ state["image_dir"],
45
+ embedding_model=embedding_model,
46
+ index_file=INDEX_FILE,
47
+ chunks_file=CHUNKS_FILE,
48
+ use_cache=True)
49
+ state["index"] = index
50
+ state["chunks"] = chunks
51
+
52
+ # Save to cache
53
+ save_faiss_index(index, INDEX_FILE)
54
+ save_cache(chunks, CHUNKS_FILE)
55
+
56
+ return "✅ Document processed and ready for Q&A!"
57
 
58
+ def chat_streaming(message, history):
59
+ if state["index"] is None or state["chunks"] is None:
60
+ yield "[ERROR] Please upload and process a PDF first."
 
 
 
 
61
 
62
+ # Perform semantic search
63
+ retrieved_chunks = semantic_search(message, embedding_model, state["index"], state["chunks"], top_k=10)
64
 
65
+ # Stream the answer
66
+ for partial in generate_answer_stream(message, retrieved_chunks, model, processor):
67
+ yield partial
 
 
 
 
 
 
 
68
 
69
+ # Function for clearing the cache files before uploading another document to prevent stale cache retrieval
70
+ def manual_clear_cache():
71
+ if not os.path.exists(INDEX_FILE) and not os.path.exists(CHUNKS_FILE):
72
+ return "⚠️No cache files exists to clear."
73
+ if os.path.exists(CACHE_DIR):
74
+ shutil.rmtree(CACHE_DIR)
 
75
 
76
+ state["index"], state["chunks"] = None, None
77
+ return "✅ Cache cleared! You can upload a new document now."
78
 
79
  description = """
80
+ Remember to be specific when querying for better response.
81
+ 📖🧐
82
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown("## 📚Multimodal RAG System\nUpload a PDF (≤50 pages recommended) and ask questions about it.")
86
+
87
+ with gr.Row():
88
+ file_input = gr.File(label="📂Upload PDF")
89
+ upload_button = gr.Button("🔁Process PDF")
90
+
91
+ with gr.Row():
92
+ clear_cache_button = gr.Button("🧹 Clear Cache")
93
+ clear_cache_status = gr.Textbox(label="Cache Clear Status", interactive=False)
94
+
95
+ upload_status = gr.Textbox(label="Upload Status", interactive=False)
96
+ upload_button.click(handle_pdf_upload, inputs=file_input, outputs=upload_status)
97
+ clear_cache_button.click(manual_clear_cache, outputs=clear_cache_status)
98
+
99
+ chat = gr.ChatInterface(
100
+ fn=chat_streaming,
101
+ type="messages",
102
+ title="📄Ask Questions from PDF",
103
+ description=description,
104
+ examples=[["What is this document about?"]]
105
+ )
106
+ chat.queue()
107
  demo.launch()
utils.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import gc
5
  import json
6
  import torch
 
7
  from typing import List, Dict
8
  import faiss
9
  import numpy as np
@@ -53,11 +54,8 @@ def load_faiss_index(filepath: str):
53
  # Deleting extracted images directory after captioning
54
  def cleanup_images(image_dir: str):
55
  try:
56
- for filename in os.listdir(image_dir):
57
- file_path = os.path.join(image_dir, filename)
58
- if os.path.isfile(file_path):
59
- os.remove(file_path)
60
- print(f"[INFO] Cleaned up extracted images in: {image_dir}")
61
  except Exception as e:
62
  print(f"[WARNING] Failed to delete some images in {image_dir}: {e}")
63
 
 
4
  import gc
5
  import json
6
  import torch
7
+ import shutil
8
  from typing import List, Dict
9
  import faiss
10
  import numpy as np
 
54
  # Deleting extracted images directory after captioning
55
  def cleanup_images(image_dir: str):
56
  try:
57
+ shutil.rmtree(image_dir)
58
+ print(f"[INFO] Cleaned up extracted images directory: {image_dir}")
 
 
 
59
  except Exception as e:
60
  print(f"[WARNING] Failed to delete some images in {image_dir}: {e}")
61