Spaces:
Runtime error
Runtime error
fix: style
Browse files- tools/train/train.py +6 -2
tools/train/train.py
CHANGED
|
@@ -647,7 +647,9 @@ def main():
|
|
| 647 |
|
| 648 |
# add gradient accumulation
|
| 649 |
if training_args.gradient_accumulation_steps > 1:
|
| 650 |
-
optimizer = optax.MultiSteps(
|
|
|
|
|
|
|
| 651 |
|
| 652 |
# Setup train state
|
| 653 |
state = TrainState.create(
|
|
@@ -691,7 +693,9 @@ def main():
|
|
| 691 |
|
| 692 |
metrics = {
|
| 693 |
"loss": loss,
|
| 694 |
-
"learning_rate": learning_rate_fn(
|
|
|
|
|
|
|
| 695 |
}
|
| 696 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
| 697 |
|
|
|
|
| 647 |
|
| 648 |
# add gradient accumulation
|
| 649 |
if training_args.gradient_accumulation_steps > 1:
|
| 650 |
+
optimizer = optax.MultiSteps(
|
| 651 |
+
optimizer, training_args.gradient_accumulation_steps
|
| 652 |
+
)
|
| 653 |
|
| 654 |
# Setup train state
|
| 655 |
state = TrainState.create(
|
|
|
|
| 693 |
|
| 694 |
metrics = {
|
| 695 |
"loss": loss,
|
| 696 |
+
"learning_rate": learning_rate_fn(
|
| 697 |
+
state.step // training_args.gradient_accumulation_steps
|
| 698 |
+
),
|
| 699 |
}
|
| 700 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
| 701 |
|