|
|
.. _numpy: |
|
|
|
|
|
Conversion to NumPy and Other Frameworks |
|
|
======================================== |
|
|
|
|
|
MLX array supports conversion between other frameworks with either: |
|
|
|
|
|
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_. |
|
|
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_. |
|
|
|
|
|
Let's convert an array to NumPy and back. |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
import mlx.core as mx |
|
|
import numpy as np |
|
|
|
|
|
a = mx.arange(3) |
|
|
b = np.array(a) |
|
|
c = mx.array(b) |
|
|
|
|
|
.. note:: |
|
|
|
|
|
Since NumPy does not support ``bfloat16`` arrays, you will need to convert |
|
|
to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``. |
|
|
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 |
|
|
buffer format string does not match the dtype V item size 0.`` |
|
|
|
|
|
By default, NumPy copies data to a new array. This can be prevented by creating |
|
|
an array view: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
a = mx.arange(3) |
|
|
a_view = np.array(a, copy=False) |
|
|
print(a_view.flags.owndata) |
|
|
a_view[0] = 1 |
|
|
print(a[0].item()) |
|
|
|
|
|
.. note:: |
|
|
|
|
|
NumPy arrays with type ``float64`` will be default converted to MLX arrays |
|
|
with type ``float32``. |
|
|
|
|
|
A NumPy array view is a normal NumPy array, except that it does not own its |
|
|
memory. This means writing to the view is reflected in the original array. |
|
|
|
|
|
While this is quite powerful to prevent copying arrays, it should be noted that |
|
|
external changes to the memory of arrays cannot be reflected in gradients. |
|
|
|
|
|
Let's demonstrate this in an example: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
def f(x): |
|
|
x_view = np.array(x, copy=False) |
|
|
x_view[:] *= x_view |
|
|
return x.sum() |
|
|
|
|
|
x = mx.array([3.0]) |
|
|
y, df = mx.value_and_grad(f)(x) |
|
|
print("f(x) = x² =", y.item()) |
|
|
print("f'(x) = 2x !=", df.item()) |
|
|
|
|
|
|
|
|
The function ``f`` indirectly modifies the array ``x`` through a memory view. |
|
|
However, this modification is not reflected in the gradient, as seen in the |
|
|
last line outputting ``1.0``, representing the gradient of the sum operation |
|
|
alone. The squaring of ``x`` occurs externally to MLX, meaning that no |
|
|
gradient is incorporated. It's important to note that a similar issue arises |
|
|
during array conversion and copying. For instance, a function defined as |
|
|
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient, |
|
|
even though no in-place operations on MLX memory are executed. |
|
|
|
|
|
PyTorch |
|
|
------- |
|
|
|
|
|
.. warning:: |
|
|
|
|
|
PyTorch Support for :obj:`memoryview` is experimental and can break for |
|
|
multi-dimensional arrays. Casting to NumPy first is advised for now. |
|
|
|
|
|
PyTorch supports the buffer protocol, but it requires an explicit |
|
|
:obj:`memoryview`. |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
import mlx.core as mx |
|
|
import torch |
|
|
|
|
|
a = mx.arange(3) |
|
|
b = torch.tensor(memoryview(a)) |
|
|
c = mx.array(b.numpy()) |
|
|
|
|
|
Conversion from PyTorch tensors back to arrays must be done via intermediate |
|
|
NumPy arrays with ``numpy()``. |
|
|
|
|
|
JAX |
|
|
--- |
|
|
JAX fully supports the buffer protocol. |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
import mlx.core as mx |
|
|
import jax.numpy as jnp |
|
|
|
|
|
a = mx.arange(3) |
|
|
b = jnp.array(a) |
|
|
c = mx.array(b) |
|
|
|
|
|
TensorFlow |
|
|
---------- |
|
|
|
|
|
TensorFlow supports the buffer protocol, but it requires an explicit |
|
|
:obj:`memoryview`. |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
import mlx.core as mx |
|
|
import tensorflow as tf |
|
|
|
|
|
a = mx.arange(3) |
|
|
b = tf.constant(memoryview(a)) |
|
|
c = mx.array(b) |
|
|
|