litwell commited on
Commit
3ba711c
·
verified ·
1 Parent(s): 5d224a4

Upload models/src/utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/src/utils.py +91 -0
models/src/utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from peft import PeftModel
2
+ import torch
3
+ from transformers import BitsAndBytesConfig, Qwen2VLForConditionalGeneration, AutoProcessor, AutoConfig, Qwen2_5_VLForConditionalGeneration
4
+ import warnings
5
+ import os
6
+ import json
7
+
8
+ def disable_torch_init():
9
+ """
10
+ Disable the redundant torch default initialization to accelerate model creation.
11
+ """
12
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
13
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
14
+
15
+ # This code is borrowed from LLaVA
16
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False,
17
+ device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
18
+ kwargs = {"device_map": device_map}
19
+
20
+ if device != "cuda":
21
+ kwargs['device_map'] = {"":device}
22
+
23
+ if load_8bit:
24
+ kwargs['load_in_8bit'] = True
25
+ elif load_4bit:
26
+ kwargs['quantization_config'] = BitsAndBytesConfig(
27
+ load_in_4bit=True,
28
+ bnb_4bit_compute_dtype=torch.float16,
29
+ bnb_4bit_use_double_quant=True,
30
+ bnb_4bit_quant_type='nf4'
31
+ )
32
+ else:
33
+ kwargs['torch_dtype'] = torch.float16
34
+
35
+ if use_flash_attn:
36
+ kwargs['_attn_implementation'] = 'flash_attention_2'
37
+
38
+ if 'lora' in model_name.lower() and model_base is None:
39
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument.')
40
+ if 'lora' in model_name.lower() and model_base is not None:
41
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
42
+ if hasattr(lora_cfg_pretrained, 'quantization_config'):
43
+ del lora_cfg_pretrained.quantization_config
44
+ processor = AutoProcessor.from_pretrained(model_base)
45
+ print('Loading Qwen2-VL from base model...')
46
+ if "Qwen2.5" in model_base:
47
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
48
+ else:
49
+ model = Qwen2VLForConditionalGeneration.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
50
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
51
+ if model.lm_head.weight.shape[0] != token_num:
52
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
53
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
54
+
55
+ print('Loading additional Qwen2-VL weights...')
56
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_state_dict.bin'), map_location='cpu')
57
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
58
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
59
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
60
+ model.load_state_dict(non_lora_trainables, strict=False)
61
+
62
+ print('Loading LoRA weights...')
63
+ model = PeftModel.from_pretrained(model, model_path)
64
+
65
+ print('Merging LoRA weights...')
66
+ model = model.merge_and_unload()
67
+
68
+ print('Model Loaded!!!')
69
+
70
+ else:
71
+ with open(os.path.join(model_path, 'config.json'), 'r') as f:
72
+ config = json.load(f)
73
+
74
+ if "Qwen2.5" in config["_name_or_path"]:
75
+ processor = AutoProcessor.from_pretrained(model_path)
76
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
77
+
78
+ else:
79
+ processor = AutoProcessor.from_pretrained(model_path)
80
+ model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
81
+
82
+ return processor, model
83
+
84
+
85
+ def get_model_name_from_path(model_path):
86
+ model_path = model_path.strip("/")
87
+ model_paths = model_path.split("/")
88
+ if model_paths[-1].startswith('checkpoint-'):
89
+ return model_paths[-2] + "_" + model_paths[-1]
90
+ else:
91
+ return model_paths[-1]