File size: 3,370 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
.. _numpy:

Conversion to NumPy and Other Frameworks
========================================

MLX array supports conversion between other frameworks with either:

* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.

Let's convert an array to NumPy and back.

.. code-block:: python

  import mlx.core as mx
  import numpy as np

  a = mx.arange(3)
  b = np.array(a) # copy of a
  c = mx.array(b) # copy of b

.. note::

    Since NumPy does not support ``bfloat16`` arrays, you will need to convert
    to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``.
    Otherwise, you will receive an error like: ``Item size 2 for PEP 3118
    buffer format string does not match the dtype V item size 0.``

By default, NumPy copies data to a new array. This can be prevented by creating
an array view:

.. code-block:: python

  a = mx.arange(3)
  a_view = np.array(a, copy=False)
  print(a_view.flags.owndata) # False
  a_view[0] = 1
  print(a[0].item()) # 1

.. note::

    NumPy arrays with type ``float64`` will be default converted to MLX arrays
    with type ``float32``.

A NumPy array view is a normal NumPy array, except that it does not own its
memory. This means writing to the view is reflected in the original array.

While this is quite powerful to prevent copying arrays, it should be noted that
external changes to the memory of arrays cannot be reflected in gradients.

Let's demonstrate this in an example:

.. code-block:: python

  def f(x):
      x_view = np.array(x, copy=False)
      x_view[:] *= x_view # modify memory without telling mx
      return x.sum()

  x = mx.array([3.0])
  y, df = mx.value_and_grad(f)(x)
  print("f(x) = x² =", y.item()) # 9.0
  print("f'(x) = 2x !=", df.item()) # 1.0


The function ``f`` indirectly modifies the array ``x`` through a memory view.
However, this modification is not reflected in the gradient, as seen in the
last line outputting ``1.0``, representing the gradient of the sum operation
alone.  The squaring of ``x`` occurs externally to MLX, meaning that no
gradient is incorporated.  It's important to note that a similar issue arises
during array conversion and copying.  For instance, a function defined as
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
even though no in-place operations on MLX memory are executed.

PyTorch
-------

.. warning::

   PyTorch Support for :obj:`memoryview` is experimental and can break for
   multi-dimensional arrays. Casting to NumPy first is advised for now.

PyTorch supports the buffer protocol, but it requires an explicit
:obj:`memoryview`.

.. code-block:: python

  import mlx.core as mx
  import torch

  a = mx.arange(3)
  b = torch.tensor(memoryview(a))
  c = mx.array(b.numpy())

Conversion from PyTorch tensors back to arrays must be done via intermediate
NumPy arrays with ``numpy()``.

JAX
---
JAX fully supports the buffer protocol.

.. code-block:: python

  import mlx.core as mx
  import jax.numpy as jnp

  a = mx.arange(3)
  b = jnp.array(a)
  c = mx.array(b)

TensorFlow
----------

TensorFlow supports the buffer protocol, but it requires an explicit
:obj:`memoryview`.

.. code-block:: python

  import mlx.core as mx
  import tensorflow as tf

  a = mx.arange(3)
  b = tf.constant(memoryview(a))
  c = mx.array(b)