credent007 commited on
Commit
011728b
·
verified ·
1 Parent(s): 33c13a8

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +47 -15
model_loader.py CHANGED
@@ -1,22 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from transformers import AutoProcessor, AutoModelForVision2Seq
3
 
4
  MODEL_NAME = "Qwen/Qwen2.5-VL-7B-Instruct"
5
 
 
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
 
8
- print("Loading processor...")
9
- processor = AutoProcessor.from_pretrained(
10
- MODEL_NAME,
11
- trust_remote_code=True,
12
- use_fast=True) # use_fast to avoid warnings in logs
13
-
14
- print("Loading model...")
15
- model = AutoModelForVision2Seq.from_pretrained(
16
- MODEL_NAME,
17
- trust_remote_code=True,
18
- torch_dtype=torch.float16,
19
- device_map="auto"
20
- )
21
-
22
- print("Model loaded successfully")
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # from transformers import AutoProcessor, AutoModelForVision2Seq
3
+
4
+ # MODEL_NAME = "Qwen/Qwen2.5-VL-7B-Instruct"
5
+
6
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ # print("Loading processor...")
9
+ # processor = AutoProcessor.from_pretrained(
10
+ # MODEL_NAME,
11
+ # trust_remote_code=True,
12
+ # use_fast=True) # use_fast to avoid warnings in logs
13
+
14
+ # print("Loading model...")
15
+ # model = AutoModelForVision2Seq.from_pretrained(
16
+ # MODEL_NAME,
17
+ # trust_remote_code=True,
18
+ # torch_dtype=torch.float16,
19
+ # device_map="auto"
20
+ # )
21
+
22
+ # print("Model loaded successfully")
23
  import torch
24
  from transformers import AutoProcessor, AutoModelForVision2Seq
25
 
26
  MODEL_NAME = "Qwen/Qwen2.5-VL-7B-Instruct"
27
 
28
+ model = None
29
+ processor = None
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
+
33
+ def get_model():
34
+ global model, processor, device
35
+
36
+ if model is None or processor is None:
37
+ print("Loading processor...")
38
+ processor = AutoProcessor.from_pretrained(
39
+ MODEL_NAME,
40
+ trust_remote_code=True,
41
+ use_fast=True
42
+ )
43
+
44
+ print("Loading model...")
45
+ model = AutoModelForVision2Seq.from_pretrained(
46
+ MODEL_NAME,
47
+ trust_remote_code=True,
48
+ torch_dtype=torch.float16,
49
+ device_map="auto"
50
+ )
51
+
52
+ print("Model loaded successfully")
53
+
54
+ return model, processor, device