|
|
|
|
|
|
|
|
#include <nanobind/nanobind.h> |
|
|
#include <nanobind/stl/optional.h> |
|
|
#include <nanobind/stl/shared_ptr.h> |
|
|
#include <nanobind/stl/string.h> |
|
|
#include <nanobind/stl/variant.h> |
|
|
#include <nanobind/stl/vector.h> |
|
|
|
|
|
#include "mlx/distributed/distributed.h" |
|
|
#include "mlx/distributed/ops.h" |
|
|
#include "python/src/small_vector.h" |
|
|
#include "python/src/utils.h" |
|
|
|
|
|
namespace mx = mlx::core; |
|
|
namespace nb = nanobind; |
|
|
using namespace nb::literals; |
|
|
|
|
|
void init_distributed(nb::module_& parent_module) { |
|
|
auto m = parent_module.def_submodule( |
|
|
"distributed", "mlx.core.distributed: Communication operations"); |
|
|
|
|
|
nb::class_<mx::distributed::Group>( |
|
|
m, |
|
|
"Group", |
|
|
R"pbcopy( |
|
|
An :class:`mlx.core.distributed.Group` represents a group of independent mlx |
|
|
processes that can communicate. |
|
|
)pbcopy") |
|
|
.def( |
|
|
"rank", &mx::distributed::Group::rank, "Get the rank of this process") |
|
|
.def("size", &mx::distributed::Group::size, "Get the size of the group") |
|
|
.def( |
|
|
"split", |
|
|
&mx::distributed::Group::split, |
|
|
"color"_a, |
|
|
"key"_a = -1, |
|
|
nb::sig("def split(self, color: int, key: int = -1) -> Group"), |
|
|
R"pbdoc( |
|
|
Split the group to subgroups based on the provided color. |
|
|
|
|
|
Processes that use the same color go to the same group. The ``key`` |
|
|
argument defines the rank in the new group. The smaller the key the |
|
|
smaller the rank. If the key is negative then the rank in the |
|
|
current group is used. |
|
|
|
|
|
Args: |
|
|
color (int): A value to group processes into subgroups. |
|
|
key (int, optional): A key to optionally change the rank ordering |
|
|
of the processes. |
|
|
)pbdoc"); |
|
|
|
|
|
m.def( |
|
|
"is_available", |
|
|
&mx::distributed::is_available, |
|
|
R"pbdoc( |
|
|
Check if a communication backend is available. |
|
|
)pbdoc"); |
|
|
|
|
|
m.def( |
|
|
"init", |
|
|
&mx::distributed::init, |
|
|
"strict"_a = false, |
|
|
"backend"_a = "any", |
|
|
nb::sig("def init(strict: bool = False, backend: str = 'any') -> Group"), |
|
|
R"pbdoc( |
|
|
Initialize the communication backend and create the global communication group. |
|
|
|
|
|
Example: |
|
|
|
|
|
.. code:: python |
|
|
|
|
|
import mlx.core as mx |
|
|
|
|
|
group = mx.distributed.init(backend="ring") |
|
|
|
|
|
Args: |
|
|
strict (bool, optional): If set to False it returns a singleton group |
|
|
in case ``mx.distributed.is_available()`` returns False otherwise |
|
|
it throws a runtime error. Default: ``False`` |
|
|
backend (str, optional): Which distributed backend to initialize. |
|
|
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all |
|
|
available backends are tried and the first one that succeeds |
|
|
becomes the global group which will be returned in subsequent |
|
|
calls. Default: ``any`` |
|
|
|
|
|
Returns: |
|
|
Group: The group representing all the launched processes. |
|
|
)pbdoc"); |
|
|
|
|
|
m.def( |
|
|
"all_sum", |
|
|
[](const ScalarOrArray& x, |
|
|
std::optional<mx::distributed::Group> group, |
|
|
mx::StreamOrDevice s) { |
|
|
return mx::distributed::all_sum(to_array(x), group, s); |
|
|
}, |
|
|
"x"_a, |
|
|
nb::kw_only(), |
|
|
"group"_a = nb::none(), |
|
|
"stream"_a = nb::none(), |
|
|
nb::sig( |
|
|
"def all_sum(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), |
|
|
R"pbdoc( |
|
|
All reduce sum. |
|
|
|
|
|
Sum the ``x`` arrays from all processes in the group. |
|
|
|
|
|
Args: |
|
|
x (array): Input array. |
|
|
group (Group): The group of processes that will participate in the |
|
|
reduction. If set to ``None`` the global group is used. Default: |
|
|
``None``. |
|
|
stream (Stream, optional): Stream or device. Defaults to ``None`` |
|
|
in which case the default stream of the default device is used. |
|
|
|
|
|
Returns: |
|
|
array: The sum of all ``x`` arrays. |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"all_max", |
|
|
[](const ScalarOrArray& x, |
|
|
std::optional<mx::distributed::Group> group, |
|
|
mx::StreamOrDevice s) { |
|
|
return mx::distributed::all_max(to_array(x), group, s); |
|
|
}, |
|
|
"x"_a, |
|
|
nb::kw_only(), |
|
|
"group"_a = nb::none(), |
|
|
"stream"_a = nb::none(), |
|
|
nb::sig( |
|
|
"def all_max(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), |
|
|
R"pbdoc( |
|
|
All reduce max. |
|
|
|
|
|
Find the maximum of the ``x`` arrays from all processes in the group. |
|
|
|
|
|
Args: |
|
|
x (array): Input array. |
|
|
group (Group): The group of processes that will participate in the |
|
|
reduction. If set to ``None`` the global group is used. Default: |
|
|
``None``. |
|
|
stream (Stream, optional): Stream or device. Defaults to ``None`` |
|
|
in which case the default stream of the default device is used. |
|
|
|
|
|
Returns: |
|
|
array: The maximum of all ``x`` arrays. |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"all_min", |
|
|
[](const ScalarOrArray& x, |
|
|
std::optional<mx::distributed::Group> group, |
|
|
mx::StreamOrDevice s) { |
|
|
return mx::distributed::all_min(to_array(x), group, s); |
|
|
}, |
|
|
"x"_a, |
|
|
nb::kw_only(), |
|
|
"group"_a = nb::none(), |
|
|
"stream"_a = nb::none(), |
|
|
nb::sig( |
|
|
"def all_min(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), |
|
|
R"pbdoc( |
|
|
All reduce min. |
|
|
|
|
|
Find the minimum of the ``x`` arrays from all processes in the group. |
|
|
|
|
|
Args: |
|
|
x (array): Input array. |
|
|
group (Group): The group of processes that will participate in the |
|
|
reduction. If set to ``None`` the global group is used. Default: |
|
|
``None``. |
|
|
stream (Stream, optional): Stream or device. Defaults to ``None`` |
|
|
in which case the default stream of the default device is used. |
|
|
|
|
|
Returns: |
|
|
array: The minimum of all ``x`` arrays. |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"all_gather", |
|
|
[](const ScalarOrArray& x, |
|
|
std::optional<mx::distributed::Group> group, |
|
|
mx::StreamOrDevice s) { |
|
|
return mx::distributed::all_gather(to_array(x), group, s); |
|
|
}, |
|
|
"x"_a, |
|
|
nb::kw_only(), |
|
|
"group"_a = nb::none(), |
|
|
"stream"_a = nb::none(), |
|
|
nb::sig( |
|
|
"def all_gather(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), |
|
|
R"pbdoc( |
|
|
Gather arrays from all processes. |
|
|
|
|
|
Gather the ``x`` arrays from all processes in the group and concatenate |
|
|
them along the first axis. The arrays should all have the same shape. |
|
|
|
|
|
Args: |
|
|
x (array): Input array. |
|
|
group (Group): The group of processes that will participate in the |
|
|
gather. If set to ``None`` the global group is used. Default: |
|
|
``None``. |
|
|
stream (Stream, optional): Stream or device. Defaults to ``None`` |
|
|
in which case the default stream of the default device is used. |
|
|
|
|
|
Returns: |
|
|
array: The concatenation of all ``x`` arrays. |
|
|
)pbdoc"); |
|
|
|
|
|
m.def( |
|
|
"send", |
|
|
[](const ScalarOrArray& x, |
|
|
int dst, |
|
|
std::optional<mx::distributed::Group> group, |
|
|
mx::StreamOrDevice s) { |
|
|
return mx::distributed::send(to_array(x), dst, group, s); |
|
|
}, |
|
|
"x"_a, |
|
|
"dst"_a, |
|
|
nb::kw_only(), |
|
|
"group"_a = nb::none(), |
|
|
"stream"_a = nb::none(), |
|
|
nb::sig( |
|
|
"def send(x: array, dst: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), |
|
|
R"pbdoc( |
|
|
Send an array from the current process to the process that has rank |
|
|
``dst`` in the group. |
|
|
|
|
|
Args: |
|
|
x (array): Input array. |
|
|
dst (int): Rank of the destination process in the group. |
|
|
group (Group): The group of processes that will participate in the |
|
|
sned. If set to ``None`` the global group is used. Default: |
|
|
``None``. |
|
|
stream (Stream, optional): Stream or device. Defaults to ``None`` |
|
|
in which case the default stream of the default device is used. |
|
|
|
|
|
Returns: |
|
|
array: An array identical to ``x`` which when evaluated the send is performed. |
|
|
)pbdoc"); |
|
|
|
|
|
m.def( |
|
|
"recv", |
|
|
&mx::distributed::recv, |
|
|
"shape"_a, |
|
|
"dtype"_a, |
|
|
"src"_a, |
|
|
nb::kw_only(), |
|
|
"group"_a = nb::none(), |
|
|
"stream"_a = nb::none(), |
|
|
nb::sig( |
|
|
"def recv(shape: Sequence[int], dtype: Dtype, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), |
|
|
R"pbdoc( |
|
|
Recv an array with shape ``shape`` and dtype ``dtype`` from process |
|
|
with rank ``src``. |
|
|
|
|
|
Args: |
|
|
shape (Tuple[int]): The shape of the array we are receiving. |
|
|
dtype (Dtype): The data type of the array we are receiving. |
|
|
src (int): Rank of the source process in the group. |
|
|
group (Group): The group of processes that will participate in the |
|
|
recv. If set to ``None`` the global group is used. Default: |
|
|
``None``. |
|
|
stream (Stream, optional): Stream or device. Defaults to ``None`` |
|
|
in which case the default stream of the default device is used. |
|
|
|
|
|
Returns: |
|
|
array: The array that was received from ``src``. |
|
|
)pbdoc"); |
|
|
|
|
|
m.def( |
|
|
"recv_like", |
|
|
[](const ScalarOrArray& x, |
|
|
int src, |
|
|
std::optional<mx::distributed::Group> group, |
|
|
mx::StreamOrDevice s) { |
|
|
return mx::distributed::recv_like(to_array(x), src, group, s); |
|
|
}, |
|
|
"x"_a, |
|
|
"src"_a, |
|
|
nb::kw_only(), |
|
|
"group"_a = nb::none(), |
|
|
"stream"_a = nb::none(), |
|
|
nb::sig( |
|
|
"def recv_like(x: array, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), |
|
|
R"pbdoc( |
|
|
Recv an array with shape and type like ``x`` from process with rank |
|
|
``src``. |
|
|
|
|
|
It is equivalent to calling ``mx.distributed.recv(x.shape, x.dtype, src)``. |
|
|
|
|
|
Args: |
|
|
x (array): An array defining the shape and dtype of the array we are |
|
|
receiving. |
|
|
src (int): Rank of the source process in the group. |
|
|
group (Group): The group of processes that will participate in the |
|
|
recv. If set to ``None`` the global group is used. Default: |
|
|
``None``. |
|
|
stream (Stream, optional): Stream or device. Defaults to ``None`` |
|
|
in which case the default stream of the default device is used. |
|
|
|
|
|
Returns: |
|
|
array: The array that was received from ``src``. |
|
|
)pbdoc"); |
|
|
} |
|
|
|