File size: 913 Bytes
fc7d689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import equinox as eqx


def save_model(filename, model):
    """
    Serialize and save a model to a file.

    Parameters
    ----------
    filename: `str`
        The name of the file to save the model to.
        The file extension must be `.eqx`.
    model: `eqx.Module`
        The model to save.
    """
    with open(filename, "wb") as f:
        eqx.tree_serialise_leaves(f, model)


def load_model(filename, model_skeleton):
    """
    Load a serialized model from a file.

    Parameters
    ----------
    filename: `str`
        The name of the file to load the model from.
        The file extension must be `.eqx`.
    model_skeleton: `eqx.Module`
        The reference skeleton of the model to load the model into.

    Returns
    -------
    model: `eqx.Module`
        The loaded model.
    """
    with open(filename, "rb") as f:
        return eqx.tree_deserialise_leaves(f, model_skeleton)