Spaces:
Runtime error
Runtime error
add tokenizer save to wandb:
Browse filesFormer-commit-id: 36b4af0d456410a4c2996d1476525e91205d3d1c
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -811,13 +811,16 @@ def main():
|
|
| 811 |
params=params,
|
| 812 |
)
|
| 813 |
|
|
|
|
|
|
|
|
|
|
| 814 |
# save state
|
| 815 |
state = unreplicate(state)
|
| 816 |
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
| 817 |
f.write(to_bytes(state.opt_state))
|
| 818 |
with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
|
| 819 |
json.dump({'step': state.step.item()}, f)
|
| 820 |
-
|
| 821 |
# save to W&B
|
| 822 |
if data_args.log_model:
|
| 823 |
metadata = {'step': step, 'epoch': epoch}
|
|
@@ -827,6 +830,11 @@ def main():
|
|
| 827 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
| 828 |
)
|
| 829 |
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 830 |
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
| 831 |
artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
|
| 832 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|
|
|
|
| 811 |
params=params,
|
| 812 |
)
|
| 813 |
|
| 814 |
+
# save tokenizer
|
| 815 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
| 816 |
+
|
| 817 |
# save state
|
| 818 |
state = unreplicate(state)
|
| 819 |
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
| 820 |
f.write(to_bytes(state.opt_state))
|
| 821 |
with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
|
| 822 |
json.dump({'step': state.step.item()}, f)
|
| 823 |
+
|
| 824 |
# save to W&B
|
| 825 |
if data_args.log_model:
|
| 826 |
metadata = {'step': step, 'epoch': epoch}
|
|
|
|
| 830 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
| 831 |
)
|
| 832 |
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
| 833 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer_config.json'))
|
| 834 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'special_tokens_map.json'))
|
| 835 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'vocab.json'))
|
| 836 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'added_tokens.json'))
|
| 837 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'merges.txt'))
|
| 838 |
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
| 839 |
artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
|
| 840 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|