Spaces:
Runtime error
Runtime error
fix: comment
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -753,6 +753,7 @@ def main():
|
|
| 753 |
# restore optimizer state and step
|
| 754 |
state = state.restore_state(artifact_dir)
|
| 755 |
# TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
|
|
|
|
| 756 |
|
| 757 |
# label smoothed cross entropy
|
| 758 |
def loss_fn(logits, labels):
|
|
@@ -937,7 +938,7 @@ def main():
|
|
| 937 |
for epoch in epochs:
|
| 938 |
# ======================== Training ================================
|
| 939 |
step = unreplicate(state.step)
|
| 940 |
-
|
| 941 |
|
| 942 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 943 |
if data_args.streaming:
|
|
|
|
| 753 |
# restore optimizer state and step
|
| 754 |
state = state.restore_state(artifact_dir)
|
| 755 |
# TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
|
| 756 |
+
# TODO: optimizer may use a different step for learning rate, we should serialize/restore entire state
|
| 757 |
|
| 758 |
# label smoothed cross entropy
|
| 759 |
def loss_fn(logits, labels):
|
|
|
|
| 938 |
for epoch in epochs:
|
| 939 |
# ======================== Training ================================
|
| 940 |
step = unreplicate(state.step)
|
| 941 |
+
wandb_log({"train/epoch": epoch}, step=step)
|
| 942 |
|
| 943 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 944 |
if data_args.streaming:
|