| | .. _usage_distributed: |
| |
|
| | Distributed Communication |
| | ========================= |
| |
|
| | .. currentmodule:: mlx.core.distributed |
| |
|
| | MLX supports distributed communication operations that allow the computational cost |
| | of training or inference to be shared across many physical machines. At the |
| | moment we support two different communication backends: |
| |
|
| | * `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a |
| | full-featured and mature distributed communications library |
| | * A **ring** backend of our own that uses native TCP sockets and should be |
| | faster for thunderbolt connections. |
| |
|
| | The list of all currently supported operations and their documentation can be |
| | seen in the :ref:`API docs<distributed>`. |
| |
|
| | .. note:: |
| | Some operations may not be supported or not as fast as they should be. |
| | We are adding more and tuning the ones we have as we are figuring out the |
| | best way to do distributed computing on Macs using MLX. |
| |
|
| | Getting Started |
| | --------------- |
| |
|
| | A distributed program in MLX is as simple as: |
| |
|
| | .. code:: python |
| |
|
| | import mlx.core as mx |
| |
|
| | world = mx.distributed.init() |
| | x = mx.distributed.all_sum(mx.ones(10)) |
| | print(world.rank(), x) |
| |
|
| | The program above sums the array ``mx.ones(10)`` across all |
| | distributed processes. However, when this script is run with ``python`` only |
| | one process is launched and no distributed communication takes place. Namely, |
| | all operations in ``mx.distributed`` are noops when the distributed group has a |
| | size of one. This property allows us to avoid code that checks if we are in a |
| | distributed setting similar to the one below: |
| |
|
| | .. code:: python |
| |
|
| | import mlx.core as mx |
| |
|
| | x = ... |
| | world = mx.distributed.init() |
| | # No need for the check we can simply do x = mx.distributed.all_sum(x) |
| | if world.size() > 1: |
| | x = mx.distributed.all_sum(x) |
| |
|
| | Running Distributed Programs |
| | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
|
| | MLX provides ``mlx.launch`` a helper script to launch distributed programs. |
| | Continuing with our initial example we can run it on localhost with 4 processes using |
| |
|
| | .. code:: shell |
| |
|
| | $ mlx.launch -n 4 my_script.py |
| | 3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) |
| | 2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) |
| | 1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) |
| | 0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) |
| |
|
| | We can also run it on some remote hosts by providing their IPs (provided that |
| | the script exists on all hosts and they are reachable by ssh) |
| |
|
| | .. code:: shell |
| |
|
| | $ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py |
| | 3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) |
| | 2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) |
| | 1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) |
| | 0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) |
| |
|
| | Consult the dedicated :doc:`usage guide<launching_distributed>` for more |
| | information on using ``mlx.launch``. |
| |
|
| | Selecting Backend |
| | ^^^^^^^^^^^^^^^^^ |
| |
|
| | You can select the backend you want to use when calling :func:`init` by passing |
| | one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to |
| | initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they |
| | both fail then a singleton group is created. |
| | |
| | .. note:: |
| | After a distributed backend is successfully initialized :func:`init` will |
| | return **the same backend** if called without arguments or with backend set to |
| | ``any``. |
| |
|
| | The following examples aim to clarify the backend initialization logic in MLX: |
| |
|
| | .. code:: python |
| |
|
| | # Case 1: Initialize MPI regardless if it was possible to initialize the ring backend |
| | world = mx.distributed.init(backend="mpi") |
| | world2 = mx.distributed.init() # subsequent calls return the MPI backend! |
| |
|
| | # Case 2: Initialize any backend |
| | world = mx.distributed.init(backend="any") # equivalent to no arguments |
| | world2 = mx.distributed.init() # same as above |
| |
|
| | # Case 3: Initialize both backends at the same time |
| | world_mpi = mx.distributed.init(backend="mpi") |
| | world_ring = mx.distributed.init(backend="ring") |
| | world_any = mx.distributed.init() # same as MPI because it was initialized first! |
| |
|
| | Training Example |
| | ---------------- |
| |
|
| | In this section we will adapt an MLX training loop to support data parallel |
| | distributed training. Namely, we will average the gradients across a set of |
| | hosts before applying them to the model. |
| |
|
| | Our training loop looks like the following code snippet if we omit the model, |
| | dataset and optimizer initialization. |
| |
|
| | .. code:: python |
| |
|
| | model = ... |
| | optimizer = ... |
| | dataset = ... |
| |
|
| | def step(model, x, y): |
| | loss, grads = loss_grad_fn(model, x, y) |
| | optimizer.update(model, grads) |
| | return loss |
| |
|
| | for x, y in dataset: |
| | loss = step(model, x, y) |
| | mx.eval(loss, model.parameters()) |
| |
|
| | All we have to do to average the gradients across machines is perform an |
| | :func:`all_sum` and divide by the size of the :class:`Group`. Namely we |
| | have to :func:`mlx.utils.tree_map` the gradients with following function. |
| |
|
| | .. code:: python |
| |
|
| | def all_avg(x): |
| | return mx.distributed.all_sum(x) / mx.distributed.init().size() |
| |
|
| | Putting everything together our training loop step looks as follows with |
| | everything else remaining the same. |
| |
|
| | .. code:: python |
| |
|
| | from mlx.utils import tree_map |
| |
|
| | def all_reduce_grads(grads): |
| | N = mx.distributed.init().size() |
| | if N == 1: |
| | return grads |
| | return tree_map( |
| | lambda x: mx.distributed.all_sum(x) / N, |
| | grads |
| | ) |
| |
|
| | def step(model, x, y): |
| | loss, grads = loss_grad_fn(model, x, y) |
| | grads = all_reduce_grads(grads) # <--- This line was added |
| | optimizer.update(model, grads) |
| | return loss |
| |
|
| | Utilizing ``nn.average_gradients`` |
| | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
|
| | Although the code example above works correctly; it performs one communication |
| | per gradient. It is significantly more efficient to aggregate several gradients |
| | together and perform fewer communication steps. |
| |
|
| | This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks |
| | almost identical to the example above: |
| |
|
| | .. code:: python |
| |
|
| | model = ... |
| | optimizer = ... |
| | dataset = ... |
| |
|
| | def step(model, x, y): |
| | loss, grads = loss_grad_fn(model, x, y) |
| | grads = mx.nn.average_gradients(grads) # <---- This line was added |
| | optimizer.update(model, grads) |
| | return loss |
| |
|
| | for x, y in dataset: |
| | loss = step(model, x, y) |
| | mx.eval(loss, model.parameters()) |
| |
|
| |
|
| | Getting Started with MPI |
| | ------------------------ |
| |
|
| | MLX already comes with the ability to "talk" to MPI if it is installed on the |
| | machine. Launching distributed MLX programs that use MPI can be done with |
| | ``mpirun`` as expected. However, in the following examples we will be using |
| | ``mlx.launch --backend mpi`` which takes care of some nuisances such as setting |
| | absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared |
| | library. |
| |
|
| | The simplest possible usage is the following which, assuming the minimal |
| | example in the beginning of this page, should result in: |
| |
|
| | .. code:: shell |
| |
|
| | $ mlx.launch --backend mpi -n 2 test.py |
| | 1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32) |
| | 0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32) |
| |
|
| | The above launches two processes on the same (local) machine and we can see |
| | both standard output streams. The processes send the array of 1s to each other |
| | and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would |
| | print 4 etc. |
| |
|
| | Installing MPI |
| | ^^^^^^^^^^^^^^ |
| |
|
| | MPI can be installed with Homebrew, using the Anaconda package manager or |
| | compiled from source. Most of our testing is done using ``openmpi`` installed |
| | with the Anaconda package manager as follows: |
| |
|
| | .. code:: shell |
| |
|
| | $ conda install conda-forge::openmpi |
| |
|
| | Installing with Homebrew may require specifying the location of ``libmpi.dyld`` |
| | so that MLX can find it and load it at runtime. This can simply be achieved by |
| | passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is |
| | done automatically by ``mlx.launch``. |
| |
|
| | .. code:: shell |
| |
|
| | $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py |
| | $ # or simply |
| | $ mlx.launch -n 2 test.py |
| |
|
| | Setting up Remote Hosts |
| | ^^^^^^^^^^^^^^^^^^^^^^^ |
| |
|
| | MPI can automatically connect to remote hosts and set up the communication over |
| | the network if the remote hosts can be accessed via ssh. A good checklist to |
| | debug connectivity issues is the following: |
| |
|
| | * ``ssh hostname`` works from all machines to all machines without asking for |
| | password or host confirmation |
| | * ``mpirun`` is accessible on all machines. |
| | * Ensure that the ``hostname`` used by MPI is the one that you have configured |
| | in the ``.ssh/config`` files on all machines. |
| |
|
| | Tuning MPI All Reduce |
| | ^^^^^^^^^^^^^^^^^^^^^ |
| |
|
| | .. note:: |
| |
|
| | For faster all reduce consider using the ring backend either with Thunderbolt |
| | connections or over Ethernet. |
| |
|
| | Configure MPI to use N tcp connections between each host to improve bandwidth |
| | by passing ``--mca btl_tcp_links N``. |
| |
|
| | Force MPI to use the most performant network interface by setting ``--mca |
| | btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want |
| | to use. |
| |
|
| | Getting Started with Ring |
| | ------------------------- |
| |
|
| | The ring backend does not depend on any third party library so it is always |
| | available. It uses TCP sockets so the nodes need to be reachable via a network. |
| | As the name suggests the nodes are connected in a ring which means that rank 1 |
| | can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3 |
| | and so on and so forth. As a result :func:`send` and :func:`recv` with |
| | arbitrary sender and receiver is not supported in the ring backend. |
| |
|
| | Defining a Ring |
| | ^^^^^^^^^^^^^^^ |
| |
|
| | The easiest way to define and use a ring is via a JSON hostfile and the |
| | ``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one |
| | defines a hostname to ssh into to run commands on this node and one or more IPs |
| | that this node will listen to for connections. |
| |
|
| | For example the hostfile below defines a 4 node ring. ``hostname1`` will be |
| | rank 0, ``hostname2`` rank 1 etc. |
| |
|
| | .. code:: json |
| |
|
| | [ |
| | {"ssh": "hostname1", "ips": ["123.123.123.1"]}, |
| | {"ssh": "hostname2", "ips": ["123.123.123.2"]}, |
| | {"ssh": "hostname3", "ips": ["123.123.123.3"]}, |
| | {"ssh": "hostname4", "ips": ["123.123.123.4"]} |
| | ] |
| |
|
| | Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each |
| | node, run the script which will listen for connections in each of the provided |
| | IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a |
| | connection from ``123.123.123.4`` and so on and so forth. |
| |
|
| | Thunderbolt Ring |
| | ^^^^^^^^^^^^^^^^ |
| |
|
| | Although the ring backend can have benefits over MPI even for Ethernet, its |
| | main purpose is to use Thunderbolt rings for higher bandwidth communication. |
| | Setting up such thunderbolt rings can be done manually, but is a relatively |
| | tedious process. To simplify this, we provide the utility ``mlx.distributed_config``. |
| |
|
| | To use ``mlx.distributed_config`` your computers need to be accessible by ssh via |
| | Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the |
| | utility as follows: |
| |
|
| | .. code:: shell |
| |
|
| | mlx.distributed_config --verbose --hosts host1,host2,host3,host4 |
| |
|
| | By default the script will attempt to discover the thunderbolt ring and provide |
| | you with the commands to configure each node as well as the ``hostfile.json`` |
| | to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes |
| | then ``--auto-setup`` can be used to configure them automatically. |
| |
|
| | To validate your connection without configuring anything |
| | ``mlx.distributed_config`` can also plot the ring using DOT format. |
| |
|
| | .. code:: shell |
| |
|
| | mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot |
| | dot -Tpng ring.dot >ring.png |
| | open ring.png |
| |
|
| | If you want to go through the process manually, the steps are as follows: |
| |
|
| | * Disable the thunderbolt bridge interface |
| | * For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces |
| | corresponding to that cable in nodes ``i`` and ``i + 1``. |
| | * Set up a unique subnetwork connecting the two nodes for the corresponding |
| | interfaces. For instance if the cable corresponds to ``en2`` on node ``i`` |
| | and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and |
| | ``192.168.0.2`` respectively to the two nodes. For more details you can see |
| | the commands prepared by the utility script. |
| |
|