|
|
|
|
|
|
|
|
#include <sstream> |
|
|
|
|
|
#include <nanobind/nanobind.h> |
|
|
#include <nanobind/stl/optional.h> |
|
|
#include <nanobind/stl/string.h> |
|
|
#include <nanobind/stl/variant.h> |
|
|
|
|
|
#include "mlx/stream.h" |
|
|
#include "mlx/utils.h" |
|
|
|
|
|
namespace mx = mlx::core; |
|
|
namespace nb = nanobind; |
|
|
using namespace nb::literals; |
|
|
|
|
|
|
|
|
class PyStreamContext { |
|
|
public: |
|
|
PyStreamContext(mx::StreamOrDevice s) : _inner(nullptr) { |
|
|
if (std::holds_alternative<std::monostate>(s)) { |
|
|
throw std::runtime_error( |
|
|
"[StreamContext] Invalid argument, please specify a stream or device."); |
|
|
} |
|
|
_s = s; |
|
|
} |
|
|
|
|
|
void enter() { |
|
|
_inner = new mx::StreamContext(_s); |
|
|
} |
|
|
|
|
|
void exit() { |
|
|
if (_inner != nullptr) { |
|
|
delete _inner; |
|
|
_inner = nullptr; |
|
|
} |
|
|
} |
|
|
|
|
|
private: |
|
|
mx::StreamOrDevice _s; |
|
|
mx::StreamContext* _inner; |
|
|
}; |
|
|
|
|
|
void init_stream(nb::module_& m) { |
|
|
nb::class_<mx::Stream>( |
|
|
m, |
|
|
"Stream", |
|
|
R"pbdoc( |
|
|
A stream for running operations on a given device. |
|
|
)pbdoc") |
|
|
.def_ro("device", &mx::Stream::device) |
|
|
.def( |
|
|
"__repr__", |
|
|
[](const mx::Stream& s) { |
|
|
std::ostringstream os; |
|
|
os << s; |
|
|
return os.str(); |
|
|
}) |
|
|
.def("__eq__", [](const mx::Stream& s, const nb::object& other) { |
|
|
return nb::isinstance<mx::Stream>(other) && |
|
|
s == nb::cast<mx::Stream>(other); |
|
|
}); |
|
|
|
|
|
nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>(); |
|
|
|
|
|
m.def( |
|
|
"default_stream", |
|
|
&mx::default_stream, |
|
|
"device"_a, |
|
|
R"pbdoc(Get the device's default stream.)pbdoc"); |
|
|
m.def( |
|
|
"set_default_stream", |
|
|
&mx::set_default_stream, |
|
|
"stream"_a, |
|
|
R"pbdoc( |
|
|
Set the default stream. |
|
|
|
|
|
This will make the given stream the default for the |
|
|
streams device. It will not change the default device. |
|
|
|
|
|
Args: |
|
|
stream (stream): Stream to make the default. |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"new_stream", |
|
|
&mx::new_stream, |
|
|
"device"_a, |
|
|
R"pbdoc(Make a new stream on the given device.)pbdoc"); |
|
|
|
|
|
nb::class_<PyStreamContext>(m, "StreamContext", R"pbdoc( |
|
|
A context manager for setting the current device and stream. |
|
|
|
|
|
See :func:`stream` for usage. |
|
|
|
|
|
Args: |
|
|
s: The stream or device to set as the default. |
|
|
)pbdoc") |
|
|
.def(nb::init<mx::StreamOrDevice>(), "s"_a) |
|
|
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); }) |
|
|
.def( |
|
|
"__exit__", |
|
|
[](PyStreamContext& scm, |
|
|
const std::optional<nb::type_object>& exc_type, |
|
|
const std::optional<nb::object>& exc_value, |
|
|
const std::optional<nb::object>& traceback) { scm.exit(); }, |
|
|
"exc_type"_a = nb::none(), |
|
|
"exc_value"_a = nb::none(), |
|
|
"traceback"_a = nb::none()); |
|
|
m.def( |
|
|
"stream", |
|
|
[](mx::StreamOrDevice s) { return PyStreamContext(s); }, |
|
|
"s"_a, |
|
|
R"pbdoc( |
|
|
Create a context manager to set the default device and stream. |
|
|
|
|
|
Args: |
|
|
s: The :obj:`Stream` or :obj:`Device` to set as the default. |
|
|
|
|
|
Returns: |
|
|
A context manager that sets the default device and stream. |
|
|
|
|
|
Example: |
|
|
|
|
|
.. code-block::python |
|
|
|
|
|
import mlx.core as mx |
|
|
|
|
|
# Create a context manager for the default device and stream. |
|
|
with mx.stream(mx.cpu): |
|
|
# Operations here will use mx.cpu by default. |
|
|
pass |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"synchronize", |
|
|
[](const std::optional<mx::Stream>& s) { |
|
|
s ? mx::synchronize(s.value()) : mx::synchronize(); |
|
|
}, |
|
|
"stream"_a = nb::none(), |
|
|
R"pbdoc( |
|
|
Synchronize with the given stream. |
|
|
|
|
|
Args: |
|
|
stream (Stream, optional): The stream to synchronize with. If ``None`` |
|
|
then the default stream of the default device is used. |
|
|
Default: ``None``. |
|
|
)pbdoc"); |
|
|
} |
|
|
|