add code for save checkpoint
Browse files
main.py
CHANGED
|
@@ -9,6 +9,7 @@ from transformers.trainer_utils import get_last_checkpoint
|
|
| 9 |
import json
|
| 10 |
import os, glob
|
| 11 |
from callbacks import BreakEachEpoch
|
|
|
|
| 12 |
|
| 13 |
logging.set_verbosity_info()
|
| 14 |
|
|
@@ -85,8 +86,14 @@ def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_ma
|
|
| 85 |
return processed_dataset
|
| 86 |
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
if __name__ == "__main__":
|
| 92 |
|
|
@@ -208,3 +215,6 @@ if __name__ == "__main__":
|
|
| 208 |
# Clear cache file to free disk
|
| 209 |
test_dataset.cleanup_cache_files()
|
| 210 |
train_dataset.cleanup_cache_files()
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import json
|
| 10 |
import os, glob
|
| 11 |
from callbacks import BreakEachEpoch
|
| 12 |
+
import subprocess
|
| 13 |
|
| 14 |
logging.set_verbosity_info()
|
| 15 |
|
|
|
|
| 86 |
return processed_dataset
|
| 87 |
|
| 88 |
|
| 89 |
+
def commit_checkpoint():
|
| 90 |
+
submit_commands = [
|
| 91 |
+
'git add model-bin/finetune/base/*',
|
| 92 |
+
'git commit -m "auto commit"',
|
| 93 |
+
'git push origin main'
|
| 94 |
+
]
|
| 95 |
+
for command in submit_commands:
|
| 96 |
+
print(subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode('utf-8'))
|
| 97 |
|
| 98 |
if __name__ == "__main__":
|
| 99 |
|
|
|
|
| 215 |
# Clear cache file to free disk
|
| 216 |
test_dataset.cleanup_cache_files()
|
| 217 |
train_dataset.cleanup_cache_files()
|
| 218 |
+
|
| 219 |
+
if epoch_idx % 10 == 0:
|
| 220 |
+
commit_checkpoint()
|