File size: 8,232 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
.. _export_usage:

Exporting Functions
===================

.. currentmodule:: mlx.core

MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).

This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.

Basics of Exporting
-------------------

Let's start with a simple example:

.. code-block:: python

  def fun(x, y):
    return x + y

  x = mx.array(1.0)
  y = mx.array(1.0)
  mx.export_function("add.mlxfn", fun, x, y)

To export a function, provide sample input arrays that the function
can be called with. The data doesn't matter, but the shapes and types of the
arrays do. In the above example we exported ``fun`` with two ``float32``
scalar arrays. We can then import the function and run it:

.. code-block:: python

  add_fun = mx.import_function("add.mlxfn")

  out, = add_fun(mx.array(1.0), mx.array(2.0))
  # Prints: array(3, dtype=float32)
  print(out)

  out, = add_fun(mx.array(1.0), mx.array(3.0))
  # Prints: array(4, dtype=float32)
  print(out)

  # Raises an exception
  add_fun(mx.array(1), mx.array(3.0))

  # Raises an exception
  add_fun(mx.array([1.0, 2.0]), mx.array(3.0))

Notice the third and fourth calls to ``add_fun`` raise exceptions because the
shapes and types of the inputs are different than the shapes and types of the
example inputs we exported the function with.

Also notice that even though the original ``fun`` returns a single output
array, the imported function always returns a tuple of one or more arrays.

The inputs to :func:`export_function` and to an imported function can be
specified as variable positional arguments or as a tuple of arrays:

.. code-block:: python

  def fun(x, y):
    return x + y

  x = mx.array(1.0)
  y = mx.array(1.0)

  # Both arguments to fun are positional
  mx.export_function("add.mlxfn", fun, x, y)

  # Same as above
  mx.export_function("add.mlxfn", fun, (x, y))

  imported_fun = mx.import_function("add.mlxfn")

  # Ok
  out, = imported_fun(x, y)

  # Also ok
  out, = imported_fun((x, y))

You can pass example inputs to functions as positional or keyword arguments. If
you use keyword arguments to export the function, then you have to use the same
keyword arguments when calling the imported function.

.. code-block:: python

  def fun(x, y):
    return x + y

  # One argument to fun is positional, the other is a kwarg
  mx.export_function("add.mlxfn", fun, x, y=y)

  imported_fun = mx.import_function("add.mlxfn")

  # Ok
  out, = imported_fun(x, y=y)

  # Also ok
  out, = imported_fun((x,), {"y": y})

  # Raises since the keyword argument is missing
  out, = imported_fun(x, y)

  # Raises since the keyword argument has the wrong key
  out, = imported_fun(x, z=y)


Exporting Modules
-----------------

An :obj:`mlx.nn.Module` can be exported with or without the parameters included
in the exported function. Here's an example:

.. code-block:: python

   model = nn.Linear(4, 4)
   mx.eval(model.parameters())

   def call(x):
      return model(x)

   mx.export_function("model.mlxfn", call, mx.zeros(4))

In the above example, the :obj:`mlx.nn.Linear` module is exported. Its
parameters are also saved to the ``model.mlxfn`` file.

.. note::

   For enclosed arrays inside an exported function, be extra careful to ensure
   they are evaluated. The computation graph that gets exported will include
   the computation that produces enclosed inputs.

   If the above example was missing ``mx.eval(model.parameters()``, the
   exported function would include the random initialization of the
   :obj:`mlx.nn.Module` parameters.

If you only want to export the ``Module.__call__`` function without the
parameters, pass them as inputs to the ``call`` wrapper:

.. code-block:: python

   model = nn.Linear(4, 4)
   mx.eval(model.parameters())

   def call(x, **params):
     # Set the model's parameters to the input parameters
     model.update(tree_unflatten(list(params.items())))
     return model(x)

   params = tree_flatten(model.parameters(), destination={})
   mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)


Shapeless Exports
-----------------

Just like :func:`compile`, functions can also be exported for dynamically shaped
inputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter`
to export a function which can be used for inputs with variable shapes:

.. code-block:: python

  mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
  imported_abs = mx.import_function("fun.mlxfn")

  # Ok
  out, = imported_abs(mx.array([-1.0]))

  # Also ok
  out, = imported_abs(mx.array([-1.0, -2.0]))

With ``shapeless=False`` (which is the default), the second call to
``imported_abs`` would raise an exception with a shape mismatch.

Shapeless exporting works the same as shapeless compilation and should be
used carefully. See the :ref:`documentation on shapeless compilation
<shapeless_compile>` for more information.

Exporting Multiple Traces
-------------------------

In some cases, functions build different computation graphs for different
input arguments. A simple way to manage this is to export to a new file with
each set of inputs. This is a fine option in many cases. But it can be
suboptimal if the exported functions have a large amount of duplicate constant
data (for example the parameters of a :obj:`mlx.nn.Module`).

The export API in MLX lets you export multiple traces of the same function to
a single file by creating an exporting context manager with :func:`exporter`:

.. code-block:: python

  def fun(x, y=None):
      constant = mx.array(3.0)
      if y is not None:
        x += y
      return x + constant

  with mx.exporter("fun.mlxfn", fun) as exporter:
      exporter(mx.array(1.0))
      exporter(mx.array(1.0), y=mx.array(0.0))

  imported_function = mx.import_function("fun.mlxfn")

  # Call the function with y=None
  out, = imported_function(mx.array(1.0))
  print(out)

  # Call the function with y specified
  out, = imported_function(mx.array(1.0), y=mx.array(1.0))
  print(out)

In the above example the function constant data, (i.e. ``constant``), is only
saved once.

Transformations with Imported Functions
---------------------------------------

Function transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work
on imported functions just like regular Python functions:

.. code-block:: python

  def fun(x):
      return mx.sin(x)

  x = mx.array(0.0)
  mx.export_function("sine.mlxfn", fun, x)

  imported_fun = mx.import_function("sine.mlxfn")

  # Take the derivative of the imported function
  dfdx = mx.grad(lambda x: imported_fun(x)[0])
  # Prints: array(1, dtype=float32)
  print(dfdx(x))

  # Compile the imported function
  mx.compile(imported_fun)
  # Prints: array(0, dtype=float32)
  print(compiled_fun(x)[0])


Importing Functions in C++
--------------------------

Importing and running functions in C++ is basically the same as importing and
running them in Python. First, follow the :ref:`instructions <mlx_in_cpp>` to
setup a simple C++ project that uses MLX as a library.

Next, export a simple function from Python:

.. code-block:: python

  def fun(x, y):
      return mx.exp(x + y)

  x = mx.array(1.0)
  y = mx.array(1.0)
  mx.export_function("fun.mlxfn", fun, x, y)


Import and run the function in C++ with only a few lines of code:

.. code-block:: c++

  auto fun = mx::import_function("fun.mlxfn");

  auto inputs = {mx::array(1.0), mx::array(1.0)};
  auto outputs = fun(inputs);

  // Prints: array(2, dtype=float32)
  std::cout << outputs[0] << std::endl;

Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.

More Examples
-------------

Here are a few more complete examples exporting more complex functions from
Python and importing and running them in C++:

* `Inference and training a multi-layer perceptron <https://github.com/ml-explore/mlx/tree/main/examples/export>`_