| | .. _export_usage: |
| |
|
| | Exporting Functions |
| | =================== |
| |
|
| | .. currentmodule:: mlx.core |
| |
|
| | MLX has an API to export and import functions to and from a file. This lets you |
| | run computations written in one MLX front-end (e.g. Python) in another MLX |
| | front-end (e.g. C++). |
| |
|
| | This guide walks through the basics of the MLX export API with some examples. |
| | To see the full list of functions check-out the :ref:`API documentation |
| | <export>`. |
| |
|
| | Basics of Exporting |
| | ------------------- |
| |
|
| | Let's start with a simple example: |
| |
|
| | .. code-block:: python |
| |
|
| | def fun(x, y): |
| | return x + y |
| |
|
| | x = mx.array(1.0) |
| | y = mx.array(1.0) |
| | mx.export_function("add.mlxfn", fun, x, y) |
| |
|
| | To export a function, provide sample input arrays that the function |
| | can be called with. The data doesn't matter, but the shapes and types of the |
| | arrays do. In the above example we exported ``fun`` with two ``float32`` |
| | scalar arrays. We can then import the function and run it: |
| |
|
| | .. code-block:: python |
| |
|
| | add_fun = mx.import_function("add.mlxfn") |
| |
|
| | out, = add_fun(mx.array(1.0), mx.array(2.0)) |
| | # Prints: array(3, dtype=float32) |
| | print(out) |
| |
|
| | out, = add_fun(mx.array(1.0), mx.array(3.0)) |
| | # Prints: array(4, dtype=float32) |
| | print(out) |
| |
|
| | # Raises an exception |
| | add_fun(mx.array(1), mx.array(3.0)) |
| |
|
| | # Raises an exception |
| | add_fun(mx.array([1.0, 2.0]), mx.array(3.0)) |
| |
|
| | Notice the third and fourth calls to ``add_fun`` raise exceptions because the |
| | shapes and types of the inputs are different than the shapes and types of the |
| | example inputs we exported the function with. |
| |
|
| | Also notice that even though the original ``fun`` returns a single output |
| | array, the imported function always returns a tuple of one or more arrays. |
| |
|
| | The inputs to :func:`export_function` and to an imported function can be |
| | specified as variable positional arguments or as a tuple of arrays: |
| |
|
| | .. code-block:: python |
| |
|
| | def fun(x, y): |
| | return x + y |
| |
|
| | x = mx.array(1.0) |
| | y = mx.array(1.0) |
| |
|
| | # Both arguments to fun are positional |
| | mx.export_function("add.mlxfn", fun, x, y) |
| |
|
| | # Same as above |
| | mx.export_function("add.mlxfn", fun, (x, y)) |
| |
|
| | imported_fun = mx.import_function("add.mlxfn") |
| |
|
| | # Ok |
| | out, = imported_fun(x, y) |
| |
|
| | # Also ok |
| | out, = imported_fun((x, y)) |
| |
|
| | You can pass example inputs to functions as positional or keyword arguments. If |
| | you use keyword arguments to export the function, then you have to use the same |
| | keyword arguments when calling the imported function. |
| |
|
| | .. code-block:: python |
| |
|
| | def fun(x, y): |
| | return x + y |
| |
|
| | # One argument to fun is positional, the other is a kwarg |
| | mx.export_function("add.mlxfn", fun, x, y=y) |
| |
|
| | imported_fun = mx.import_function("add.mlxfn") |
| |
|
| | # Ok |
| | out, = imported_fun(x, y=y) |
| |
|
| | # Also ok |
| | out, = imported_fun((x,), {"y": y}) |
| |
|
| | # Raises since the keyword argument is missing |
| | out, = imported_fun(x, y) |
| |
|
| | # Raises since the keyword argument has the wrong key |
| | out, = imported_fun(x, z=y) |
| |
|
| |
|
| | Exporting Modules |
| | ----------------- |
| |
|
| | An :obj:`mlx.nn.Module` can be exported with or without the parameters included |
| | in the exported function. Here's an example: |
| |
|
| | .. code-block:: python |
| |
|
| | model = nn.Linear(4, 4) |
| | mx.eval(model.parameters()) |
| |
|
| | def call(x): |
| | return model(x) |
| |
|
| | mx.export_function("model.mlxfn", call, mx.zeros(4)) |
| |
|
| | In the above example, the :obj:`mlx.nn.Linear` module is exported. Its |
| | parameters are also saved to the ``model.mlxfn`` file. |
| |
|
| | .. note:: |
| |
|
| | For enclosed arrays inside an exported function, be extra careful to ensure |
| | they are evaluated. The computation graph that gets exported will include |
| | the computation that produces enclosed inputs. |
| |
|
| | If the above example was missing ``mx.eval(model.parameters()``, the |
| | exported function would include the random initialization of the |
| | :obj:`mlx.nn.Module` parameters. |
| |
|
| | If you only want to export the ``Module.__call__`` function without the |
| | parameters, pass them as inputs to the ``call`` wrapper: |
| |
|
| | .. code-block:: python |
| |
|
| | model = nn.Linear(4, 4) |
| | mx.eval(model.parameters()) |
| |
|
| | def call(x, **params): |
| | # Set the model's parameters to the input parameters |
| | model.update(tree_unflatten(list(params.items()))) |
| | return model(x) |
| |
|
| | params = tree_flatten(model.parameters(), destination={}) |
| | mx.export_function("model.mlxfn", call, (mx.zeros(4),), params) |
| |
|
| |
|
| | Shapeless Exports |
| | ----------------- |
| |
|
| | Just like :func:`compile`, functions can also be exported for dynamically shaped |
| | inputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter` |
| | to export a function which can be used for inputs with variable shapes: |
| |
|
| | .. code-block:: python |
| |
|
| | mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True) |
| | imported_abs = mx.import_function("fun.mlxfn") |
| |
|
| | # Ok |
| | out, = imported_abs(mx.array([-1.0])) |
| |
|
| | # Also ok |
| | out, = imported_abs(mx.array([-1.0, -2.0])) |
| |
|
| | With ``shapeless=False`` (which is the default), the second call to |
| | ``imported_abs`` would raise an exception with a shape mismatch. |
| |
|
| | Shapeless exporting works the same as shapeless compilation and should be |
| | used carefully. See the :ref:`documentation on shapeless compilation |
| | <shapeless_compile>` for more information. |
| |
|
| | Exporting Multiple Traces |
| | ------------------------- |
| |
|
| | In some cases, functions build different computation graphs for different |
| | input arguments. A simple way to manage this is to export to a new file with |
| | each set of inputs. This is a fine option in many cases. But it can be |
| | suboptimal if the exported functions have a large amount of duplicate constant |
| | data (for example the parameters of a :obj:`mlx.nn.Module`). |
| |
|
| | The export API in MLX lets you export multiple traces of the same function to |
| | a single file by creating an exporting context manager with :func:`exporter`: |
| |
|
| | .. code-block:: python |
| |
|
| | def fun(x, y=None): |
| | constant = mx.array(3.0) |
| | if y is not None: |
| | x += y |
| | return x + constant |
| |
|
| | with mx.exporter("fun.mlxfn", fun) as exporter: |
| | exporter(mx.array(1.0)) |
| | exporter(mx.array(1.0), y=mx.array(0.0)) |
| |
|
| | imported_function = mx.import_function("fun.mlxfn") |
| |
|
| | # Call the function with y=None |
| | out, = imported_function(mx.array(1.0)) |
| | print(out) |
| |
|
| | # Call the function with y specified |
| | out, = imported_function(mx.array(1.0), y=mx.array(1.0)) |
| | print(out) |
| |
|
| | In the above example the function constant data, (i.e. ``constant``), is only |
| | saved once. |
| |
|
| | Transformations with Imported Functions |
| | --------------------------------------- |
| |
|
| | Function transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work |
| | on imported functions just like regular Python functions: |
| |
|
| | .. code-block:: python |
| |
|
| | def fun(x): |
| | return mx.sin(x) |
| |
|
| | x = mx.array(0.0) |
| | mx.export_function("sine.mlxfn", fun, x) |
| |
|
| | imported_fun = mx.import_function("sine.mlxfn") |
| |
|
| | # Take the derivative of the imported function |
| | dfdx = mx.grad(lambda x: imported_fun(x)[0]) |
| | # Prints: array(1, dtype=float32) |
| | print(dfdx(x)) |
| |
|
| | # Compile the imported function |
| | mx.compile(imported_fun) |
| | # Prints: array(0, dtype=float32) |
| | print(compiled_fun(x)[0]) |
| |
|
| |
|
| | Importing Functions in C++ |
| | -------------------------- |
| |
|
| | Importing and running functions in C++ is basically the same as importing and |
| | running them in Python. First, follow the :ref:`instructions <mlx_in_cpp>` to |
| | setup a simple C++ project that uses MLX as a library. |
| |
|
| | Next, export a simple function from Python: |
| |
|
| | .. code-block:: python |
| |
|
| | def fun(x, y): |
| | return mx.exp(x + y) |
| |
|
| | x = mx.array(1.0) |
| | y = mx.array(1.0) |
| | mx.export_function("fun.mlxfn", fun, x, y) |
| |
|
| |
|
| | Import and run the function in C++ with only a few lines of code: |
| |
|
| | .. code-block:: c++ |
| |
|
| | auto fun = mx::import_function("fun.mlxfn"); |
| |
|
| | auto inputs = {mx::array(1.0), mx::array(1.0)}; |
| | auto outputs = fun(inputs); |
| |
|
| | |
| | std::cout << outputs[0] << std::endl; |
| |
|
| | Imported functions can be transformed in C++ just like in Python. Use |
| | ``std::vector<mx::array>`` for positional arguments and ``std::map<std::string, |
| | mx::array>`` for keyword arguments when calling imported functions in C++. |
| |
|
| | More Examples |
| | ------------- |
| |
|
| | Here are a few more complete examples exporting more complex functions from |
| | Python and importing and running them in C++: |
| | |
| | * `Inference and training a multi-layer perceptron <https: |
| |
|