various bugfixes
Browse files- configs/llama_65B_alpaca.yml +3 -3
- requirements.txt +3 -0
- scripts/finetune.py +8 -7
- setup.cfg +1 -0
- src/axolotl/datasets.py +21 -19
configs/llama_65B_alpaca.yml
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
base_model:
|
| 2 |
model_type: LlamaForCausalLM
|
| 3 |
tokenizer_type: LlamaTokenizer
|
| 4 |
load_in_8bit: true
|
|
@@ -33,8 +33,8 @@ num_epochs: 5
|
|
| 33 |
learning_rate: 0.00003
|
| 34 |
train_on_inputs: false
|
| 35 |
group_by_length: false
|
| 36 |
-
bf16:
|
| 37 |
-
tf32:
|
| 38 |
resume_from_checkpoint:
|
| 39 |
local_rank:
|
| 40 |
deepspeed:
|
|
|
|
| 1 |
+
base_model: huggyllama/llama-7b
|
| 2 |
model_type: LlamaForCausalLM
|
| 3 |
tokenizer_type: LlamaTokenizer
|
| 4 |
load_in_8bit: true
|
|
|
|
| 33 |
learning_rate: 0.00003
|
| 34 |
train_on_inputs: false
|
| 35 |
group_by_length: false
|
| 36 |
+
bf16: true
|
| 37 |
+
tf32: true
|
| 38 |
resume_from_checkpoint:
|
| 39 |
local_rank:
|
| 40 |
deepspeed:
|
requirements.txt
CHANGED
|
@@ -10,3 +10,6 @@ accelerate
|
|
| 10 |
sentencepiece
|
| 11 |
wandb
|
| 12 |
flash-attn
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
sentencepiece
|
| 11 |
wandb
|
| 12 |
flash-attn
|
| 13 |
+
deepspeed
|
| 14 |
+
einops
|
| 15 |
+
|
scripts/finetune.py
CHANGED
|
@@ -11,7 +11,7 @@ import torch
|
|
| 11 |
import transformers
|
| 12 |
import yaml
|
| 13 |
from attrdict import AttrDefault
|
| 14 |
-
from datasets import load_dataset, IterableDataset, Dataset
|
| 15 |
from peft import (
|
| 16 |
LoraConfig,
|
| 17 |
get_peft_model,
|
|
@@ -52,8 +52,9 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
|
| 52 |
if adapter != "lora":
|
| 53 |
raise NotImplementedError(f"{adapter} peft adapter not available")
|
| 54 |
if "llama" in base_model:
|
| 55 |
-
|
| 56 |
-
|
|
|
|
| 57 |
|
| 58 |
try:
|
| 59 |
if "llama" in base_model:
|
|
@@ -86,7 +87,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
|
| 86 |
except:
|
| 87 |
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
| 88 |
|
| 89 |
-
if tokenizer.__class__.__name__
|
| 90 |
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
| 91 |
|
| 92 |
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
|
@@ -255,8 +256,9 @@ def train(
|
|
| 255 |
return
|
| 256 |
|
| 257 |
datasets = []
|
| 258 |
-
if
|
| 259 |
-
|
|
|
|
| 260 |
else:
|
| 261 |
for d in cfg.datasets:
|
| 262 |
ds: IterableDataset = load_dataset(
|
|
@@ -288,7 +290,6 @@ def train(
|
|
| 288 |
[_ for _ in constant_len_dataset]
|
| 289 |
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
| 290 |
dataset.save_to_disk("data/last_run")
|
| 291 |
-
print(dataset)
|
| 292 |
|
| 293 |
train_dataset = dataset["train"]
|
| 294 |
eval_dataset = dataset["test"]
|
|
|
|
| 11 |
import transformers
|
| 12 |
import yaml
|
| 13 |
from attrdict import AttrDefault
|
| 14 |
+
from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
|
| 15 |
from peft import (
|
| 16 |
LoraConfig,
|
| 17 |
get_peft_model,
|
|
|
|
| 52 |
if adapter != "lora":
|
| 53 |
raise NotImplementedError(f"{adapter} peft adapter not available")
|
| 54 |
if "llama" in base_model:
|
| 55 |
+
if cfg.device not in ["mps", "cpu"]:
|
| 56 |
+
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
| 57 |
+
replace_llama_attn_with_flash_attn()
|
| 58 |
|
| 59 |
try:
|
| 60 |
if "llama" in base_model:
|
|
|
|
| 87 |
except:
|
| 88 |
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
| 89 |
|
| 90 |
+
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
| 91 |
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
| 92 |
|
| 93 |
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
|
|
|
| 256 |
return
|
| 257 |
|
| 258 |
datasets = []
|
| 259 |
+
if not isinstance(cfg.datasets, list) and isinstance(cfg.datasets, str):
|
| 260 |
+
# assumption that we are loading a previously saved/cached dataset
|
| 261 |
+
dataset = load_from_disk(cfg.datasets)
|
| 262 |
else:
|
| 263 |
for d in cfg.datasets:
|
| 264 |
ds: IterableDataset = load_dataset(
|
|
|
|
| 290 |
[_ for _ in constant_len_dataset]
|
| 291 |
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
| 292 |
dataset.save_to_disk("data/last_run")
|
|
|
|
| 293 |
|
| 294 |
train_dataset = dataset["train"]
|
| 295 |
eval_dataset = dataset["test"]
|
setup.cfg
CHANGED
|
@@ -23,6 +23,7 @@ install_requires =
|
|
| 23 |
sentencepiece
|
| 24 |
wandb
|
| 25 |
flash-attn
|
|
|
|
| 26 |
|
| 27 |
[options.packages.find]
|
| 28 |
where = src
|
|
|
|
| 23 |
sentencepiece
|
| 24 |
wandb
|
| 25 |
flash-attn
|
| 26 |
+
einops
|
| 27 |
|
| 28 |
[options.packages.find]
|
| 29 |
where = src
|
src/axolotl/datasets.py
CHANGED
|
@@ -93,22 +93,24 @@ class ConstantLengthDataset(IterableDataset):
|
|
| 93 |
buffer_len = 0
|
| 94 |
|
| 95 |
if example:
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
| 93 |
buffer_len = 0
|
| 94 |
|
| 95 |
if example:
|
| 96 |
+
# just going to drop data points that are too long
|
| 97 |
+
if len(example["input_ids"]) <= self.seq_length:
|
| 98 |
+
input_ids = example["input_ids"]
|
| 99 |
+
attention_mask = example["attention_mask"]
|
| 100 |
+
labels = example["labels"]
|
| 101 |
+
|
| 102 |
+
if add_concat_token:
|
| 103 |
+
input_ids.append(self.concat_token_id)
|
| 104 |
+
attention_mask.append(1)
|
| 105 |
+
labels.append(self.concat_token_id)
|
| 106 |
+
|
| 107 |
+
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
|
| 108 |
+
attention_mask_with_concat = torch.tensor(
|
| 109 |
+
attention_mask, dtype=torch.long
|
| 110 |
+
)
|
| 111 |
+
labels_with_concat = torch.tensor(labels, dtype=torch.long)
|
| 112 |
+
|
| 113 |
+
buffer["input_ids"].append(input_ids_with_concat)
|
| 114 |
+
buffer["attention_mask"].append(attention_mask_with_concat)
|
| 115 |
+
buffer["labels"].append(labels_with_concat)
|
| 116 |
+
buffer_len += len(input_ids)
|