File size: 13,057 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
.. _compile:

Compilation
===========

.. currentmodule:: mlx.core

MLX has a :func:`compile` function transformation which compiles computation
graphs. Function compilation results in smaller graphs by merging common work
and fusing certain operations. In many cases this can lead to big improvements
in run-time and memory use.

Getting started with :func:`compile` is simple, but there are some edge cases
that are good to be aware of for more complex graphs and advanced usage.

Basics of Compile
-----------------

Let's start with a simple example:

.. code-block:: python

  def fun(x, y):
      return mx.exp(-x) + y

  x = mx.array(1.0)
  y = mx.array(2.0)

  # Regular call, no compilation
  # Prints: array(2.36788, dtype=float32)
  print(fun(x, y))

  # Compile the function
  compiled_fun = mx.compile(fun)

  # Prints: array(2.36788, dtype=float32)
  print(compiled_fun(x, y))

The output of both the regular function and the compiled function is the same
up to numerical precision.

The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
function multiple times will not initiate a new compilation. This means you
should typically compile functions that you plan to use more than once.

.. code-block:: python

  def fun(x, y):
      return mx.exp(-x) + y

  x = mx.array(1.0)
  y = mx.array(2.0)

  compiled_fun = mx.compile(fun)

  # Compiled here
  compiled_fun(x, y)

  # Not compiled again
  compiled_fun(x, y)

  # Not compiled again
  mx.compile(fun)(x, y)

There are some important cases to be aware of that can cause a function to
be recompiled:

* Changing the shape or number of dimensions
* Changing the type of any of the inputs
* Changing the number of inputs to the function

In certain cases only some of the compilation stack will be rerun (for
example when changing the shapes) and in other cases the full compilation
stack will be rerun (for example when changing the types). In general you
should avoid compiling functions too frequently.

Another idiom to watch out for is compiling functions which get created and
destroyed frequently. This can happen, for example, when compiling an anonymous
function in a loop:

.. code-block:: python

  a = mx.array(1.0)
  # Don't do this, compiles lambda at each iteration
  for _ in range(5):
      mx.compile(lambda x: mx.exp(mx.abs(x)))(a)

Example Speedup
---------------

The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
Transformer-based models. The implementation involves several unary and binary
element-wise operations:

.. code-block:: python

  def gelu(x):
      return x * (1 + mx.erf(x / math.sqrt(2))) / 2

If you use this function with small arrays, it will be overhead bound. If you
use it with large arrays it will be memory bandwidth bound.  However, all of
the operations in the ``gelu`` are fusible into a single kernel with
:func:`compile`. This can speedup both cases considerably.

Let's compare the runtime of the regular function versus the compiled
function. We'll use the following timing helper which does a warm up and
handles synchronization:

.. code-block:: python

  import time

  def timeit(fun, x):
      # warm up
      for _ in range(10):
          mx.eval(fun(x))

      tic = time.perf_counter()
      for _ in range(100):
          mx.eval(fun(x))
      toc = time.perf_counter()
      tpi = 1e3 * (toc - tic) / 100
      print(f"Time per iteration {tpi:.3f} (ms)")


Now make an array, and benchmark both functions:

.. code-block:: python

  x = mx.random.uniform(shape=(32, 1000, 4096))
  timeit(gelu, x)
  timeit(mx.compile(gelu), x)

On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.

Debugging
---------

When a compiled function is first called, it is traced with placeholder
inputs. This means you can't evaluate arrays (for example to print their
contents) inside compiled functions.

.. code-block:: python

  @mx.compile
  def fun(x):
      z = -x
      print(z)  # Crash
      return mx.exp(z)

  fun(mx.array(5.0))

For debugging, inspecting arrays can be helpful. One way to do that is to
globally disable compilation using the :func:`disable_compile` function or
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
``fun`` is compiled:

.. code-block:: python

  @mx.compile
  def fun(x):
      z = -x
      print(z) # Okay
      return mx.exp(z)

  mx.disable_compile()
  fun(mx.array(5.0))


