| | |
| | |
| | |
| | |
| |
|
| | import pprint |
| | from functools import partial |
| | import os |
| | import numpy as np |
| | import mlxu |
| | import jax.numpy as jnp |
| | import flax.serialization |
| | from EasyLM.checkpoint import StreamingCheckpointer |
| | from EasyLM.jax_utils import float_to_dtype |
| |
|
| |
|
| | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( |
| | load_checkpoint='', |
| | output_file='', |
| | streaming=False, |
| | float_dtype='bf16', |
| | ) |
| |
|
| |
|
| | def main(argv): |
| | assert FLAGS.load_checkpoint != '' and FLAGS.output_file != '', 'input and output must be specified' |
| | params = StreamingCheckpointer.load_trainstate_checkpoint( |
| | FLAGS.load_checkpoint, disallow_trainstate=True |
| | )[1]['params'] |
| |
|
| | if FLAGS.streaming: |
| | StreamingCheckpointer.save_train_state_to_file( |
| | params, FLAGS.output_file, float_dtype=FLAGS.float_dtype |
| | ) |
| | else: |
| | params = float_to_dtype(params, FLAGS.float_dtype) |
| | with mlxu.open_file(FLAGS.output, 'wb') as fout: |
| | fout.write(flax.serialization.msgpack_serialize(params, in_place=True)) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | mlxu.run(main) |
| |
|