File size: 4,720 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
.. _lazy eval:

Lazy Evaluation
===============

.. currentmodule:: mlx.core

Why Lazy Evaluation
-------------------

When you perform operations in MLX, no computation actually happens. Instead a
compute graph is recorded. The actual computation only happens if an
:func:`eval` is performed.

MLX uses lazy evaluation because it has some nice features, some of which we
describe below.

Transforming Compute Graphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^

Lazy evaluation lets us record a compute graph without actually doing any
computations. This is useful for function transformations like :func:`grad` and
:func:`vmap` and graph optimizations.

Currently, MLX does not compile and rerun compute graphs. They are all
generated dynamically. However, lazy evaluation makes it much easier to
integrate compilation for future performance enhancements.

Only Compute What You Use
^^^^^^^^^^^^^^^^^^^^^^^^^

In MLX you do not need to worry as much about computing outputs that are never
used. For example:

.. code-block:: python

  def fun(x):
      a = fun1(x)
      b = expensive_fun(a)
      return a, b

  y, _ = fun(x)

Here, we never actually compute the output of ``expensive_fun``. Use this
pattern with care though, as the graph of ``expensive_fun`` is still built, and
that has some cost associated to it.

Similarly, lazy evaluation can be beneficial for saving memory while keeping
code simple. Say you have a very large model ``Model`` derived from
:obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``.
Typically, this will initialize all of the weights as ``float32``, but the
initialization does not actually compute anything until you perform an
:func:`eval`. If you update the model with ``float16`` weights, your maximum
consumed memory will be half that required if eager computation was used
instead.

This pattern is simple to do in MLX thanks to lazy computation:

.. code-block:: python

  model = Model() # no memory used yet
  model.load_weights("weights_fp16.safetensors")

When to Evaluate
----------------

A common question is when to use :func:`eval`. The trade-off is between
letting graphs get too large and not batching enough useful work.

For example:

.. code-block:: python

  for _ in range(100):
       a = a + b
       mx.eval(a)
       b = b * 2
       mx.eval(b)

This is a bad idea because there is some fixed overhead with each graph
evaluation. On the other hand, there is some slight overhead which grows with
the compute graph size, so extremely large graphs (while computationally
correct) can be costly.

Luckily, a wide range of compute graph sizes work pretty well with MLX:
anything from a few tens of operations to many thousands of operations per
evaluation should be okay.

Most numerical computations have an iterative outer loop (e.g. the iteration in
stochastic gradient descent). A natural and usually efficient place to use
:func:`eval` is at each iteration of this outer loop.

Here is a concrete example:

.. code-block:: python

   for batch in dataset:

       # Nothing has been evaluated yet
       loss, grad = value_and_grad_fn(model, batch)

       # Still nothing has been evaluated
       optimizer.update(model, grad)

       # Evaluate the loss and the new parameters which will
       # run the full gradient computation and optimizer update
       mx.eval(loss, model.parameters())


An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you ``print`` an array, convert it to an
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
saving functions) will also evaluate the array.


Calling :func:`array.item` on a scalar array will also evaluate it. In the
example above, printing the loss (``print(loss)``) or adding the loss scalar to
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
these lines are before ``mx.eval(loss, model.parameters())`` then this
will be a partial evaluation, computing only the forward pass.

Also, calling :func:`eval` on an array or set of arrays multiple times is
perfectly fine. This is effectively a no-op.

.. warning::

  Using scalar arrays for control-flow will cause an evaluation.

Here is an example:

.. code-block:: python

   def fun(x):
       h, y = first_layer(x)
       if y > 0:  # An evaluation is done here!
           z  = second_layer_a(h)
       else:
           z  = second_layer_b(h)
       return z

Using arrays for control flow should be done with care. The above example works
and can even be used with gradient transformations. However, this can be very
inefficient if evaluations are done too frequently.