Commit ·
98dfb11
1
Parent(s): a9f9b4a
avoid pushing checkpoints
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
|
@@ -797,6 +797,9 @@ def main():
|
|
| 797 |
)
|
| 798 |
logger.info("*** Trainer initialized ***")
|
| 799 |
|
|
|
|
|
|
|
|
|
|
| 800 |
# 12. Training
|
| 801 |
if training_args.do_train:
|
| 802 |
logger.info("*** Train ***")
|
|
@@ -812,10 +815,7 @@ def main():
|
|
| 812 |
# We don't want to push the model to the hub now
|
| 813 |
# so we temporarily set to false the push_to_hub attribute
|
| 814 |
# and then reset it to the original value
|
| 815 |
-
orig_push_to_hub = trainer.args.push_to_hub
|
| 816 |
-
trainer.args.push_to_hub = False
|
| 817 |
trainer.save_model() # Saves the feature extractor too for easy upload
|
| 818 |
-
trainer.args.push_to_hub = orig_push_to_hub
|
| 819 |
logger.info("*** Model saved ***")
|
| 820 |
metrics = train_result.metrics
|
| 821 |
if data_args.max_train_samples:
|
|
@@ -909,7 +909,7 @@ def main():
|
|
| 909 |
notify_me(recipient=RECIPIENT_ADDRESS,
|
| 910 |
message=f"Training complete! {train_results = } {eval_results = }")
|
| 911 |
|
| 912 |
-
|
| 913 |
if training_args.push_to_hub:
|
| 914 |
logger.info("*** Pushing to hub ***")
|
| 915 |
trainer.push_to_hub(**kwargs)
|
|
|
|
| 797 |
)
|
| 798 |
logger.info("*** Trainer initialized ***")
|
| 799 |
|
| 800 |
+
orig_push_to_hub = trainer.args.push_to_hub
|
| 801 |
+
trainer.args.push_to_hub = False
|
| 802 |
+
|
| 803 |
# 12. Training
|
| 804 |
if training_args.do_train:
|
| 805 |
logger.info("*** Train ***")
|
|
|
|
| 815 |
# We don't want to push the model to the hub now
|
| 816 |
# so we temporarily set to false the push_to_hub attribute
|
| 817 |
# and then reset it to the original value
|
|
|
|
|
|
|
| 818 |
trainer.save_model() # Saves the feature extractor too for easy upload
|
|
|
|
| 819 |
logger.info("*** Model saved ***")
|
| 820 |
metrics = train_result.metrics
|
| 821 |
if data_args.max_train_samples:
|
|
|
|
| 909 |
notify_me(recipient=RECIPIENT_ADDRESS,
|
| 910 |
message=f"Training complete! {train_results = } {eval_results = }")
|
| 911 |
|
| 912 |
+
trainer.args.push_to_hub = orig_push_to_hub
|
| 913 |
if training_args.push_to_hub:
|
| 914 |
logger.info("*** Pushing to hub ***")
|
| 915 |
trainer.push_to_hub(**kwargs)
|