|
|
.. _indexing: |
|
|
|
|
|
Indexing Arrays |
|
|
=============== |
|
|
|
|
|
.. currentmodule:: mlx.core |
|
|
|
|
|
For the most part, indexing an MLX :obj:`array` works the same as indexing a |
|
|
NumPy :obj:`numpy.ndarray`. See the `NumPy documentation |
|
|
<https://numpy.org/doc/stable/user/basics.indexing.html>`_ for more details on |
|
|
how that works. |
|
|
|
|
|
For example, you can use regular integers and slices (:obj:`slice`) to index arrays: |
|
|
|
|
|
.. code-block:: shell |
|
|
|
|
|
>>> arr = mx.arange(10) |
|
|
>>> arr[3] |
|
|
array(3, dtype=int32) |
|
|
>>> arr[-2] |
|
|
array(8, dtype=int32) |
|
|
>>> arr[2:8:2] |
|
|
array([2, 4, 6], dtype=int32) |
|
|
|
|
|
For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy: |
|
|
|
|
|
.. code-block:: shell |
|
|
|
|
|
>>> arr = mx.arange(8).reshape(2, 2, 2) |
|
|
>>> arr[:, :, 0] |
|
|
array(3, dtype=int32) |
|
|
array([[0, 2], |
|
|
[4, 6]], dtype=int32 |
|
|
>>> arr[..., 0] |
|
|
array([[0, 2], |
|
|
[4, 6]], dtype=int32 |
|
|
|
|
|
You can index with ``None`` to create a new axis: |
|
|
|
|
|
.. code-block:: shell |
|
|
|
|
|
>>> arr = mx.arange(8) |
|
|
>>> arr.shape |
|
|
[8] |
|
|
>>> arr[None].shape |
|
|
[1, 8] |
|
|
|
|
|
|
|
|
You can also use an :obj:`array` to index another :obj:`array`: |
|
|
|
|
|
.. code-block:: shell |
|
|
|
|
|
>>> arr = mx.arange(10) |
|
|
>>> idx = mx.array([5, 7]) |
|
|
>>> arr[idx] |
|
|
array([5, 7], dtype=int32) |
|
|
|
|
|
Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices |
|
|
works just as in NumPy. |
|
|
|
|
|
Other functions which may be useful for indexing arrays are :func:`take` and |
|
|
:func:`take_along_axis`. |
|
|
|
|
|
Differences from NumPy |
|
|
---------------------- |
|
|
|
|
|
.. Note:: |
|
|
|
|
|
MLX indexing is different from NumPy indexing in two important ways: |
|
|
|
|
|
* Indexing does not perform bounds checking. Indexing out of bounds is |
|
|
undefined behavior. |
|
|
* Boolean mask based indexing is not yet supported. |
|
|
|
|
|
The reason for the lack of bounds checking is that exceptions cannot propagate |
|
|
from the GPU. Performing bounds checking for array indices before launching the |
|
|
kernel would be extremely inefficient. |
|
|
|
|
|
Indexing with boolean masks is something that MLX may support in the future. In |
|
|
general, MLX has limited support for operations for which output |
|
|
*shapes* are dependent on input *data*. Other examples of these types of |
|
|
operations which MLX does not yet support include :func:`numpy.nonzero` and the |
|
|
single input version of :func:`numpy.where`. |
|
|
|
|
|
In Place Updates |
|
|
---------------- |
|
|
|
|
|
In place updates to indexed arrays are possible in MLX. For example: |
|
|
|
|
|
.. code-block:: shell |
|
|
|
|
|
>>> a = mx.array([1, 2, 3]) |
|
|
>>> a[2] = 0 |
|
|
>>> a |
|
|
array([1, 2, 0], dtype=int32) |
|
|
|
|
|
Just as in NumPy, in place updates will be reflected in all references to the |
|
|
same array: |
|
|
|
|
|
.. code-block:: shell |
|
|
|
|
|
>>> a = mx.array([1, 2, 3]) |
|
|
>>> b = a |
|
|
>>> b[2] = 0 |
|
|
>>> b |
|
|
array([1, 2, 0], dtype=int32) |
|
|
>>> a |
|
|
array([1, 2, 0], dtype=int32) |
|
|
|
|
|
Note that unlike NumPy, slicing an array creates a copy, not a view. So |
|
|
mutating it does not mutate the original array: |
|
|
|
|
|
.. code-block:: shell |
|
|
|
|
|
>>> a = mx.array([1, 2, 3]) |
|
|
>>> b = a[:] |
|
|
>>> b[2] = 0 |
|
|
>>> b |
|
|
array([1, 2, 0], dtype=int32) |
|
|
>>> a |
|
|
array([1, 2, 3], dtype=int32) |
|
|
|
|
|
Also unlike NumPy, updates to the same location are nondeterministic: |
|
|
|
|
|
.. code-block:: shell |
|
|
|
|
|
>>> a = mx.array([1, 2, 3]) |
|
|
>>> a[[0, 0]] = mx.array([4, 5]) |
|
|
|
|
|
The first element of ``a`` could be ``4`` or ``5``. |
|
|
|
|
|
Transformations of functions which use in-place updates are allowed and work as |
|
|
expected. For example: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
def fun(x, idx): |
|
|
x[idx] = 2.0 |
|
|
return x.sum() |
|
|
|
|
|
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1])) |
|
|
print(dfdx) |
|
|
|
|
|
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx`` |
|
|
and ones elsewhere. |
|
|
|