Spaces:
Runtime error
Runtime error
feat(train): use MultiSteps for gradient accumulation
Browse files- tools/train/train.py +2 -4
tools/train/train.py
CHANGED
|
@@ -647,9 +647,7 @@ def main():
|
|
| 647 |
|
| 648 |
# add gradient accumulation
|
| 649 |
if training_args.gradient_accumulation_steps > 1:
|
| 650 |
-
optimizer = optax.
|
| 651 |
-
optax.apply_every(training_args.gradient_accumulation_steps), optimizer
|
| 652 |
-
)
|
| 653 |
|
| 654 |
# Setup train state
|
| 655 |
state = TrainState.create(
|
|
@@ -693,7 +691,7 @@ def main():
|
|
| 693 |
|
| 694 |
metrics = {
|
| 695 |
"loss": loss,
|
| 696 |
-
"learning_rate": learning_rate_fn(state.step),
|
| 697 |
}
|
| 698 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
| 699 |
|
|
|
|
| 647 |
|
| 648 |
# add gradient accumulation
|
| 649 |
if training_args.gradient_accumulation_steps > 1:
|
| 650 |
+
optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
|
|
|
|
|
|
|
| 651 |
|
| 652 |
# Setup train state
|
| 653 |
state = TrainState.create(
|
|
|
|
| 691 |
|
| 692 |
metrics = {
|
| 693 |
"loss": loss,
|
| 694 |
+
"learning_rate": learning_rate_fn(state.step // training_args.gradient_accumulation_steps),
|
| 695 |
}
|
| 696 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
| 697 |
|