| |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
| import mlx.optimizers as optim |
| import mlx.utils |
|
|
|
|
| class MLP(nn.Module): |
| """A simple MLP.""" |
|
|
| def __init__( |
| self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int |
| ): |
| super().__init__() |
| layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim] |
| self.layers = [ |
| nn.Linear(idim, odim) |
| for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) |
| ] |
|
|
| def __call__(self, x): |
| for l in self.layers[:-1]: |
| x = nn.relu(l(x)) |
| return self.layers[-1](x) |
|
|
|
|
| if __name__ == "__main__": |
|
|
| batch_size = 8 |
| input_dim = 32 |
| output_dim = 10 |
|
|
| def init(): |
| |
| mx.random.seed(0) |
| model = MLP( |
| num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim |
| ) |
| optimizer = optim.SGD(learning_rate=1e-1) |
| optimizer.init(model.parameters()) |
| state = [model.parameters(), optimizer.state] |
| tree_structure, state = zip(*mlx.utils.tree_flatten(state)) |
| return model, optimizer, tree_structure, state |
|
|
| |
| model, optimizer, tree_structure, state = init() |
| mx.eval(state) |
| mx.export_function("init_mlp.mlxfn", lambda: init()[-1]) |
|
|
| def loss_fn(params, X, y): |
| model.update(params) |
| return nn.losses.cross_entropy(model(X), y, reduction="mean") |
|
|
| def step(*inputs): |
| *state, X, y = inputs |
| params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state))) |
| optimizer.state = opt_state |
| loss, grads = mx.value_and_grad(loss_fn)(params, X, y) |
| params = optimizer.apply_gradients(grads, params) |
| _, state = zip(*mlx.utils.tree_flatten([params, optimizer.state])) |
| return *state, loss |
|
|
| |
| mx.random.seed(42) |
| example_X = mx.random.normal(shape=(batch_size, input_dim)) |
| example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,)) |
| mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y) |
|
|
| |
| imported_step = mx.import_function("train_mlp.mlxfn") |
|
|
| for it in range(100): |
| *state, loss = imported_step(*state, example_X, example_y) |
| if it % 10 == 0: |
| print(f"Loss {loss.item():.6}") |
|
|