Saint5 commited on
Commit
4a9c569
·
verified ·
1 Parent(s): 28edfba

Uploading Mulitimodal Retrieval Augmented Generation System.

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. main.py +2 -2
  3. model_setup.py +35 -0
app.py CHANGED
@@ -6,8 +6,8 @@ import hashlib
6
  import torch
7
  import gradio as gr
8
 
9
- from setup.multimodal_rag.model import embedding_model, model, processor
10
- from setup.multimodal_rag.main import preprocess_pdf, semantic_search, generate_answer_stream
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
 
6
  import torch
7
  import gradio as gr
8
 
9
+ from model_setup import embedding_model, model, processor
10
+ from main import preprocess_pdf, semantic_search, generate_answer_stream
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
main.py CHANGED
@@ -16,13 +16,13 @@ from PIL import Image
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from transformers import TextIteratorStreamer
18
 
19
- from setup.multimodal_rag.utils import (
20
  save_cache, load_cache,
21
  init_faiss_indexflatip, add_embeddings_to_index,
22
  search_faiss_index, save_faiss_index, load_faiss_index, cleanup_images
23
  )
24
 
25
- from setup.multimodal_rag.model import embedding_model, model, processor
26
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
 
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from transformers import TextIteratorStreamer
18
 
19
+ from utils import (
20
  save_cache, load_cache,
21
  init_faiss_indexflatip, add_embeddings_to_index,
22
  search_faiss_index, save_faiss_index, load_faiss_index, cleanup_images
23
  )
24
 
25
+ from model_setup import embedding_model, model, processor
26
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
model_setup.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """loading the models to be used by the Mulltimodal RAG system."""
3
+
4
+ import torch
5
+
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, BitsAndBytesConfig
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Embedding model
12
+ embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
13
+
14
+ # Gemma3 quantized config
15
+ model_name = "google/gemma-3-4b-it"
16
+ bnb_config = BitsAndBytesConfig(
17
+ load_in_4bit=True,
18
+ bnb_4bit_compute_dtype=torch.bfloat16,
19
+ bnb_4bit_use_double_quant=True,
20
+ bnb_4bit_quant_type="nf4",
21
+ )
22
+
23
+ # Load Gemma3
24
+ model = Gemma3ForConditionalGeneration.from_pretrained(
25
+ model_name,
26
+ torch_dtype=torch.bfloat16,
27
+ device_map="auto",
28
+ quantization_config=bnb_config,
29
+ low_cpu_mem_usage=True,
30
+ attn_implementation="sdpa"
31
+ )
32
+ model.eval()
33
+
34
+ # Processor
35
+ processor = AutoProcessor.from_pretrained(model_name, use_fast=True)