wire up gradient checkpointing for 4bit
Browse files
src/axolotl/utils/trainer.py
CHANGED
|
@@ -28,7 +28,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 28 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
| 29 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
| 30 |
if cfg.gradient_checkpointing is not None:
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# deepspeed
|
| 34 |
if os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" and torch.cuda.device_count() > 1:
|
|
|
|
| 28 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
| 29 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
| 30 |
if cfg.gradient_checkpointing is not None:
|
| 31 |
+
if cfg.load_4bit:
|
| 32 |
+
from alpaca_lora_4bit.gradient_checkpointing import apply_gradient_checkpointing
|
| 33 |
+
gradient_checkpointing_ratio = cfg.gradient_checkpointing_ratio if cfg.gradient_checkpointing_ratio else 1.0
|
| 34 |
+
apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio)
|
| 35 |
+
else:
|
| 36 |
+
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
| 37 |
+
|
| 38 |
|
| 39 |
# deepspeed
|
| 40 |
if os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" and torch.cuda.device_count() > 1:
|