Update run_mlm_flax.py
Browse files- run_mlm_flax.py +2 -1
run_mlm_flax.py
CHANGED
|
@@ -786,7 +786,8 @@ def main():
|
|
| 786 |
return new_state, metrics, new_dropout_rng
|
| 787 |
|
| 788 |
# Create parallel version of the train step
|
| 789 |
-
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
|
|
|
| 790 |
|
| 791 |
# Define eval fn
|
| 792 |
def eval_step(params, batch):
|
|
|
|
| 786 |
return new_state, metrics, new_dropout_rng
|
| 787 |
|
| 788 |
# Create parallel version of the train step
|
| 789 |
+
# p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
| 790 |
+
p_train_step = jax.pmap(train_step, "batch")
|
| 791 |
|
| 792 |
# Define eval fn
|
| 793 |
def eval_step(params, batch):
|