File size: 2,446 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright © 2024 Apple Inc.

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.utils


class MLP(nn.Module):
    """A simple MLP."""

    def __init__(
        self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
    ):
        super().__init__()
        layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
        self.layers = [
            nn.Linear(idim, odim)
            for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
        ]

    def __call__(self, x):
        for l in self.layers[:-1]:
            x = nn.relu(l(x))
        return self.layers[-1](x)


if __name__ == "__main__":

    batch_size = 8
    input_dim = 32
    output_dim = 10

    def init():
        # Seed for the parameter initialization
        mx.random.seed(0)
        model = MLP(
            num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
        )
        optimizer = optim.SGD(learning_rate=1e-1)
        optimizer.init(model.parameters())
        state = [model.parameters(), optimizer.state]
        tree_structure, state = zip(*mlx.utils.tree_flatten(state))
        return model, optimizer, tree_structure, state

    # Export the model parameter initialization
    model, optimizer, tree_structure, state = init()
    mx.eval(state)
    mx.export_function("init_mlp.mlxfn", lambda: init()[-1])

    def loss_fn(params, X, y):
        model.update(params)
        return nn.losses.cross_entropy(model(X), y, reduction="mean")

    def step(*inputs):
        *state, X, y = inputs
        params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
        optimizer.state = opt_state
        loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
        params = optimizer.apply_gradients(grads, params)
        _, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
        return *state, loss

    # Make some random data
    mx.random.seed(42)
    example_X = mx.random.normal(shape=(batch_size, input_dim))
    example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
    mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)

    # Export one step of SGD
    imported_step = mx.import_function("train_mlp.mlxfn")

    for it in range(100):
        *state, loss = imported_step(*state, example_X, example_y)
        if it % 10 == 0:
            print(f"Loss {loss.item():.6}")