|
|
.. _compile: |
|
|
|
|
|
Compilation |
|
|
=========== |
|
|
|
|
|
.. currentmodule:: mlx.core |
|
|
|
|
|
MLX has a :func:`compile` function transformation which compiles computation |
|
|
graphs. Function compilation results in smaller graphs by merging common work |
|
|
and fusing certain operations. In many cases this can lead to big improvements |
|
|
in run-time and memory use. |
|
|
|
|
|
Getting started with :func:`compile` is simple, but there are some edge cases |
|
|
that are good to be aware of for more complex graphs and advanced usage. |
|
|
|
|
|
Basics of Compile |
|
|
----------------- |
|
|
|
|
|
Let's start with a simple example: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
def fun(x, y): |
|
|
return mx.exp(-x) + y |
|
|
|
|
|
x = mx.array(1.0) |
|
|
y = mx.array(2.0) |
|
|
|
|
|
# Regular call, no compilation |
|
|
# Prints: array(2.36788, dtype=float32) |
|
|
print(fun(x, y)) |
|
|
|
|
|
# Compile the function |
|
|
compiled_fun = mx.compile(fun) |
|
|
|
|
|
# Prints: array(2.36788, dtype=float32) |
|
|
print(compiled_fun(x, y)) |
|
|
|
|
|
The output of both the regular function and the compiled function is the same |
|
|
up to numerical precision. |
|
|
|
|
|
The first time you call a compiled function, MLX will build the compute |
|
|
graph, optimize it, and generate and compile code. This can be relatively |
|
|
slow. However, MLX will cache compiled functions, so calling a compiled |
|
|
function multiple times will not initiate a new compilation. This means you |
|
|
should typically compile functions that you plan to use more than once. |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
def fun(x, y): |
|
|
return mx.exp(-x) + y |
|
|
|
|
|
x = mx.array(1.0) |
|
|
y = mx.array(2.0) |
|
|
|
|
|
compiled_fun = mx.compile(fun) |
|
|
|
|
|
# Compiled here |
|
|
compiled_fun(x, y) |
|
|
|
|
|
# Not compiled again |
|
|
compiled_fun(x, y) |
|
|
|
|
|
# Not compiled again |
|
|
mx.compile(fun)(x, y) |
|
|
|
|
|
There are some important cases to be aware of that can cause a function to |
|
|
be recompiled: |
|
|
|
|
|
* Changing the shape or number of dimensions |
|
|
* Changing the type of any of the inputs |
|
|
* Changing the number of inputs to the function |
|
|
|
|
|
In certain cases only some of the compilation stack will be rerun (for |
|
|
example when changing the shapes) and in other cases the full compilation |
|
|
stack will be rerun (for example when changing the types). In general you |
|
|
should avoid compiling functions too frequently. |
|
|
|
|
|
Another idiom to watch out for is compiling functions which get created and |
|
|
destroyed frequently. This can happen, for example, when compiling an anonymous |
|
|
function in a loop: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
a = mx.array(1.0) |
|
|
# Don't do this, compiles lambda at each iteration |
|
|
for _ in range(5): |
|
|
mx.compile(lambda x: mx.exp(mx.abs(x)))(a) |
|
|
|
|
|
Example Speedup |
|
|
--------------- |
|
|
|
|
|
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with |
|
|
Transformer-based models. The implementation involves several unary and binary |
|
|
element-wise operations: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
def gelu(x): |
|
|
return x * (1 + mx.erf(x / math.sqrt(2))) / 2 |
|
|
|
|
|
If you use this function with small arrays, it will be overhead bound. If you |
|
|
use it with large arrays it will be memory bandwidth bound. However, all of |
|
|
the operations in the ``gelu`` are fusible into a single kernel with |
|
|
:func:`compile`. This can speedup both cases considerably. |
|
|
|
|
|
Let's compare the runtime of the regular function versus the compiled |
|
|
function. We'll use the following timing helper which does a warm up and |
|
|
handles synchronization: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
import time |
|
|
|
|
|
def timeit(fun, x): |
|
|
# warm up |
|
|
for _ in range(10): |
|
|
mx.eval(fun(x)) |
|
|
|
|
|
tic = time.perf_counter() |
|
|
for _ in range(100): |
|
|
mx.eval(fun(x)) |
|
|
toc = time.perf_counter() |
|
|
tpi = 1e3 * (toc - tic) / 100 |
|
|
print(f"Time per iteration {tpi:.3f} (ms)") |
|
|
|
|
|
|
|
|
Now make an array, and benchmark both functions: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
x = mx.random.uniform(shape=(32, 1000, 4096)) |
|
|
timeit(gelu, x) |
|
|
timeit(mx.compile(gelu), x) |
|
|
|
|
|
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is |
|
|
five times faster. |
|
|
|
|
|
Debugging |
|
|
--------- |
|
|
|
|
|
When a compiled function is first called, it is traced with placeholder |
|
|
inputs. This means you can't evaluate arrays (for example to print their |
|
|
contents) inside compiled functions. |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
@mx.compile |
|
|
def fun(x): |
|
|
z = -x |
|
|
print(z) # Crash |
|
|
return mx.exp(z) |
|
|
|
|
|
fun(mx.array(5.0)) |
|
|
|
|
|
For debugging, inspecting arrays can be helpful. One way to do that is to |
|
|
globally disable compilation using the :func:`disable_compile` function or |
|
|
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though |
|
|
``fun`` is compiled: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
@mx.compile |
|
|
def fun(x): |
|
|
z = -x |
|
|
print(z) # Okay |
|
|
return mx.exp(z) |
|
|
|
|
|
mx.disable_compile() |
|
|
fun(mx.array(5.0)) |
|
|
|
|
|
|
|
|
Pure Functions |
|
|
-------------- |
|
|
|
|
|
Compiled functions are intended to be *pure*; that is they should not have side |
|
|
effects. For example: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
state = [] |
|
|
|
|
|
@mx.compile |
|
|
def fun(x, y): |
|
|
z = x + y |
|
|
state.append(z) |
|
|
return mx.exp(z) |
|
|
|
|
|
fun(mx.array(1.0), mx.array(2.0)) |
|
|
# Crash! |
|
|
print(state) |
|
|
|
|
|
After the first call of ``fun``, the ``state`` list will hold a placeholder |
|
|
array. The placeholder does not have any data; it is only used to build the |
|
|
computation graph. Printing such an array results in a crash. |
|
|
|
|
|
You have two options to deal with this. The first option is to simply return |
|
|
``state`` as an output: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
state = [] |
|
|
|
|
|
@mx.compile |
|
|
def fun(x, y): |
|
|
z = x + y |
|
|
state.append(z) |
|
|
return mx.exp(z), state |
|
|
|
|
|
_, state = fun(mx.array(1.0), mx.array(2.0)) |
|
|
# Prints [array(3, dtype=float32)] |
|
|
print(state) |
|
|
|
|
|
In some cases returning updated state can be pretty inconvenient. Hence, |
|
|
:func:`compile` has a parameter to capture implicit outputs: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
from functools import partial |
|
|
|
|
|
state = [] |
|
|
|
|
|
# Tell compile to capture state as an output |
|
|
@partial(mx.compile, outputs=state) |
|
|
def fun(x, y): |
|
|
z = x + y |
|
|
state.append(z) |
|
|
return mx.exp(z) |
|
|
|
|
|
fun(mx.array(1.0), mx.array(2.0)) |
|
|
# Prints [array(3, dtype=float32)] |
|
|
print(state) |
|
|
|
|
|
This is particularly useful for compiling a function which includes an update |
|
|
to a container of arrays, as is commonly done when training the parameters of a |
|
|
:class:`mlx.nn.Module`. |
|
|
|
|
|
Compiled functions will also treat any inputs not in the parameter list as |
|
|
constants. For example: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
state = [mx.array(1.0)] |
|
|
|
|
|
@mx.compile |
|
|
def fun(x): |
|
|
return x + state[0] |
|
|
|
|
|
# Prints array(2, dtype=float32) |
|
|
print(fun(mx.array(1.0))) |
|
|
|
|
|
# Update state |
|
|
state[0] = mx.array(5.0) |
|
|
|
|
|
# Still prints array(2, dtype=float32) |
|
|
print(fun(mx.array(1.0))) |
|
|
|
|
|
In order to have the change of state reflected in the outputs of ``fun`` you |
|
|
again have two options. The first option is to simply pass ``state`` as input |
|
|
to the function. In some cases this can be pretty inconvenient. Hence, |
|
|
:func:`compile` also has a parameter to capture implicit inputs: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
from functools import partial |
|
|
state = [mx.array(1.0)] |
|
|
|
|
|
# Tell compile to capture state as an input |
|
|
@partial(mx.compile, inputs=state) |
|
|
def fun(x): |
|
|
return x + state[0] |
|
|
|
|
|
# Prints array(2, dtype=float32) |
|
|
print(fun(mx.array(1.0))) |
|
|
|
|
|
# Update state |
|
|
state[0] = mx.array(5.0) |
|
|
|
|
|
# Prints array(6, dtype=float32) |
|
|
print(fun(mx.array(1.0))) |
|
|
|
|
|
|
|
|
Compiling Training Graphs |
|
|
------------------------- |
|
|
|
|
|
This section will step through how to use :func:`compile` with a simple example |
|
|
of a common setup: training a model with :obj:`mlx.nn.Module` using an |
|
|
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the |
|
|
full forward, backward, and update with :func:`compile`. |
|
|
|
|
|
To start, here is the simple example without any compilation: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
import mlx.core as mx |
|
|
import mlx.nn as nn |
|
|
import mlx.optimizers as optim |
|
|
|
|
|
# 4 examples with 10 features each |
|
|
x = mx.random.uniform(shape=(4, 10)) |
|
|
|
|
|
# 0, 1 targets |
|
|
y = mx.array([0, 1, 0, 1]) |
|
|
|
|
|
# Simple linear model |
|
|
model = nn.Linear(10, 1) |
|
|
|
|
|
# SGD with momentum |
|
|
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8) |
|
|
|
|
|
def loss_fn(model, x, y): |
|
|
logits = model(x).squeeze() |
|
|
return nn.losses.binary_cross_entropy(logits, y) |
|
|
|
|
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn) |
|
|
|
|
|
# Perform 10 steps of gradient descent |
|
|
for it in range(10): |
|
|
loss, grads = loss_and_grad_fn(model, x, y) |
|
|
optimizer.update(model, grads) |
|
|
mx.eval(model.parameters(), optimizer.state) |
|
|
|
|
|
To compile the update we can put it all in a function and compile it with the |
|
|
appropriate input and output captures. Here's the same example but compiled: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
import mlx.core as mx |
|
|
import mlx.nn as nn |
|
|
import mlx.optimizers as optim |
|
|
from functools import partial |
|
|
|
|
|
# 4 examples with 10 features each |
|
|
x = mx.random.uniform(shape=(4, 10)) |
|
|
|
|
|
# 0, 1 targets |
|
|
y = mx.array([0, 1, 0, 1]) |
|
|
|
|
|
# Simple linear model |
|
|
model = nn.Linear(10, 1) |
|
|
|
|
|
# SGD with momentum |
|
|
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8) |
|
|
|
|
|
def loss_fn(model, x, y): |
|
|
logits = model(x).squeeze() |
|
|
return nn.losses.binary_cross_entropy(logits, y) |
|
|
|
|
|
# The state that will be captured as input and output |
|
|
state = [model.state, optimizer.state] |
|
|
|
|
|
@partial(mx.compile, inputs=state, outputs=state) |
|
|
def step(x, y): |
|
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn) |
|
|
loss, grads = loss_and_grad_fn(model, x, y) |
|
|
optimizer.update(model, grads) |
|
|
return loss |
|
|
|
|
|
# Perform 10 steps of gradient descent |
|
|
for it in range(10): |
|
|
loss = step(x, y) |
|
|
# Evaluate the model and optimizer state |
|
|
mx.eval(state) |
|
|
print(loss) |
|
|
|
|
|
|
|
|
.. note:: |
|
|
|
|
|
If you are using a module which performs random sampling such as |
|
|
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the |
|
|
``state`` captured by :func:`compile`, i.e. ``state = [model.state, |
|
|
optimizer.state, mx.random.state]``. |
|
|
|
|
|
|
|
|
.. note:: |
|
|
|
|
|
For more examples of compiling full training graphs checkout the `MLX |
|
|
Examples <https: |
|
|
|
|
|
Transformations with Compile |
|
|
---------------------------- |
|
|
|
|
|
In MLX function transformations are composable. You can apply any function |
|
|
transformation to the output of any other function transformation. For more on |
|
|
this, see the documentation on :ref:`function transforms |
|
|
<function_transforms>`. |
|
|
|
|
|
Compiling transformed functions works just as expected: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
grad_fn = mx.grad(mx.exp) |
|
|
|
|
|
compiled_grad_fn = mx.compile(grad_fn) |
|
|
|
|
|
# Prints: array(2.71828, dtype=float32) |
|
|
print(grad_fn(mx.array(1.0))) |
|
|
|
|
|
# Also prints: array(2.71828, dtype=float32) |
|
|
print(compiled_grad_fn(mx.array(1.0))) |
|
|
|
|
|
.. note:: |
|
|
|
|
|
In order to compile as much as possible, a transformation of a compiled |
|
|
function will not by default be compiled. To compile the transformed |
|
|
function simply pass it through :func:`compile`. |
|
|
|
|
|
You can also compile functions which themselves call compiled functions. A |
|
|
good practice is to compile the outer most function to give :func:`compile` |
|
|
the most opportunity to optimize the computation graph: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
@mx.compile |
|
|
def inner(x): |
|
|
return mx.exp(-mx.abs(x)) |
|
|
|
|
|
def outer(x): |
|
|
inner(inner(x)) |
|
|
|
|
|
# Compiling the outer function is good to do as it will likely |
|
|
# be faster even though the inner functions are compiled |
|
|
fun = mx.compile(outer) |
|
|
|
|
|
|
|
|
|
|
|
.. _shapeless_compile: |
|
|
|
|
|
Shapeless Compilation |
|
|
--------------------- |
|
|
|
|
|
When the shape of an input to a compiled function changes, the function is |
|
|
recompiled. You can compile a function once and run it on inputs with |
|
|
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this |
|
|
case changes to the shapes of the inputs do not cause the function to be |
|
|
recompiled. |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
def fun(x, y): |
|
|
return mx.abs(x + y) |
|
|
|
|
|
compiled_fun = mx.compile(fun, shapeless=True) |
|
|
|
|
|
x = mx.array(1.0) |
|
|
y = mx.array(-2.0) |
|
|
|
|
|
# Firt call compiles the function |
|
|
print(compiled_fun(x, y)) |
|
|
|
|
|
# Second call with different shapes |
|
|
# does not recompile the function |
|
|
x = mx.array([1.0, -6.0]) |
|
|
y = mx.array([-2.0, 3.0]) |
|
|
print(compiled_fun(x, y)) |
|
|
|
|
|
|
|
|
Use shapeless compilations carefully. Since compilation is not triggered when |
|
|
shapes change, any graphs which are conditional on the input shapes will not |
|
|
work as expected. Shape-dependent computations are common and sometimes subtle |
|
|
to detect. For example: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
def fun(x): |
|
|
return x.reshape(x.shape[0] * x.shape[1], -1) |
|
|
|
|
|
compiled_fun = mx.compile(fun, shapeless=True) |
|
|
|
|
|
x = mx.random.uniform(shape=(2, 3, 4)) |
|
|
|
|
|
out = compiled_fun(x) |
|
|
|
|
|
x = mx.random.uniform(shape=(5, 5, 3)) |
|
|
|
|
|
# Error, can't reshape (5, 5, 3) to (6, -1) |
|
|
out = compiled_fun(x) |
|
|
|
|
|
The second call to the ``compiled_fun`` fails because of the call to |
|
|
:func:`reshape` which uses the static shape of ``x`` in the first call. We can |
|
|
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
def fun(x): |
|
|
return x.flatten(0, 1) |
|
|
|
|
|
compiled_fun = mx.compile(fun, shapeless=True) |
|
|
|
|
|
x = mx.random.uniform(shape=(2, 3, 4)) |
|
|
|
|
|
out = compiled_fun(x) |
|
|
|
|
|
x = mx.random.uniform(shape=(5, 5, 3)) |
|
|
|
|
|
# Ok |
|
|
out = compiled_fun(x) |
|
|
|