add support for opimum bettertransformers
Browse files- configs/gpt_neox_20b.yml +18 -12
- requirements.txt +1 -0
- scripts/finetune.py +11 -4
- src/axolotl/utils/models.py +6 -2
- src/axolotl/utils/validation.py +8 -0
configs/gpt_neox_20b.yml
CHANGED
|
@@ -1,24 +1,25 @@
|
|
| 1 |
base_model: EleutherAI/gpt-neox-20b
|
|
|
|
| 2 |
base_model_ignore_patterns: pytorch* # prefer safetensors
|
| 3 |
model_type: GPTNeoXForCausalLM
|
| 4 |
tokenizer_type: AutoTokenizer
|
| 5 |
-
load_in_8bit:
|
|
|
|
|
|
|
| 6 |
datasets:
|
| 7 |
-
- path:
|
| 8 |
type: alpaca
|
| 9 |
-
shards: 4
|
| 10 |
-
shards_index: 0
|
| 11 |
dataset_prepared_path: last_run_prepared
|
| 12 |
val_set_size: 0.05
|
| 13 |
-
adapter:
|
| 14 |
lora_model_dir:
|
| 15 |
sequence_len: 2048
|
| 16 |
max_packed_sequence_len: 2048
|
| 17 |
-
lora_r:
|
| 18 |
lora_alpha: 32
|
| 19 |
-
lora_dropout: 0.
|
| 20 |
lora_target_modules:
|
| 21 |
-
|
| 22 |
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
| 23 |
wandb_project: gpt4all-neox-20b
|
| 24 |
wandb_watch:
|
|
@@ -26,14 +27,19 @@ wandb_run_id:
|
|
| 26 |
wandb_log_model:
|
| 27 |
output_dir: ./gpt4all-neox-20b
|
| 28 |
gradient_accumulation_steps: 1
|
| 29 |
-
micro_batch_size:
|
| 30 |
num_epochs: 5
|
| 31 |
learning_rate: 0.00003
|
| 32 |
-
|
|
|
|
| 33 |
train_on_inputs: false
|
| 34 |
group_by_length: false
|
| 35 |
-
bf16:
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
| 37 |
early_stopping_patience:
|
| 38 |
resume_from_checkpoint:
|
| 39 |
local_rank:
|
|
|
|
|
|
| 1 |
base_model: EleutherAI/gpt-neox-20b
|
| 2 |
+
base_model_config: EleutherAI/gpt-neox-20b
|
| 3 |
base_model_ignore_patterns: pytorch* # prefer safetensors
|
| 4 |
model_type: GPTNeoXForCausalLM
|
| 5 |
tokenizer_type: AutoTokenizer
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: true
|
| 8 |
+
load_4bit: false
|
| 9 |
datasets:
|
| 10 |
+
- path: vicgalle/alpaca-gpt4
|
| 11 |
type: alpaca
|
|
|
|
|
|
|
| 12 |
dataset_prepared_path: last_run_prepared
|
| 13 |
val_set_size: 0.05
|
| 14 |
+
adapter:
|
| 15 |
lora_model_dir:
|
| 16 |
sequence_len: 2048
|
| 17 |
max_packed_sequence_len: 2048
|
| 18 |
+
lora_r: 64
|
| 19 |
lora_alpha: 32
|
| 20 |
+
lora_dropout: 0.0
|
| 21 |
lora_target_modules:
|
| 22 |
+
lora_target_linear: true
|
| 23 |
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
| 24 |
wandb_project: gpt4all-neox-20b
|
| 25 |
wandb_watch:
|
|
|
|
| 27 |
wandb_log_model:
|
| 28 |
output_dir: ./gpt4all-neox-20b
|
| 29 |
gradient_accumulation_steps: 1
|
| 30 |
+
micro_batch_size: 2
|
| 31 |
num_epochs: 5
|
| 32 |
learning_rate: 0.00003
|
| 33 |
+
optimizer: paged_adamw_32bit
|
| 34 |
+
lr_scheduler: cosine
|
| 35 |
train_on_inputs: false
|
| 36 |
group_by_length: false
|
| 37 |
+
bf16: false
|
| 38 |
+
fp16: false
|
| 39 |
+
float16: true
|
| 40 |
+
tf32: true
|
| 41 |
+
flash_optimum: true
|
| 42 |
early_stopping_patience:
|
| 43 |
resume_from_checkpoint:
|
| 44 |
local_rank:
|
| 45 |
+
gradient_checkpointing: true
|
requirements.txt
CHANGED
|
@@ -11,6 +11,7 @@ sentencepiece
|
|
| 11 |
wandb
|
| 12 |
einops
|
| 13 |
xformers
|
|
|
|
| 14 |
# qlora things
|
| 15 |
bert-score==0.3.13
|
| 16 |
evaluate==0.4.0
|
|
|
|
| 11 |
wandb
|
| 12 |
einops
|
| 13 |
xformers
|
| 14 |
+
optimum
|
| 15 |
# qlora things
|
| 16 |
bert-score==0.3.13
|
| 17 |
evaluate==0.4.0
|
scripts/finetune.py
CHANGED
|
@@ -6,6 +6,7 @@ import os
|
|
| 6 |
import random
|
| 7 |
import signal
|
| 8 |
import sys
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import Any, Dict, List, Optional, Union
|
| 11 |
|
|
@@ -19,6 +20,8 @@ from axolotl.utils.dict import DictDefault
|
|
| 19 |
from axolotl.utils.models import load_model, load_tokenizer
|
| 20 |
|
| 21 |
# add src to the pythonpath so we don't need to pip install this
|
|
|
|
|
|
|
| 22 |
from axolotl.utils.tokenization import check_dataset_labels
|
| 23 |
from axolotl.utils.trainer import setup_trainer
|
| 24 |
from axolotl.utils.validation import validate_config
|
|
@@ -264,12 +267,14 @@ def train(
|
|
| 264 |
|
| 265 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
| 266 |
if cfg.local_rank == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
signal.signal(
|
| 268 |
signal.SIGINT,
|
| 269 |
-
lambda
|
| 270 |
-
model.save_pretrained(cfg.output_dir),
|
| 271 |
-
sys.exit(0),
|
| 272 |
-
),
|
| 273 |
)
|
| 274 |
|
| 275 |
logging.info("Starting trainer...")
|
|
@@ -299,6 +304,8 @@ def train(
|
|
| 299 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
| 300 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
| 301 |
if cfg.local_rank == 0:
|
|
|
|
|
|
|
| 302 |
model.save_pretrained(cfg.output_dir)
|
| 303 |
|
| 304 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
|
|
|
| 6 |
import random
|
| 7 |
import signal
|
| 8 |
import sys
|
| 9 |
+
from functools import partial
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Any, Dict, List, Optional, Union
|
| 12 |
|
|
|
|
| 20 |
from axolotl.utils.models import load_model, load_tokenizer
|
| 21 |
|
| 22 |
# add src to the pythonpath so we don't need to pip install this
|
| 23 |
+
from optimum.bettertransformer import BetterTransformer
|
| 24 |
+
|
| 25 |
from axolotl.utils.tokenization import check_dataset_labels
|
| 26 |
from axolotl.utils.trainer import setup_trainer
|
| 27 |
from axolotl.utils.validation import validate_config
|
|
|
|
| 267 |
|
| 268 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
| 269 |
if cfg.local_rank == 0:
|
| 270 |
+
def terminate_handler(signum, frame, model):
|
| 271 |
+
if cfg.flash_optimum:
|
| 272 |
+
model = BetterTransformer.reverse(model)
|
| 273 |
+
model.save_pretrained(cfg.output_dir)
|
| 274 |
+
sys.exit(0)
|
| 275 |
signal.signal(
|
| 276 |
signal.SIGINT,
|
| 277 |
+
lambda signum, frame: terminate_handler(signum, frame, model)
|
|
|
|
|
|
|
|
|
|
| 278 |
)
|
| 279 |
|
| 280 |
logging.info("Starting trainer...")
|
|
|
|
| 304 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
| 305 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
| 306 |
if cfg.local_rank == 0:
|
| 307 |
+
if cfg.flash_optimum:
|
| 308 |
+
model = BetterTransformer.reverse(model)
|
| 309 |
model.save_pretrained(cfg.output_dir)
|
| 310 |
|
| 311 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
src/axolotl/utils/models.py
CHANGED
|
@@ -11,7 +11,8 @@ import bitsandbytes as bnb
|
|
| 11 |
import torch
|
| 12 |
import transformers
|
| 13 |
from transformers import PreTrainedModel # noqa: F401
|
| 14 |
-
from
|
|
|
|
| 15 |
AutoConfig,
|
| 16 |
AutoModelForCausalLM,
|
| 17 |
AutoTokenizer,
|
|
@@ -137,7 +138,7 @@ def load_model(
|
|
| 137 |
|
| 138 |
if cfg.bf16:
|
| 139 |
torch_dtype = torch.bfloat16
|
| 140 |
-
elif cfg.load_in_8bit or cfg.fp16:
|
| 141 |
torch_dtype = torch.float16
|
| 142 |
else:
|
| 143 |
torch_dtype = torch.float32
|
|
@@ -342,6 +343,9 @@ def load_model(
|
|
| 342 |
logging.warning("there are no parameters that require gradient updates")
|
| 343 |
model.config.use_cache = False
|
| 344 |
|
|
|
|
|
|
|
|
|
|
| 345 |
# TODO resume_from_checkpoint handling
|
| 346 |
return model, lora_config
|
| 347 |
|
|
|
|
| 11 |
import torch
|
| 12 |
import transformers
|
| 13 |
from transformers import PreTrainedModel # noqa: F401
|
| 14 |
+
from optimum.bettertransformer import BetterTransformer
|
| 15 |
+
from transformers import (
|
| 16 |
AutoConfig,
|
| 17 |
AutoModelForCausalLM,
|
| 18 |
AutoTokenizer,
|
|
|
|
| 138 |
|
| 139 |
if cfg.bf16:
|
| 140 |
torch_dtype = torch.bfloat16
|
| 141 |
+
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
| 142 |
torch_dtype = torch.float16
|
| 143 |
else:
|
| 144 |
torch_dtype = torch.float32
|
|
|
|
| 343 |
logging.warning("there are no parameters that require gradient updates")
|
| 344 |
model.config.use_cache = False
|
| 345 |
|
| 346 |
+
if cfg.flash_optimum:
|
| 347 |
+
model = BetterTransformer.transform(model)
|
| 348 |
+
|
| 349 |
# TODO resume_from_checkpoint handling
|
| 350 |
return model, lora_config
|
| 351 |
|
src/axolotl/utils/validation.py
CHANGED
|
@@ -57,6 +57,14 @@ def validate_config(cfg):
|
|
| 57 |
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
| 58 |
raise ValueError("FSDP is not supported for falcon models")
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
# TODO
|
| 61 |
# MPT 7b
|
| 62 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
|
| 57 |
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
| 58 |
raise ValueError("FSDP is not supported for falcon models")
|
| 59 |
|
| 60 |
+
if cfg.flash_optimum is True:
|
| 61 |
+
if cfg.adapter:
|
| 62 |
+
logging.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
| 63 |
+
if cfg.fp16 or cfg.bf16:
|
| 64 |
+
raise ValueError("AMP is not supported with BetterTransformer")
|
| 65 |
+
if cfg.float16 is not True:
|
| 66 |
+
logging.warning("You should probably set float16 to true")
|
| 67 |
+
|
| 68 |
# TODO
|
| 69 |
# MPT 7b
|
| 70 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|