File size: 913 Bytes
fc7d689 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | import equinox as eqx
def save_model(filename, model):
"""
Serialize and save a model to a file.
Parameters
----------
filename: `str`
The name of the file to save the model to.
The file extension must be `.eqx`.
model: `eqx.Module`
The model to save.
"""
with open(filename, "wb") as f:
eqx.tree_serialise_leaves(f, model)
def load_model(filename, model_skeleton):
"""
Load a serialized model from a file.
Parameters
----------
filename: `str`
The name of the file to load the model from.
The file extension must be `.eqx`.
model_skeleton: `eqx.Module`
The reference skeleton of the model to load the model into.
Returns
-------
model: `eqx.Module`
The loaded model.
"""
with open(filename, "rb") as f:
return eqx.tree_deserialise_leaves(f, model_skeleton)
|