| | .. _function_transforms: |
| |
|
| | Function Transforms |
| | =================== |
| |
|
| | .. currentmodule:: mlx.core |
| |
|
| | MLX uses composable function transformations for automatic differentiation, |
| | vectorization, and compute graph optimizations. To see the complete list of |
| | function transformations check-out the :ref:`API documentation <transforms>`. |
| |
|
| | The key idea behind composable function transformations is that every |
| | transformation returns a function which can be further transformed. |
| |
|
| | Here is a simple example: |
| |
|
| | .. code-block:: shell |
| |
|
| | >>> dfdx = mx.grad(mx.sin) |
| | >>> dfdx(mx.array(mx.pi)) |
| | array(-1, dtype=float32) |
| | >>> mx.cos(mx.array(mx.pi)) |
| | array(-1, dtype=float32) |
| |
|
| |
|
| | The output of :func:`grad` on :func:`sin` is simply another function. In this |
| | case it is the gradient of the sine function which is exactly the cosine |
| | function. To get the second derivative you can do: |
| |
|
| | .. code-block:: shell |
| |
|
| | >>> d2fdx2 = mx.grad(mx.grad(mx.sin)) |
| | >>> d2fdx2(mx.array(mx.pi / 2)) |
| | array(-1, dtype=float32) |
| | >>> mx.sin(mx.array(mx.pi / 2)) |
| | array(1, dtype=float32) |
| |
|
| | Using :func:`grad` on the output of :func:`grad` is always ok. You keep |
| | getting higher order derivatives. |
| |
|
| | Any of the MLX function transformations can be composed in any order to any |
| | depth. See the following sections for more information on :ref:`automatic |
| | differentiation <auto diff>` and :ref:`automatic vectorization <vmap>`. |
| | For more information on :func:`compile` see the :ref:`compile documentation <compile>`. |
| |
|
| |
|
| | Automatic Differentiation |
| | ------------------------- |
| |
|
| | .. _auto diff: |
| |
|
| | Automatic differentiation in MLX works on functions rather than on implicit |
| | graphs. |
| |
|
| | .. note:: |
| |
|
| | If you are coming to MLX from PyTorch, you no longer need functions like |
| | ``backward``, ``zero_grad``, and ``detach``, or properties like |
| | ``requires_grad``. |
| |
|
| | The most basic example is taking the gradient of a scalar-valued function as we |
| | saw above. You can use the :func:`grad` and :func:`value_and_grad` function to |
| | compute gradients of more complex functions. By default these functions compute |
| | the gradient with respect to the first argument: |
| |
|
| | .. code-block:: python |
| |
|
| | def loss_fn(w, x, y): |
| | return mx.mean(mx.square(w * x - y)) |
| |
|
| | w = mx.array(1.0) |
| | x = mx.array([0.5, -0.5]) |
| | y = mx.array([1.5, -1.5]) |
| |
|
| | |
| | grad_fn = mx.grad(loss_fn) |
| | dloss_dw = grad_fn(w, x, y) |
| | |
| | print(dloss_dw) |
| |
|
| | |
| | grad_fn = mx.grad(loss_fn, argnums=1) |
| | dloss_dx = grad_fn(w, x, y) |
| | |
| | print(dloss_dx) |
| |
|
| |
|
| | One way to get the loss and gradient is to call ``loss_fn`` followed by |
| | ``grad_fn``, but this can result in a lot of redundant work. Instead, you |
| | should use :func:`value_and_grad`. Continuing the above example: |
| |
|
| |
|
| | .. code-block:: python |
| |
|
| | |
| | loss_and_grad_fn = mx.value_and_grad(loss_fn) |
| | loss, dloss_dw = loss_and_grad_fn(w, x, y) |
| |
|
| | |
| | print(loss) |
| |
|
| | |
| | print(dloss_dw) |
| |
|
| |
|
| | You can also take the gradient with respect to arbitrarily nested Python |
| | containers of arrays (specifically any of :obj:`list`, :obj:`tuple`, or |
| | :obj:`dict`). |
| |
|
| | Suppose we wanted a weight and a bias parameter in the above example. A nice |
| | way to do that is the following: |
| |
|
| | .. code-block:: python |
| |
|
| | def loss_fn(params, x, y): |
| | w, b = params["weight"], params["bias"] |
| | h = w * x + b |
| | return mx.mean(mx.square(h - y)) |
| |
|
| | params = {"weight": mx.array(1.0), "bias": mx.array(0.0)} |
| | x = mx.array([0.5, -0.5]) |
| | y = mx.array([1.5, -1.5]) |
| |
|
| | |
| | |
| | grad_fn = mx.grad(loss_fn) |
| | grads = grad_fn(params, x, y) |
| |
|
| | |
| | |
| | print(grads) |
| |
|
| | Notice the tree structure of the parameters is preserved in the gradients. |
| |
|
| | In some cases you may want to stop gradients from propagating through a |
| | part of the function. You can use the :func:`stop_gradient` for that. |
| |
|
| |
|
| | Automatic Vectorization |
| | ----------------------- |
| |
|
| | .. _vmap: |
| |
|
| | Use :func:`vmap` to automate vectorizing complex functions. Here we'll go |
| | through a basic and contrived example for the sake of clarity, but :func:`vmap` |
| | can be quite powerful for more complex functions which are difficult to optimize |
| | by hand. |
| | |
| | .. warning:: |
| | |
| | Some operations are not yet supported with :func:`vmap`. If you encounter an error |
| | like: ``ValueError: Primitive's vmap not implemented.`` file an `issue |
| | <https://github.com/ml-explore/mlx/issues>`_ and include your function. |
| | We will prioritize including it. |
| |
|
| | A naive way to add the elements from two sets of vectors is with a loop: |
| |
|
| | .. code-block:: python |
| |
|
| | xs = mx.random.uniform(shape=(4096, 100)) |
| | ys = mx.random.uniform(shape=(100, 4096)) |
| |
|
| | def naive_add(xs, ys): |
| | return [xs[i] + ys[:, i] for i in range(xs.shape[0])] |
| |
|
| | Instead you can use :func:`vmap` to automatically vectorize the addition: |
| |
|
| | .. code-block:: python |
| |
|
| | |
| | |
| | vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1)) |
| |
|
| | The ``in_axes`` parameter can be used to specify which dimensions of the |
| | corresponding input to vectorize over. Similarly, use ``out_axes`` to specify |
| | where the vectorized axes should be in the outputs. |
| |
|
| | Let's time these two different versions: |
| | |
| | .. code-block:: python |
| | |
| | import timeit |
| | |
| | print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100)) |
| | print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100)) |
| | |
| | On an M1 Max the naive version takes in total ``5.639`` seconds whereas the |
| | vectorized version takes only ``0.024`` seconds, more than 200 times faster. |
| | |
| | Of course, this operation is quite contrived. A better approach is to simply do |
| | ``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy. |
| | |