Pure Functions
--------------

Compiled functions are intended to be *pure*; that is they should not have side
effects. For example:

.. code-block:: python

  state = []

  @mx.compile
  def fun(x, y):
      z = x + y
      state.append(z)
      return mx.exp(z)

  fun(mx.array(1.0), mx.array(2.0))
  # Crash!
  print(state)

After the first call of ``fun``, the ``state`` list will hold a placeholder
array. The placeholder does not have any data; it is only used to build the
computation graph. Printing such an array results in a crash.

You have two options to deal with this. The first option is to simply return
``state`` as an output:

.. code-block:: python

   state = []

   @mx.compile
   def fun(x, y):
      z = x + y
      state.append(z)
      return mx.exp(z), state

    _, state = fun(mx.array(1.0), mx.array(2.0))
    # Prints [array(3, dtype=float32)]
    print(state)

In some cases returning updated state can be pretty inconvenient. Hence,
:func:`compile` has a parameter to capture implicit outputs:

.. code-block:: python

  from functools import partial

  state = []

  # Tell compile to capture state as an output
  @partial(mx.compile, outputs=state)
  def fun(x, y):
      z = x + y
      state.append(z)
      return mx.exp(z)

  fun(mx.array(1.0), mx.array(2.0))
  # Prints [array(3, dtype=float32)]
  print(state)

This is particularly useful for compiling a function which includes an update
to a container of arrays, as is commonly done when training the parameters of a
:class:`mlx.nn.Module`.

Compiled functions will also treat any inputs not in the parameter list as
constants. For example:

.. code-block:: python

  state = [mx.array(1.0)]

  @mx.compile
  def fun(x):
      return x + state[0]

  # Prints array(2, dtype=float32)
  print(fun(mx.array(1.0)))

  # Update state
  state[0] = mx.array(5.0)

  # Still prints array(2, dtype=float32)
  print(fun(mx.array(1.0)))

In order to have the change of state reflected in the outputs of ``fun`` you
again have two options. The first option is to simply pass ``state`` as input
to the function. In some cases this can be pretty inconvenient. Hence,
:func:`compile` also has a parameter to capture implicit inputs:

.. code-block:: python

  from functools import partial
  state = [mx.array(1.0)]

  # Tell compile to capture state as an input
  @partial(mx.compile, inputs=state)
  def fun(x):
      return x + state[0]

  # Prints array(2, dtype=float32)
  print(fun(mx.array(1.0)))

  # Update state
  state[0] = mx.array(5.0)

  # Prints array(6, dtype=float32)
  print(fun(mx.array(1.0)))


Compiling Training Graphs
-------------------------

This section will step through how to use :func:`compile` with a simple example
of a common setup: training a model with :obj:`mlx.nn.Module` using an
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
full forward, backward, and update with :func:`compile`.

To start, here is the simple example without any compilation:

.. code-block:: python

  import mlx.core as mx
  import mlx.nn as nn
  import mlx.optimizers as optim

  # 4 examples with 10 features each
  x = mx.random.uniform(shape=(4, 10))

  # 0, 1 targets
  y = mx.array([0, 1, 0, 1])

  # Simple linear model
  model = nn.Linear(10, 1)

  # SGD with momentum
  optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)

  def loss_fn(model, x, y):
      logits = model(x).squeeze()
      return nn.losses.binary_cross_entropy(logits, y)

  loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

  # Perform 10 steps of gradient descent
  for it in range(10):
      loss, grads = loss_and_grad_fn(model, x, y)
      optimizer.update(model, grads)
      mx.eval(model.parameters(), optimizer.state)

To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:

