import equinox as eqx def save_model(filename, model): """ Serialize and save a model to a file. Parameters ---------- filename: `str` The name of the file to save the model to. The file extension must be `.eqx`. model: `eqx.Module` The model to save. """ with open(filename, "wb") as f: eqx.tree_serialise_leaves(f, model) def load_model(filename, model_skeleton): """ Load a serialized model from a file. Parameters ---------- filename: `str` The name of the file to load the model from. The file extension must be `.eqx`. model_skeleton: `eqx.Module` The reference skeleton of the model to load the model into. Returns ------- model: `eqx.Module` The loaded model. """ with open(filename, "rb") as f: return eqx.tree_deserialise_leaves(f, model_skeleton)