Multimodal-rag-chatbot / utils /model_loader.py
Advait3009's picture
Update utils/model_loader.py
a558a96 verified
from transformers import pipeline, AutoTokenizer, BitsAndBytesConfig
import torch
from typing import Optional
def load_llava_model():
"""Load LLaVA model with 4-bit quantization for HF Spaces"""
model_id = "llava-hf/llava-1.5-7b-hf"
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
return pipeline(
"image-to-text",
model=model_id,
tokenizer=model_id,
device_map="auto",
model_kwargs={
"torch_dtype": torch.float16,
"quantization_config": quant_config
}
)
def load_caption_model():
"""BLIP-2 with efficient loading"""
return pipeline(
"image-to-text",
model="Salesforce/blip2-opt-2.7b",
device_map="auto",
torch_dtype=torch.float16,
model_kwargs={"cache_dir": "/tmp/models"}
)
def load_retrieval_models():
"""Load encoders with shared weights"""
from sentence_transformers import SentenceTransformer
from transformers import AutoModel
models = {}
models['text_encoder'] = SentenceTransformer(
'sentence-transformers/all-MiniLM-L6-v2',
device="cuda" if torch.cuda.is_available() else "cpu"
)
models['image_encoder'] = AutoModel.from_pretrained(
"openai/clip-vit-base-patch32",
device_map="auto",
torch_dtype=torch.float16
)
return models