fix for qwen w lora (#906)
Browse files- src/axolotl/utils/models.py +10 -3
src/axolotl/utils/models.py
CHANGED
|
@@ -412,15 +412,22 @@ def load_model(
|
|
| 412 |
module.to(torch.float32)
|
| 413 |
|
| 414 |
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
if (cfg.adapter == "lora" and load_in_8bit) or (
|
| 416 |
cfg.adapter == "qlora" and cfg.load_in_4bit
|
| 417 |
):
|
| 418 |
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
| 419 |
if cfg.gradient_checkpointing:
|
| 420 |
model.gradient_checkpointing_enable()
|
| 421 |
-
|
| 422 |
-
model
|
| 423 |
-
|
|
|
|
| 424 |
needs_fa2_dtype = True
|
| 425 |
|
| 426 |
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
|
|
|
| 412 |
module.to(torch.float32)
|
| 413 |
|
| 414 |
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
| 415 |
+
skip_prepare_model_for_kbit_training = False
|
| 416 |
+
|
| 417 |
+
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
| 418 |
+
# Qwen doesn't play nicely with LoRA if this is enabled
|
| 419 |
+
skip_prepare_model_for_kbit_training = True
|
| 420 |
+
|
| 421 |
if (cfg.adapter == "lora" and load_in_8bit) or (
|
| 422 |
cfg.adapter == "qlora" and cfg.load_in_4bit
|
| 423 |
):
|
| 424 |
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
| 425 |
if cfg.gradient_checkpointing:
|
| 426 |
model.gradient_checkpointing_enable()
|
| 427 |
+
if not skip_prepare_model_for_kbit_training:
|
| 428 |
+
model = prepare_model_for_kbit_training(
|
| 429 |
+
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
| 430 |
+
)
|
| 431 |
needs_fa2_dtype = True
|
| 432 |
|
| 433 |
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|