Saint5 commited on
Commit
8cb5b3d
·
verified ·
1 Parent(s): 0a3175d

Uploading Mulitimodal Retrieval Augmented Generation System.

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +14 -12
  3. main.py +4 -5
  4. model_setup.py +2 -2
  5. utils.py +2 -2
README.md CHANGED
@@ -24,4 +24,4 @@ A **Multimodal Retrieval-Augmented Generation (RAG) system** that allows users t
24
  - Streams answers from the LLM using Gradio interface.
25
  - Efficient memory usage with bitsandbytes 4-bit quantization.
26
 
27
- The **[google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it)** is both used to generate image descriptions for the extracted images and for text generation for the RAG system.
 
24
  - Streams answers from the LLM using Gradio interface.
25
  - Efficient memory usage with bitsandbytes 4-bit quantization.
26
 
27
+ The **[google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it)** is both used to generate image descriptions for the extracted images and for text generation for the RAG system.
app.py CHANGED
@@ -5,7 +5,9 @@ import os
5
  import hashlib
6
  import torch
7
  import gradio as gr
 
8
 
 
9
  from model_setup import embedding_model, model, processor
10
  from main import preprocess_pdf, semantic_search, generate_answer_stream
11
 
@@ -30,7 +32,7 @@ state = {
30
 
31
  def _make_cache_names(pdf_path: str) -> tuple[str, str]:
32
  """Generate unique cache file names per PDF based on hash of filename."""
33
- pdf_hash = hashlib.md5(pdf_path.encode()).hexdigest[:8] # Shorten for readability
34
  base_name = os.path.splitext(os.path.basename(pdf_path))[0]
35
  index_file = os.path.join(CACHE_DIR, f"{base_name}_{pdf_hash}_index.faiss")
36
  chunks_file = os.path.join(CACHE_DIR, f"{base_name}_{pdf_hash}_chunks.json")
@@ -40,7 +42,10 @@ def handle_pdf_upload(file):
40
  if file is None:
41
  return "[ERROR ⚠️] No file uploaded.", gr.update()
42
 
43
- new_pdf_path = file.name
 
 
 
44
  state["pdf_path"] = new_pdf_path
45
 
46
  # Create unique cache file names for this PDF
@@ -56,6 +61,7 @@ def handle_pdf_upload(file):
56
  use_cache=True # allow cache for the PDF
57
  )
58
  state["index"], state["chunks"] = index, chunks
 
59
 
60
  # Store in processed_pdfs for later selection
61
  pdf_key = os.path.basename(state["pdf_path"])
@@ -71,18 +77,13 @@ def handle_pdf_selection(pdf_name):
71
  if pdf_name not in state["processed_pdfs"]:
72
  return "[ERROR] Selected PDF not found in cache."
73
 
74
- state["pdf_path"] = pdf_name
75
  state["index_file"], state["chunks_file"] = state["processed_pdfs"][pdf_name]
76
 
77
  # Reload index + chunks from cache
78
- index, chunks = preprocess_pdf(
79
- file_path=pdf_name,
80
- image_dir=state["image_dir"],
81
- embedding_model=embedding_model,
82
- index_file=state["index_file"],
83
- chunks_file=state["chunks_file"],
84
- use_cache=True
85
- )
86
  state["index"], state["chunks"] = index, chunks
87
  return f"📂 Switched to cached PDF: {pdf_name}"
88
 
@@ -93,6 +94,7 @@ def chat_streaming(message, history):
93
 
94
  # Perform semantic search
95
  retrieved_chunks = semantic_search(message, embedding_model, state["index"], state["chunks"], top_k=10)
 
96
 
97
  # Stream the answer
98
  for partial in generate_answer_stream(message, retrieved_chunks, model, processor):
@@ -111,7 +113,7 @@ with gr.Blocks() as demo:
111
 
112
  with gr.Row():
113
  file_input = gr.File(label="📂Upload PDF")
