Spaces:
Runtime error
Runtime error
feat: update defaults
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -253,7 +253,7 @@ class DataTrainingArguments:
|
|
| 253 |
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
| 254 |
)
|
| 255 |
save_model_steps: Optional[int] = field(
|
| 256 |
-
default=
|
| 257 |
metadata={
|
| 258 |
"help": "For logging the model more frequently. Used only when `log_model` is set."
|
| 259 |
},
|
|
@@ -290,9 +290,9 @@ class DataTrainingArguments:
|
|
| 290 |
|
| 291 |
|
| 292 |
class TrainState(train_state.TrainState):
|
| 293 |
-
dropout_rng: jnp.ndarray
|
| 294 |
-
grad_accum: jnp.ndarray
|
| 295 |
-
optimizer_step: int
|
| 296 |
|
| 297 |
def replicate(self):
|
| 298 |
return jax_utils.replicate(self).replace(
|
|
|
|
| 253 |
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
| 254 |
)
|
| 255 |
save_model_steps: Optional[int] = field(
|
| 256 |
+
default=5000, # about once every 1.5h in our experiments
|
| 257 |
metadata={
|
| 258 |
"help": "For logging the model more frequently. Used only when `log_model` is set."
|
| 259 |
},
|
|
|
|
| 290 |
|
| 291 |
|
| 292 |
class TrainState(train_state.TrainState):
|
| 293 |
+
dropout_rng: jnp.ndarray = None
|
| 294 |
+
grad_accum: jnp.ndarray = None
|
| 295 |
+
optimizer_step: int = None
|
| 296 |
|
| 297 |
def replicate(self):
|
| 298 |
return jax_utils.replicate(self).replace(
|