| | |
| | |
| | |
| | |
| |
|
| | import pprint |
| | from functools import partial |
| | import os |
| | import numpy as np |
| | import jax |
| | import jax.numpy as jnp |
| | import flax.serialization |
| | import mlxu |
| | from EasyLM.checkpoint import StreamingCheckpointer |
| | from EasyLM.jax_utils import float_to_dtype |
| |
|
| |
|
| | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( |
| | recover_diff=False, |
| | load_base_checkpoint='', |
| | load_target_checkpoint='', |
| | output_file='', |
| | streaming=True, |
| | float_dtype='bf16', |
| | ) |
| |
|
| |
|
| | def main(argv): |
| | assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != '' |
| | assert FLAGS.output_file != '' |
| | base_params = StreamingCheckpointer.load_trainstate_checkpoint( |
| | FLAGS.load_base_checkpoint, disallow_trainstate=True |
| | )[1]['params'] |
| |
|
| | target_params = StreamingCheckpointer.load_trainstate_checkpoint( |
| | FLAGS.load_target_checkpoint, disallow_trainstate=True |
| | )[1]['params'] |
| |
|
| | if FLAGS.recover_diff: |
| | params = jax.tree_util.tree_map( |
| | lambda b, t: b + t, base_params, target_params |
| | ) |
| | else: |
| | params = jax.tree_util.tree_map( |
| | lambda b, t: t - b, base_params, target_params |
| | ) |
| |
|
| | if FLAGS.streaming: |
| | StreamingCheckpointer.save_train_state_to_file( |
| | params, FLAGS.output_file, float_dtype=FLAGS.float_dtype |
| | ) |
| | else: |
| | params = float_to_dtype(params, FLAGS.float_dtype) |
| | with mlxu.open_file(FLAGS.output, 'wb') as fout: |
| | fout.write(flax.serialization.msgpack_serialize(params, in_place=True)) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | mlxu.run(main) |
| |
|