114
- upload_button = gr.Button("Process PDF")
115
 
116
  upload_status = gr.Textbox(label="Upload Status", interactive=False)
117
  pdf_selector = gr.Dropdown(label="📄 Select a Processed PDF", choices=[], interactive=True)
 
5
  import hashlib
6
  import torch
7
  import gradio as gr
8
+ # import gc
9
 
10
+ from utils import load_faiss_index, load_cache
11
  from model_setup import embedding_model, model, processor
12
  from main import preprocess_pdf, semantic_search, generate_answer_stream
13
 
 
32
 
33
  def _make_cache_names(pdf_path: str) -> tuple[str, str]:
34
  """Generate unique cache file names per PDF based on hash of filename."""
35
+ pdf_hash = hashlib.md5(pdf_path.encode()).hexdigest()[:8] # Shorten for readability
36
  base_name = os.path.splitext(os.path.basename(pdf_path))[0]
37
  index_file = os.path.join(CACHE_DIR, f"{base_name}_{pdf_hash}_index.faiss")
38
  chunks_file = os.path.join(CACHE_DIR, f"{base_name}_{pdf_hash}_chunks.json")
 
42
  if file is None:
43
  return "[ERROR ⚠️] No file uploaded.", gr.update()
44
 
45
+ # Save uploaded file to cache directory to ensure accessibility
46
+ new_pdf_path = os.path.join(CACHE_DIR, file.name)
47
+ with open(new_pdf_path, "wb") as f_out:
48
+ f_out.write(file.read())
49
  state["pdf_path"] = new_pdf_path
50
 
51
  # Create unique cache file names for this PDF
 
61
  use_cache=True # allow cache for the PDF
62
  )
63
  state["index"], state["chunks"] = index, chunks
64
+ # gc.collect() # Free memeory after PDF processing
65
 
66
  # Store in processed_pdfs for later selection
67
  pdf_key = os.path.basename(state["pdf_path"])
 
77
  if pdf_name not in state["processed_pdfs"]:
78
  return "[ERROR] Selected PDF not found in cache."
79
 
80
+ state["pdf_path"] = os.path.join(CACHE_DIR, pdf_name)
81
  state["index_file"], state["chunks_file"] = state["processed_pdfs"][pdf_name]
82
 
83
  # Reload index + chunks from cache
84
+ index = load_faiss_index(state["index_file"])
85
+ chunks = load_cache(state["chunks_file"])
86
+
 
 
 
 
 
87
  state["index"], state["chunks"] = index, chunks
88
  return f"📂 Switched to cached PDF: {pdf_name}"
89
 
 
94
 
95
  # Perform semantic search
96
  retrieved_chunks = semantic_search(message, embedding_model, state["index"], state["chunks"], top_k=10)
97
+ # gc.collect() # Free memory after semantic search
98
 
99
  # Stream the answer
100
  for partial in generate_answer_stream(message, retrieved_chunks, model, processor):
 
113
 
114
  with gr.Row():
115
  file_input = gr.File(label="📂Upload PDF")
116
+ # upload_button = gr.Button("Process PDF")
117
 
118
  upload_status = gr.Textbox(label="Upload Status", interactive=False)
119
  pdf_selector = gr.Dropdown(label="📄 Select a Processed PDF", choices=[], interactive=True)
main.py CHANGED
@@ -9,10 +9,9 @@ import re
9
  import gc
10
  import numpy as np
11
 
12
- # from time import time
13
  from typing import List, Dict, Tuple
14
  from PIL import Image
15
- # from threading import Thread
16
 
17
  from langchain.text_splitter import RecursiveCharacterTextSplitter
18
  from transformers import TextIteratorStreamer
@@ -106,8 +105,8 @@ def generate_image_descriptions(image_paths):
106
  captions.append({"image_path": image_path, "caption": "<---image---> (Captioning failed)"}) # Add a placeholder caption
