Spaces:
Runtime error
Runtime error
fix: comments
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -605,8 +605,8 @@ def main():
|
|
| 605 |
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
| 606 |
optimizer = optax.adafactor(
|
| 607 |
learning_rate=learning_rate_fn,
|
| 608 |
-
|
| 609 |
-
|
| 610 |
)
|
| 611 |
else:
|
| 612 |
optimizer = optax.adamw(
|
|
|
|
| 605 |
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
| 606 |
optimizer = optax.adafactor(
|
| 607 |
learning_rate=learning_rate_fn,
|
| 608 |
+
weight_decay_rate=training_args.weight_decay,
|
| 609 |
+
weight_decay_mask=decay_mask_fn,
|
| 610 |
)
|
| 611 |
else:
|
| 612 |
optimizer = optax.adamw(
|