.. code-block:: python

  import mlx.core as mx
  import mlx.nn as nn
  import mlx.optimizers as optim
  from functools import partial

  # 4 examples with 10 features each
  x = mx.random.uniform(shape=(4, 10))

  # 0, 1 targets
  y = mx.array([0, 1, 0, 1])

  # Simple linear model
  model = nn.Linear(10, 1)

  # SGD with momentum
  optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)

  def loss_fn(model, x, y):
      logits = model(x).squeeze()
      return nn.losses.binary_cross_entropy(logits, y)

  # The state that will be captured as input and output
  state = [model.state, optimizer.state]

  @partial(mx.compile, inputs=state, outputs=state)
  def step(x, y):
      loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
      loss, grads = loss_and_grad_fn(model, x, y)
      optimizer.update(model, grads)
      return loss

  # Perform 10 steps of gradient descent
  for it in range(10):
      loss = step(x, y)
      # Evaluate the model and optimizer state
      mx.eval(state)
      print(loss)


.. note::

  If you are using a module which performs random sampling such as
  :func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
  ``state`` captured by :func:`compile`, i.e. ``state = [model.state,
  optimizer.state, mx.random.state]``.


.. note::

   For more examples of compiling full training graphs checkout the  `MLX
   Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.

Transformations with Compile
----------------------------

In MLX function transformations are composable. You can apply any function
transformation to the output of any other function transformation. For more on
this, see the documentation on :ref:`function transforms
<function_transforms>`.

Compiling transformed functions works just as expected:

.. code-block:: python

  grad_fn = mx.grad(mx.exp)

  compiled_grad_fn = mx.compile(grad_fn)

  # Prints: array(2.71828, dtype=float32)
  print(grad_fn(mx.array(1.0)))

  # Also prints: array(2.71828, dtype=float32)
  print(compiled_grad_fn(mx.array(1.0)))

.. note::

   In order to compile as much as possible, a transformation of a compiled
   function will not by default be compiled. To compile the transformed
   function simply pass it through :func:`compile`.

You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`
the most opportunity to optimize the computation graph:

.. code-block:: python

  @mx.compile
  def inner(x):
      return mx.exp(-mx.abs(x))

  def outer(x):
      inner(inner(x))

  # Compiling the outer function is good to do as it will likely
  # be faster even though the inner functions are compiled
  fun = mx.compile(outer)



.. _shapeless_compile:

Shapeless Compilation
---------------------

When the shape of an input to a compiled function changes, the function is
recompiled. You can compile a function once and run it on inputs with
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this
case changes to the shapes of the inputs do not cause the function to be
recompiled.

.. code-block:: python

  def fun(x, y):
      return mx.abs(x + y)

  compiled_fun = mx.compile(fun, shapeless=True)

  x = mx.array(1.0)
  y = mx.array(-2.0)

  # Firt call compiles the function
  print(compiled_fun(x, y))

  # Second call with different shapes
  # does not recompile the function
  x = mx.array([1.0, -6.0])
  y = mx.array([-2.0, 3.0])
  print(compiled_fun(x, y))


Use shapeless compilations carefully. Since compilation is not triggered when
shapes change, any graphs which are conditional on the input shapes will not
work as expected. Shape-dependent computations are common and sometimes subtle
to detect. For example:

.. code-block:: python

  def fun(x):
      return x.reshape(x.shape[0] * x.shape[1], -1)

  compiled_fun = mx.compile(fun, shapeless=True)

  x = mx.random.uniform(shape=(2, 3, 4))

  out = compiled_fun(x)

  x = mx.random.uniform(shape=(5, 5, 3))

  # Error, can't reshape (5, 5, 3) to (6, -1)
  out = compiled_fun(x)

The second call to the ``compiled_fun`` fails because of the call to
:func:`reshape` which uses the static shape of ``x`` in the first call. We can
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:

.. code-block:: python

  def fun(x):
      return x.flatten(0, 1)

  compiled_fun = mx.compile(fun, shapeless=True)

  x = mx.random.uniform(shape=(2, 3, 4))

  out = compiled_fun(x)

  x = mx.random.uniform(shape=(5, 5, 3))

  # Ok
  out = compiled_fun(x)