107
  continue
108
  finally:
109
- clear_gpu_cache()
110
  gc.collect()
 
111
  return captions
112
 
113
  # Cleaning the captions from the extracted images
@@ -301,13 +300,13 @@ def generate_answer_stream(query, retrieved_chunks, model, processor):
301
  with torch.inference_mode():
302
  model.generate(**inputs, streamer=streamer, use_cache=True, max_new_tokens=512)
303
  gc.collect() # Free memory after model generation
304
-
305
  accumulated = ""
306
  for new_text in streamer:
307
  # time.sleep(0.2)
308
  accumulated += new_text
309
  yield accumulated
310
 
311
- # Free memory after streaming
312
  clear_gpu_cache()
313
  gc.collect()
 
9
  import gc
10
  import numpy as np
11
 
12
+
13
  from typing import List, Dict, Tuple
14
  from PIL import Image
 
15
 
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from transformers import TextIteratorStreamer
 
105
  captions.append({"image_path": image_path, "caption": "<---image---> (Captioning failed)"}) # Add a placeholder caption
106
  continue
107
  finally:
 
108
  gc.collect()
109
+ clear_gpu_cache()
110
  return captions
111
 
112
  # Cleaning the captions from the extracted images
 
300
  with torch.inference_mode():
301
  model.generate(**inputs, streamer=streamer, use_cache=True, max_new_tokens=512)
302
  gc.collect() # Free memory after model generation
303
+
304
  accumulated = ""
305
  for new_text in streamer:
306
  # time.sleep(0.2)
307
  accumulated += new_text
308
  yield accumulated
309
 
310
+ # Free memory after streaming is complete
311
  clear_gpu_cache()
312
  gc.collect()
model_setup.py CHANGED
@@ -7,6 +7,7 @@ import gc
7
  from sentence_transformers import SentenceTransformer
8
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, BitsAndBytesConfig
9
  from utils import clear_gpu_cache
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
  # Embedding model
@@ -35,6 +36,5 @@ model.eval()
35
  # Processor
36
  processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
37
 
38
- # Free memory
39
  clear_gpu_cache()
40
- gc.collect()
 
7
  from sentence_transformers import SentenceTransformer
8
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, BitsAndBytesConfig
9
  from utils import clear_gpu_cache
10
+
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
  # Embedding model
 
36
  # Processor
37
  processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
38
 
 
39
  clear_gpu_cache()
40
+ gc.collect()
utils.py CHANGED
@@ -3,10 +3,10 @@
3
  import os
4
  import gc
5
  import json
 
6
  from typing import List, Dict
7
  import faiss
8
  import numpy as np
9
- import torch
10
 
11
  def save_cache(data: List[Dict], filepath: str) -> None:
12
  """Saving the chunks and the embeddings for easy retrieval in .json format"""
@@ -61,7 +61,7 @@ def cleanup_images(image_dir: str):
61
  except Exception as e:
62
  print(f"[WARNING] Failed to delete some images in {image_dir}: {e}")
63
 
64
- # Just being agnostic because this space may only be using CPU but why not?
65
  def clear_gpu_cache():
66
  """Clear GPU cache and run garbage collection(saving on memory)."""
67
  if torch.cuda.is_available():
 
3
  import os
4
  import gc
5
  import json
6
+ import torch
7
  from typing import List, Dict
8
  import faiss
9
  import numpy as np
 
10
 
11
  def save_cache(data: List[Dict], filepath: str) -> None:
12
  """Saving the chunks and the embeddings for easy retrieval in .json format"""
 
61
  except Exception as e:
62
  print(f"[WARNING] Failed to delete some images in {image_dir}: {e}")
63
 
64
+ # Just being agnostic because my space may only be using CPU but why not?
65
  def clear_gpu_cache():
66
  """Clear GPU cache and run garbage collection(saving on memory)."""
67
  if torch.cuda.is_available():