Merge pull request #19 from NanoCode012/feat/callback-save-lora
Browse files
src/axolotl/utils/callbacks.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
|
| 4 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 5 |
+
|
| 6 |
+
class SavePeftModelCallback(TrainerCallback):
|
| 7 |
+
def on_save(
|
| 8 |
+
self,
|
| 9 |
+
args: TrainingArguments,
|
| 10 |
+
state: TrainerState,
|
| 11 |
+
control: TrainerControl,
|
| 12 |
+
**kwargs,
|
| 13 |
+
):
|
| 14 |
+
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
| 15 |
+
|
| 16 |
+
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
| 17 |
+
kwargs["model"].save_pretrained(peft_model_path)
|
| 18 |
+
|
| 19 |
+
return control
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -13,6 +13,7 @@ from transformers import EarlyStoppingCallback
|
|
| 13 |
from transformers.trainer_pt_utils import get_parameter_names
|
| 14 |
|
| 15 |
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
@@ -188,6 +189,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 188 |
data_collator_kwargs["padding"] = "longest"
|
| 189 |
else:
|
| 190 |
data_collator_kwargs["pad_to_multiple_of"] = 8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
trainer = transformers.Trainer(
|
| 192 |
model=model,
|
| 193 |
train_dataset=train_dataset,
|
|
@@ -198,6 +204,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 198 |
return_tensors="pt",
|
| 199 |
**data_collator_kwargs,
|
| 200 |
),
|
|
|
|
| 201 |
**trainer_kwargs,
|
| 202 |
)
|
| 203 |
|
|
|
|
| 13 |
from transformers.trainer_pt_utils import get_parameter_names
|
| 14 |
|
| 15 |
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
| 16 |
+
from axolotl.utils.callbacks import SavePeftModelCallback
|
| 17 |
|
| 18 |
|
| 19 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
|
|
| 189 |
data_collator_kwargs["padding"] = "longest"
|
| 190 |
else:
|
| 191 |
data_collator_kwargs["pad_to_multiple_of"] = 8
|
| 192 |
+
|
| 193 |
+
callbacks = []
|
| 194 |
+
if cfg.adapter == 'lora':
|
| 195 |
+
callbacks.append(SavePeftModelCallback)
|
| 196 |
+
|
| 197 |
trainer = transformers.Trainer(
|
| 198 |
model=model,
|
| 199 |
train_dataset=train_dataset,
|
|
|
|
| 204 |
return_tensors="pt",
|
| 205 |
**data_collator_kwargs,
|
| 206 |
),
|
| 207 |
+
callbacks=callbacks,
|
| 208 |
**trainer_kwargs,
|
| 209 |
)
|
| 210 |
|