File size: 1,364 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright © 2024 Apple Inc.

import mlx.core as mx
import mlx.nn as nn
import mlx.utils


class MLP(nn.Module):
    """A simple MLP."""

    def __init__(
        self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
    ):
        super().__init__()
        layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
        self.layers = [
            nn.Linear(idim, odim)
            for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
        ]

    def __call__(self, x):
        for l in self.layers[:-1]:
            x = nn.relu(l(x))
        return self.layers[-1](x)


if __name__ == "__main__":

    batch_size = 8
    input_dim = 32
    output_dim = 10

    # Load the model
    mx.random.seed(0)  # Seed for params
    model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
    mx.eval(model)

    # Note, the model parameters are saved in the export function
    def forward(x):
        return model(x)

    mx.random.seed(42)  # Seed for input
    example_x = mx.random.uniform(shape=(batch_size, input_dim))

    mx.export_function("eval_mlp.mlxfn", forward, example_x)

    # Import in Python
    imported_forward = mx.import_function("eval_mlp.mlxfn")
    expected = forward(example_x)
    (out,) = imported_forward(example_x)
    assert mx.allclose(expected, out)
    print(out)