Spaces:
Runtime error
Runtime error
Pedro Cuenca
commited on
Commit
·
566d5f2
1
Parent(s):
835ea55
Add eval_interval to evaluate and log every so often.
Browse files- seq2seq/run_seq2seq_flax.py +54 -34
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -225,6 +225,12 @@ class DataTrainingArguments:
|
|
| 225 |
"value if set."
|
| 226 |
},
|
| 227 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
log_model: bool = field(
|
| 229 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 230 |
)
|
|
@@ -738,37 +744,8 @@ def main():
|
|
| 738 |
train_time = 0
|
| 739 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
| 740 |
global_step = 0
|
| 741 |
-
for epoch in epochs:
|
| 742 |
-
# ======================== Training ================================
|
| 743 |
-
train_start = time.time()
|
| 744 |
-
|
| 745 |
-
# Create sampling rng
|
| 746 |
-
rng, input_rng = jax.random.split(rng)
|
| 747 |
-
train_metrics = []
|
| 748 |
-
|
| 749 |
-
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 750 |
-
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
| 751 |
-
steps_per_epoch = len(train_dataset) // train_batch_size
|
| 752 |
-
# train
|
| 753 |
-
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
| 754 |
-
global_step +=1
|
| 755 |
-
batch = next(train_loader)
|
| 756 |
-
state, train_metric = p_train_step(state, batch)
|
| 757 |
-
train_metrics.append(train_metric)
|
| 758 |
-
|
| 759 |
-
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
| 760 |
-
for k, v in unreplicate(train_metric).items():
|
| 761 |
-
wandb.log({"train/step": global_step})
|
| 762 |
-
wandb.log({f"train/{k}": jax.device_get(v)})
|
| 763 |
-
|
| 764 |
-
train_time += time.time() - train_start
|
| 765 |
-
|
| 766 |
-
train_metric = unreplicate(train_metric)
|
| 767 |
-
|
| 768 |
-
epochs.write(
|
| 769 |
-
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
| 770 |
-
)
|
| 771 |
|
|
|
|
| 772 |
# ======================== Evaluating ==============================
|
| 773 |
eval_metrics = []
|
| 774 |
if training_args.do_eval:
|
|
@@ -795,17 +772,60 @@ def main():
|
|
| 795 |
eval_metrics = get_metrics(eval_metrics)
|
| 796 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 797 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 798 |
# compute ROUGE metrics
|
| 799 |
rouge_desc = ""
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
|
| 805 |
# Print metrics and update progress bar
|
| 806 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
| 807 |
epochs.write(desc)
|
| 808 |
epochs.desc = desc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 809 |
|
| 810 |
# Save metrics
|
| 811 |
if has_tensorboard and jax.process_index() == 0:
|
|
|
|
| 225 |
"value if set."
|
| 226 |
},
|
| 227 |
)
|
| 228 |
+
eval_interval: Optional[int] = field(
|
| 229 |
+
default=40,
|
| 230 |
+
metadata={
|
| 231 |
+
"help": "Evaluation will be performed every eval_interval steps"
|
| 232 |
+
},
|
| 233 |
+
)
|
| 234 |
log_model: bool = field(
|
| 235 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 236 |
)
|
|
|
|
| 744 |
train_time = 0
|
| 745 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
| 746 |
global_step = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 747 |
|
| 748 |
+
def run_evaluation():
|
| 749 |
# ======================== Evaluating ==============================
|
| 750 |
eval_metrics = []
|
| 751 |
if training_args.do_eval:
|
|
|
|
| 772 |
eval_metrics = get_metrics(eval_metrics)
|
| 773 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 774 |
|
| 775 |
+
if jax.process_index() == 0:
|
| 776 |
+
for k, v in eval_metrics.items():
|
| 777 |
+
wandb.log({"eval/step": global_step})
|
| 778 |
+
wandb.log({f"eval/{k}": jax.device_get(v)})
|
| 779 |
+
|
| 780 |
# compute ROUGE metrics
|
| 781 |
rouge_desc = ""
|
| 782 |
+
# if data_args.predict_with_generate:
|
| 783 |
+
# rouge_metrics = compute_metrics(eval_preds, eval_labels)
|
| 784 |
+
# eval_metrics.update(rouge_metrics)
|
| 785 |
+
# rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
|
| 786 |
|
| 787 |
# Print metrics and update progress bar
|
| 788 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
| 789 |
epochs.write(desc)
|
| 790 |
epochs.desc = desc
|
| 791 |
+
return eval_metrics
|
| 792 |
+
|
| 793 |
+
for epoch in epochs:
|
| 794 |
+
# ======================== Training ================================
|
| 795 |
+
train_start = time.time()
|
| 796 |
+
|
| 797 |
+
# Create sampling rng
|
| 798 |
+
rng, input_rng = jax.random.split(rng)
|
| 799 |
+
train_metrics = []
|
| 800 |
+
|
| 801 |
+
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 802 |
+
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
| 803 |
+
steps_per_epoch = len(train_dataset) // train_batch_size
|
| 804 |
+
# train
|
| 805 |
+
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
| 806 |
+
global_step +=1
|
| 807 |
+
batch = next(train_loader)
|
| 808 |
+
state, train_metric = p_train_step(state, batch)
|
| 809 |
+
train_metrics.append(train_metric)
|
| 810 |
+
|
| 811 |
+
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
| 812 |
+
print("logging train loss")
|
| 813 |
+
for k, v in unreplicate(train_metric).items():
|
| 814 |
+
wandb.log({"train/step": global_step})
|
| 815 |
+
wandb.log({f"train/{k}": jax.device_get(v)})
|
| 816 |
+
|
| 817 |
+
if global_step % data_args.eval_interval == 0 and jax.process_index() == 0:
|
| 818 |
+
run_evaluation()
|
| 819 |
+
|
| 820 |
+
train_time += time.time() - train_start
|
| 821 |
+
|
| 822 |
+
train_metric = unreplicate(train_metric)
|
| 823 |
+
|
| 824 |
+
epochs.write(
|
| 825 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
eval_metrics = run_evaluation()
|
| 829 |
|
| 830 |
# Save metrics
|
| 831 |
if has_tensorboard and jax.process_index() == 0:
|