Update libra/model/builder.py
Browse files- libra/model/builder.py +9 -20
libra/model/builder.py
CHANGED
|
@@ -23,25 +23,14 @@ from libra.model import *
|
|
| 23 |
from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
| 24 |
|
| 25 |
|
| 26 |
-
def load_pretrained_model(model_path, model_base, model_name,
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
|
| 30 |
-
if device != "cuda":
|
| 31 |
-
kwargs['device_map'] = {"": device}
|
| 32 |
|
| 33 |
-
if load_8bit:
|
| 34 |
-
kwargs['load_in_8bit'] = True
|
| 35 |
-
elif load_4bit:
|
| 36 |
-
kwargs['load_in_4bit'] = True
|
| 37 |
-
kwargs['quantization_config'] = BitsAndBytesConfig(
|
| 38 |
-
load_in_4bit=True,
|
| 39 |
-
bnb_4bit_compute_dtype=torch.float16,
|
| 40 |
-
bnb_4bit_use_double_quant=True,
|
| 41 |
-
bnb_4bit_quant_type='nf4'
|
| 42 |
-
)
|
| 43 |
-
else:
|
| 44 |
-
kwargs['torch_dtype'] = torch.float16
|
| 45 |
|
| 46 |
if 'libra' in model_name.lower():
|
| 47 |
# Load Libra model
|
|
@@ -92,7 +81,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|
| 92 |
model.load_state_dict(mm_projector_weights, strict=False)
|
| 93 |
else:
|
| 94 |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
| 95 |
-
model = LibraLlamaForCausalLM.from_pretrained(model_path,
|
| 96 |
else:
|
| 97 |
# Load language model
|
| 98 |
if model_base is not None:
|
|
@@ -124,8 +113,8 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|
| 124 |
|
| 125 |
vision_tower = model.get_vision_tower()
|
| 126 |
if not vision_tower.is_loaded:
|
| 127 |
-
vision_tower.load_model(
|
| 128 |
-
vision_tower.to(device=device, dtype=torch.
|
| 129 |
image_processor = vision_tower.image_processor
|
| 130 |
|
| 131 |
if hasattr(model.config, "max_sequence_length"):
|
|
|
|
| 23 |
from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
| 24 |
|
| 25 |
|
| 26 |
+
def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
|
| 27 |
+
device_map = {"": device}
|
| 28 |
+
kwargs = {
|
| 29 |
+
"device_map": device_map,
|
| 30 |
+
"torch_dtype": torch.float32 # 对于 CPU,建议使用 float32 或 bfloat16
|
| 31 |
+
}
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
if 'libra' in model_name.lower():
|
| 36 |
# Load Libra model
|
|
|
|
| 81 |
model.load_state_dict(mm_projector_weights, strict=False)
|
| 82 |
else:
|
| 83 |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
| 84 |
+
model = LibraLlamaForCausalLM.from_pretrained(model_path, **kwargs)
|
| 85 |
else:
|
| 86 |
# Load language model
|
| 87 |
if model_base is not None:
|
|
|
|
| 113 |
|
| 114 |
vision_tower = model.get_vision_tower()
|
| 115 |
if not vision_tower.is_loaded:
|
| 116 |
+
vision_tower.load_model()
|
| 117 |
+
vision_tower.to(device=device, dtype=torch.float32)
|
| 118 |
image_processor = vision_tower.image_processor
|
| 119 |
|
| 120 |
if hasattr(model.config, "max_sequence_length"):
|