Merge pull request #36 from OpenAccess-AI-Collective/qlora
Browse files- requirements.txt +1 -1
- scripts/finetune.py +3 -0
- src/axolotl/prompters.py +5 -0
- src/axolotl/utils/data.py +5 -0
- src/axolotl/utils/models.py +21 -3
requirements.txt
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
peft @ git+https://github.com/huggingface/peft.git
|
| 2 |
transformers @ git+https://github.com/huggingface/transformers.git
|
|
|
|
| 3 |
attrdict
|
| 4 |
fire
|
| 5 |
PyYAML==6.0
|
| 6 |
black
|
| 7 |
-
bitsandbytes==0.37.2
|
| 8 |
datasets
|
| 9 |
accelerate>=0.19.0
|
| 10 |
sentencepiece
|
|
|
|
| 1 |
peft @ git+https://github.com/huggingface/peft.git
|
| 2 |
transformers @ git+https://github.com/huggingface/transformers.git
|
| 3 |
+
bitsandbytes>=0.39.0
|
| 4 |
attrdict
|
| 5 |
fire
|
| 6 |
PyYAML==6.0
|
| 7 |
black
|
|
|
|
| 8 |
datasets
|
| 9 |
accelerate>=0.19.0
|
| 10 |
sentencepiece
|
scripts/finetune.py
CHANGED
|
@@ -14,6 +14,7 @@ from attrdict import AttrDefault
|
|
| 14 |
|
| 15 |
# add src to the pythonpath so we don't need to pip install this
|
| 16 |
from axolotl.utils.tokenization import check_dataset_labels
|
|
|
|
| 17 |
|
| 18 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 19 |
src_dir = os.path.join(project_root, "src")
|
|
@@ -158,6 +159,8 @@ def train(
|
|
| 158 |
cfg.fp16 = True
|
| 159 |
cfg.bf16 = False
|
| 160 |
|
|
|
|
|
|
|
| 161 |
# Load the model and tokenizer
|
| 162 |
logging.info("loading model, tokenizer, and peft_config...")
|
| 163 |
model, tokenizer, peft_config = load_model(
|
|
|
|
| 14 |
|
| 15 |
# add src to the pythonpath so we don't need to pip install this
|
| 16 |
from axolotl.utils.tokenization import check_dataset_labels
|
| 17 |
+
from axolotl.utils.validation import validate_config
|
| 18 |
|
| 19 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 20 |
src_dir = os.path.join(project_root, "src")
|
|
|
|
| 159 |
cfg.fp16 = True
|
| 160 |
cfg.bf16 = False
|
| 161 |
|
| 162 |
+
validate_config(cfg)
|
| 163 |
+
|
| 164 |
# Load the model and tokenizer
|
| 165 |
logging.info("loading model, tokenizer, and peft_config...")
|
| 166 |
model, tokenizer, peft_config = load_model(
|
src/axolotl/prompters.py
CHANGED
|
@@ -11,6 +11,7 @@ class PromptStyle(Enum):
|
|
| 11 |
instruct = "instruct"
|
| 12 |
chat = "chat"
|
| 13 |
|
|
|
|
| 14 |
class AlpacaPrompter:
|
| 15 |
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
| 16 |
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
|
@@ -50,6 +51,10 @@ class AlpacaPrompter:
|
|
| 50 |
return output.split(self.response_split)[1].strip()
|
| 51 |
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
class JeopardyPrompter(AlpacaPrompter):
|
| 54 |
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
| 55 |
|
|
|
|
| 11 |
instruct = "instruct"
|
| 12 |
chat = "chat"
|
| 13 |
|
| 14 |
+
|
| 15 |
class AlpacaPrompter:
|
| 16 |
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
| 17 |
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
|
|
|
| 51 |
return output.split(self.response_split)[1].strip()
|
| 52 |
|
| 53 |
|
| 54 |
+
class UnpromptedPrompter(AlpacaPrompter):
|
| 55 |
+
system_prompt = ""
|
| 56 |
+
system_no_input_prompt = ""
|
| 57 |
+
|
| 58 |
class JeopardyPrompter(AlpacaPrompter):
|
| 59 |
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
| 60 |
|
src/axolotl/utils/data.py
CHANGED
|
@@ -98,6 +98,11 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
| 98 |
ds = load_dataset("json", data_files=fp, streaming=False, split=None)
|
| 99 |
if not ds:
|
| 100 |
raise Exception("unhandled dataset load")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
d_type = d.type
|
| 102 |
d_type_split = d_type.split(":")
|
| 103 |
d_base_type = d_type_split[0]
|
|
|
|
| 98 |
ds = load_dataset("json", data_files=fp, streaming=False, split=None)
|
| 99 |
if not ds:
|
| 100 |
raise Exception("unhandled dataset load")
|
| 101 |
+
# support for using a subset of the data
|
| 102 |
+
if d.shards:
|
| 103 |
+
ds = ds.shuffle(seed=42)["train"].shard(
|
| 104 |
+
num_shards=cfg.shards, index=0
|
| 105 |
+
)
|
| 106 |
d_type = d.type
|
| 107 |
d_type_split = d_type.split(":")
|
| 108 |
d_base_type = d_type_split[0]
|
src/axolotl/utils/models.py
CHANGED
|
@@ -6,11 +6,12 @@ from typing import Optional, Tuple, TYPE_CHECKING
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import transformers
|
|
|
|
| 9 |
from transformers import (
|
| 10 |
AutoModelForCausalLM,
|
| 11 |
AutoTokenizer,
|
| 12 |
PreTrainedModel,
|
| 13 |
-
AutoConfig,
|
| 14 |
)
|
| 15 |
|
| 16 |
try:
|
|
@@ -81,6 +82,16 @@ def load_model(
|
|
| 81 |
logging.exception(e)
|
| 82 |
raise e
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
try:
|
| 85 |
if cfg.load_4bit and is_llama_derived_model:
|
| 86 |
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
|
@@ -123,8 +134,10 @@ def load_model(
|
|
| 123 |
model = LlamaForCausalLM.from_pretrained(
|
| 124 |
base_model,
|
| 125 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
|
|
| 126 |
torch_dtype=torch_dtype,
|
| 127 |
device_map=cfg.device_map,
|
|
|
|
| 128 |
)
|
| 129 |
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
| 130 |
# This is a WIP, still an issue with the backward pass
|
|
@@ -156,9 +169,11 @@ def load_model(
|
|
| 156 |
model = getattr(transformers, model_type).from_pretrained(
|
| 157 |
base_model,
|
| 158 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
|
|
| 159 |
torch_dtype=torch_dtype,
|
| 160 |
device_map=cfg.device_map,
|
| 161 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
|
|
|
| 162 |
)
|
| 163 |
else:
|
| 164 |
config = AutoConfig.from_pretrained(
|
|
@@ -169,9 +184,11 @@ def load_model(
|
|
| 169 |
base_model,
|
| 170 |
config=config,
|
| 171 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
|
|
| 172 |
torch_dtype=torch_dtype,
|
| 173 |
device_map=cfg.device_map,
|
| 174 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
|
|
|
| 175 |
)
|
| 176 |
except Exception as e:
|
| 177 |
logging.error(
|
|
@@ -184,6 +201,7 @@ def load_model(
|
|
| 184 |
torch_dtype=torch_dtype,
|
| 185 |
device_map=cfg.device_map,
|
| 186 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
|
|
|
| 187 |
)
|
| 188 |
|
| 189 |
if not tokenizer:
|
|
@@ -225,7 +243,7 @@ def load_model(
|
|
| 225 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
| 226 |
model.resize_token_embeddings(embeddings_len)
|
| 227 |
|
| 228 |
-
if cfg.adapter and load_in_8bit and not cfg.load_4bit:
|
| 229 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
| 230 |
model = prepare_model_for_int8_training(model)
|
| 231 |
|
|
@@ -270,7 +288,7 @@ def load_adapter(model, cfg, adapter):
|
|
| 270 |
|
| 271 |
if adapter is None:
|
| 272 |
return model, None
|
| 273 |
-
if adapter == "lora":
|
| 274 |
return load_lora(model, cfg)
|
| 275 |
if adapter == "llama-adapter":
|
| 276 |
return load_llama_adapter(model, cfg)
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import transformers
|
| 9 |
+
from torch import nn
|
| 10 |
from transformers import (
|
| 11 |
AutoModelForCausalLM,
|
| 12 |
AutoTokenizer,
|
| 13 |
PreTrainedModel,
|
| 14 |
+
AutoConfig, BitsAndBytesConfig,
|
| 15 |
)
|
| 16 |
|
| 17 |
try:
|
|
|
|
| 82 |
logging.exception(e)
|
| 83 |
raise e
|
| 84 |
|
| 85 |
+
model_kwargs = {}
|
| 86 |
+
if cfg.adapter == "qlora":
|
| 87 |
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 88 |
+
load_in_4bit=True,
|
| 89 |
+
llm_int8_threshold=6.0,
|
| 90 |
+
llm_int8_has_fp16_weight=False,
|
| 91 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 92 |
+
bnb_4bit_use_double_quant=True,
|
| 93 |
+
bnb_4bit_quant_type="nf4",
|
| 94 |
+
)
|
| 95 |
try:
|
| 96 |
if cfg.load_4bit and is_llama_derived_model:
|
| 97 |
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
|
|
|
| 134 |
model = LlamaForCausalLM.from_pretrained(
|
| 135 |
base_model,
|
| 136 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 137 |
+
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 138 |
torch_dtype=torch_dtype,
|
| 139 |
device_map=cfg.device_map,
|
| 140 |
+
**model_kwargs,
|
| 141 |
)
|
| 142 |
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
| 143 |
# This is a WIP, still an issue with the backward pass
|
|
|
|
| 169 |
model = getattr(transformers, model_type).from_pretrained(
|
| 170 |
base_model,
|
| 171 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 172 |
+
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 173 |
torch_dtype=torch_dtype,
|
| 174 |
device_map=cfg.device_map,
|
| 175 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
| 176 |
+
**model_kwargs,
|
| 177 |
)
|
| 178 |
else:
|
| 179 |
config = AutoConfig.from_pretrained(
|
|
|
|
| 184 |
base_model,
|
| 185 |
config=config,
|
| 186 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 187 |
+
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 188 |
torch_dtype=torch_dtype,
|
| 189 |
device_map=cfg.device_map,
|
| 190 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
| 191 |
+
**model_kwargs,
|
| 192 |
)
|
| 193 |
except Exception as e:
|
| 194 |
logging.error(
|
|
|
|
| 201 |
torch_dtype=torch_dtype,
|
| 202 |
device_map=cfg.device_map,
|
| 203 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
| 204 |
+
**model_kwargs,
|
| 205 |
)
|
| 206 |
|
| 207 |
if not tokenizer:
|
|
|
|
| 243 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
| 244 |
model.resize_token_embeddings(embeddings_len)
|
| 245 |
|
| 246 |
+
if ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") and not cfg.load_4bit:
|
| 247 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
| 248 |
model = prepare_model_for_int8_training(model)
|
| 249 |
|
|
|
|
| 288 |
|
| 289 |
if adapter is None:
|
| 290 |
return model, None
|
| 291 |
+
if adapter == "lora" or adapter == "qlora":
|
| 292 |
return load_lora(model, cfg)
|
| 293 |
if adapter == "llama-adapter":
|
| 294 |
return load_llama_adapter(model, cfg)
|