split up llama model loading so config can be loaded from base config and models can be loaded from a path
Browse files
src/axolotl/utils/models.py
CHANGED
|
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
|
| 10 |
import bitsandbytes as bnb
|
| 11 |
import torch
|
| 12 |
import transformers
|
| 13 |
-
from transformers import AutoModelForCausalLM # noqa: F401
|
| 14 |
from transformers import PreTrainedModel # noqa: F401
|
| 15 |
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
|
| 16 |
|
|
@@ -172,8 +172,10 @@ def load_model(
|
|
| 172 |
)
|
| 173 |
load_in_8bit = False
|
| 174 |
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
|
|
|
| 175 |
model = LlamaForCausalLM.from_pretrained(
|
| 176 |
base_model,
|
|
|
|
| 177 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 178 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 179 |
torch_dtype=torch_dtype,
|
|
|
|
| 10 |
import bitsandbytes as bnb
|
| 11 |
import torch
|
| 12 |
import transformers
|
| 13 |
+
from transformers import AutoModelForCausalLM, LlamaConfig # noqa: F401
|
| 14 |
from transformers import PreTrainedModel # noqa: F401
|
| 15 |
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
|
| 16 |
|
|
|
|
| 172 |
)
|
| 173 |
load_in_8bit = False
|
| 174 |
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
| 175 |
+
config = LlamaConfig.from_pretrained(base_model_config)
|
| 176 |
model = LlamaForCausalLM.from_pretrained(
|
| 177 |
base_model,
|
| 178 |
+
config=config,
|
| 179 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 180 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 181 |
torch_dtype=torch_dtype,
|