Advait3009 commited on
Commit
a72f088
·
verified ·
1 Parent(s): d2a0e9f

Update utils/model_loader.py

Browse files
Files changed (1) hide show
  1. utils/model_loader.py +40 -9
utils/model_loader.py CHANGED
@@ -1,18 +1,49 @@
1
- from transformers import pipeline
2
  import torch
 
3
 
4
  def load_llava_model():
 
 
 
5
  return pipeline(
6
  "image-to-text",
7
- model="llava-hf/llava-1.5-7b-hf",
8
- torch_dtype=torch.float16,
9
  device_map="auto",
10
- max_new_tokens=200
 
 
 
 
 
 
 
 
 
11
  )
12
 
13
- def load_clip_model():
 
14
  return pipeline(
15
- "feature-extraction",
16
- model="openai/clip-vit-base-patch32",
17
- device_map="auto"
18
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, AutoProcessor, AutoModelForCausalLM
2
  import torch
3
+ from typing import Optional
4
 
5
  def load_llava_model():
6
+ """Load LLaVA model with 4-bit quantization for HF Spaces"""
7
+ model_id = "llava-hf/llava-1.5-7b-hf"
8
+
9
  return pipeline(
10
  "image-to-text",
11
+ model=model_id,
 
12
  device_map="auto",
13
+ model_kwargs={
14
+ "torch_dtype": torch.float16,
15
+ "load_in_4bit": True,
16
+ "quantization_config": {
17
+ "load_in_4bit": True,
18
+ "bnb_4bit_compute_dtype": torch.float16,
19
+ "bnb_4bit_use_double_quant": True,
20
+ "bnb_4bit_quant_type": "nf4"
21
+ }
22
+ }
23
  )
24
 
25
+ def load_caption_model():
26
+ """BLIP-2 with efficient loading"""
27
  return pipeline(
28
+ "image-to-text",
29
+ model="Salesforce/blip2-opt-2.7b",
30
+ device_map="auto",
31
+ torch_dtype=torch.float16,
32
+ model_kwargs={"cache_dir": "/tmp/models"}
33
+ )
34
+
35
+ def load_retrieval_models():
36
+ """Load encoders with shared weights"""
37
+ models = {}
38
+ models['text_encoder'] = SentenceTransformer(
39
+ 'sentence-transformers/all-MiniLM-L6-v2',
40
+ device="cuda" if torch.cuda.is_available() else "cpu"
41
+ )
42
+
43
+ models['image_encoder'] = AutoModel.from_pretrained(
44
+ "openai/clip-vit-base-patch32",
45
+ device_map="auto",
46
+ torch_dtype=torch.float16
47
+ )
48
+
49
+ return models