Saint5 commited on
Commit
27b8e9f
·
verified ·
1 Parent(s): acca017

Uploading Mulitimodal Retrieval Augmented Generation System.

Browse files
Files changed (1) hide show
  1. model_setup.py +6 -4
model_setup.py CHANGED
@@ -5,6 +5,7 @@ import gc
5
 
6
  from sentence_transformers import SentenceTransformer
7
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, BitsAndBytesConfig
 
8
  from utils import clear_gpu_cache
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -12,25 +13,26 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
12
  # Embedding model
13
  embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
14
 
15
- # Gemma3 quantized config
16
  model_name = "google/gemma-3-4b-it"
17
  bnb_config = BitsAndBytesConfig(
18
  load_in_4bit=True,
19
  bnb_4bit_compute_dtype=torch.bfloat16,
20
  bnb_4bit_use_double_quant=True,
21
  bnb_4bit_quant_type="nf4",
22
- llm_int8_enable_fp32_cpu_offload=True # Allow offloading
23
  )
24
 
25
  # Load Gemma3
26
  model = Gemma3ForConditionalGeneration.from_pretrained(
27
  model_name,
28
  torch_dtype=torch.bfloat16,
29
- device_map="auto",
30
  quantization_config=bnb_config,
31
  low_cpu_mem_usage=True,
32
- attn_implementation="sdpa"
33
  )
 
34
  model.eval()
35
 
36
  # Processor
 
5
 
6
  from sentence_transformers import SentenceTransformer
7
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, BitsAndBytesConfig
8
+ from accelerate import disk_offload
9
  from utils import clear_gpu_cache
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
13
  # Embedding model
14
  embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
15
 
16
+ # Gemma3 quantization config
17
  model_name = "google/gemma-3-4b-it"
18
  bnb_config = BitsAndBytesConfig(
19
  load_in_4bit=True,
20
  bnb_4bit_compute_dtype=torch.bfloat16,
21
  bnb_4bit_use_double_quant=True,
22
  bnb_4bit_quant_type="nf4",
23
+ # llm_int8_enable_fp32_cpu_offload=True # Allow offloading
24
  )
25
 
26
  # Load Gemma3
27
  model = Gemma3ForConditionalGeneration.from_pretrained(
28
  model_name,
29
  torch_dtype=torch.bfloat16,
30
+ device_map="cpu", # not "auto" since there is no GPU
31
  quantization_config=bnb_config,
32
  low_cpu_mem_usage=True,
33
+ # attn_implementation="sdpa"
34
  )
35
+ disk_offload(model=model, offload_dir="offload")
36
  model.eval()
37
 
38
  # Processor