|
|
.. _nn: |
|
|
|
|
|
.. currentmodule:: mlx.nn |
|
|
|
|
|
Neural Networks |
|
|
=============== |
|
|
|
|
|
Writing arbitrarily complex neural networks in MLX can be done using only |
|
|
:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the |
|
|
user to write again and again the same simple neural network operations as well |
|
|
as handle all the parameter state and initialization manually and explicitly. |
|
|
|
|
|
The module :mod:`mlx.nn` solves this problem by providing an intuitive way of |
|
|
composing neural network layers, initializing their parameters, freezing them |
|
|
for finetuning and more. |
|
|
|
|
|
Quick Start with Neural Networks |
|
|
--------------------------------- |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
import mlx.core as mx |
|
|
import mlx.nn as nn |
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, in_dims: int, out_dims: int): |
|
|
super().__init__() |
|
|
|
|
|
self.layers = [ |
|
|
nn.Linear(in_dims, 128), |
|
|
nn.Linear(128, 128), |
|
|
nn.Linear(128, out_dims), |
|
|
] |
|
|
|
|
|
def __call__(self, x): |
|
|
for i, l in enumerate(self.layers): |
|
|
x = mx.maximum(x, 0) if i > 0 else x |
|
|
x = l(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
mlp = MLP(2, 10) |
|
|
|
|
|
|
|
|
params = mlp.parameters() |
|
|
print(params["layers"][0]["weight"].shape) |
|
|
|
|
|
|
|
|
print(params["layers"][0]) |
|
|
|
|
|
|
|
|
mx.eval(mlp.parameters()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def l2_loss(x, y): |
|
|
y_hat = mlp(x) |
|
|
return (y_hat - y).square().mean() |
|
|
|
|
|
|
|
|
|
|
|
loss_and_grad = nn.value_and_grad(mlp, l2_loss) |
|
|
|
|
|
.. _module_class: |
|
|
|
|
|
The Module Class |
|
|
---------------- |
|
|
|
|
|
The workhorse of any neural network library is the :class:`Module` class. In |
|
|
MLX the :class:`Module` class is a container of :class:`mlx.core.array` or |
|
|
:class:`Module` instances. Its main function is to provide a way to |
|
|
recursively **access** and **update** its parameters and those of its |
|
|
submodules. |
|
|
|
|
|
Parameters |
|
|
^^^^^^^^^^ |
|
|
|
|
|
A parameter of a module is any public member of type :class:`mlx.core.array` (its |
|
|
name should not start with ``_``). It can be arbitrarily nested in other |
|
|
:class:`Module` instances or lists and dictionaries. |
|
|
|
|
|
:meth:`Module.parameters` can be used to extract a nested dictionary with all |
|
|
the parameters of a module and its submodules. |
|
|
|
|
|
A :class:`Module` can also keep track of "frozen" parameters. See the |
|
|
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad` |
|
|
the gradients returned will be with respect to these trainable parameters. |
|
|
|
|
|
|
|
|
Updating the Parameters |
|
|
^^^^^^^^^^^^^^^^^^^^^^^ |
|
|
|
|
|
MLX modules allow accessing and updating individual parameters. However, most |
|
|
times we need to update large subsets of a module's parameters. This action is |
|
|
performed by :meth:`Module.update`. |
|
|
|
|
|
|
|
|
Inspecting Modules |
|
|
^^^^^^^^^^^^^^^^^^ |
|
|
|
|
|
The simplest way to see the model architecture is to print it. Following along with |
|
|
the above example, you can print the ``MLP`` with: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
print(mlp) |
|
|
|
|
|
This will display: |
|
|
|
|
|
.. code-block:: shell |
|
|
|
|
|
MLP( |
|
|
(layers.0): Linear(input_dims=2, output_dims=128, bias=True) |
|
|
(layers.1): Linear(input_dims=128, output_dims=128, bias=True) |
|
|
(layers.2): Linear(input_dims=128, output_dims=10, bias=True) |
|
|
) |
|
|
|
|
|
To get more detailed information on the arrays in a :class:`Module` you can use |
|
|
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of |
|
|
all the parameters in a :class:`Module` do: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
from mlx.utils import tree_map |
|
|
shapes = tree_map(lambda p: p.shape, mlp.parameters()) |
|
|
|
|
|
As another example, you can count the number of parameters in a :class:`Module` |
|
|
with: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
from mlx.utils import tree_flatten |
|
|
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters())) |
|
|
|
|
|
|
|
|
Value and Grad |
|
|
-------------- |
|
|
|
|
|
Using a :class:`Module` does not preclude using MLX's high order function |
|
|
transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However, |
|
|
these function transformations assume pure functions, namely the parameters |
|
|
should be passed as an argument to the function being transformed. |
|
|
|
|
|
There is an easy pattern to achieve that with MLX modules |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
model = ... |
|
|
|
|
|
def f(params, other_inputs): |
|
|
model.update(params) |
|
|
return model(other_inputs) |
|
|
|
|
|
f(model.trainable_parameters(), mx.zeros((10,))) |
|
|
|
|
|
However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only |
|
|
computes the gradients with respect to the trainable parameters of the model. |
|
|
|
|
|
In detail: |
|
|
|
|
|
- it wraps the passed function with a function that calls :meth:`Module.update` |
|
|
to make sure the model is using the provided parameters. |
|
|
- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function |
|
|
that also computes the gradients with respect to the passed parameters. |
|
|
- it wraps the returned function with a function that passes the trainable |
|
|
parameters as the first argument to the function returned by |
|
|
:meth:`mlx.core.value_and_grad` |
|
|
|
|
|
.. autosummary:: |
|
|
:toctree: _autosummary |
|
|
|
|
|
value_and_grad |
|
|
quantize |
|
|
average_gradients |
|
|
|
|
|
.. toctree:: |
|
|
|
|
|
nn/module |
|
|
nn/layers |
|
|
nn/functions |
|
|
nn/losses |
|
|
nn/init |
|
|
|