File size: 868 Bytes
4a9c569
 
 
ca11a08
4a9c569
 
bff5090
 
ca11a08
8cb5b3d
4a9c569
 
 
 
 
27b8e9f
4a9c569
 
 
 
 
 
bff5090
4a9c569
bff5090
 
4a9c569
 
 
 
ca11a08
 
8cb5b3d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""loading the models to be used by the Mulltimodal RAG system."""

import torch
import gc

from sentence_transformers import SentenceTransformer
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
# from accelerate import disk_offload
from utils import clear_gpu_cache

device = "cuda" if torch.cuda.is_available() else "cpu"

# Embedding model
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# Gemma3 quantization config
model_name = "google/gemma-3-4b-it"

# Load Gemma3
model = Gemma3ForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="cpu", # Avoid meta errors
)
# disk_offload(model=model, offload_dir="offload")
model.to("cpu")
model.eval()

# Processor
processor = AutoProcessor.from_pretrained(model_name, use_fast=True)

clear_gpu_cache()
gc.collect()