| from pathlib import Path | |
| import flax | |
| import jax | |
| import jax.numpy as jnp | |
| def to_f32(t): | |
| return jax.tree_map( | |
| lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t | |
| ) | |
| data = flax.serialization.msgpack_restore( | |
| Path("output/flax_model.msgpack").read_bytes() | |
| ) | |
| transformed_data = to_f32(data) | |
| Path("output/flax_model_f32.msgpack").write_bytes( | |
| flax.serialization.msgpack_serialize(transformed_data) | |
| ) | |