File size: 3,971 Bytes
712dbf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
// Copyright © 2023-2024 Apple Inc.
#include <sstream>
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/variant.h>
#include "mlx/stream.h"
#include "mlx/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
// Create the StreamContext on enter and delete on exit.
class PyStreamContext {
public:
PyStreamContext(mx::StreamOrDevice s) : _inner(nullptr) {
if (std::holds_alternative<std::monostate>(s)) {
throw std::runtime_error(
"[StreamContext] Invalid argument, please specify a stream or device.");
}
_s = s;
}
void enter() {
_inner = new mx::StreamContext(_s);
}
void exit() {
if (_inner != nullptr) {
delete _inner;
_inner = nullptr;
}
}
private:
mx::StreamOrDevice _s;
mx::StreamContext* _inner;
};
void init_stream(nb::module_& m) {
nb::class_<mx::Stream>(
m,
"Stream",
R"pbdoc(
A stream for running operations on a given device.
)pbdoc")
.def_ro("device", &mx::Stream::device)
.def(
"__repr__",
[](const mx::Stream& s) {
std::ostringstream os;
os << s;
return os.str();
})
.def("__eq__", [](const mx::Stream& s, const nb::object& other) {
return nb::isinstance<mx::Stream>(other) &&
s == nb::cast<mx::Stream>(other);
});
nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();
m.def(
"default_stream",
&mx::default_stream,
"device"_a,
R"pbdoc(Get the device's default stream.)pbdoc");
m.def(
"set_default_stream",
&mx::set_default_stream,
"stream"_a,
R"pbdoc(
Set the default stream.
This will make the given stream the default for the
streams device. It will not change the default device.
Args:
stream (stream): Stream to make the default.
)pbdoc");
m.def(
"new_stream",
&mx::new_stream,
"device"_a,
R"pbdoc(Make a new stream on the given device.)pbdoc");
nb::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
A context manager for setting the current device and stream.
See :func:`stream` for usage.
Args:
s: The stream or device to set as the default.
)pbdoc")
.def(nb::init<mx::StreamOrDevice>(), "s"_a)
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
.def(
"__exit__",
[](PyStreamContext& scm,
const std::optional<nb::type_object>& exc_type,
const std::optional<nb::object>& exc_value,
const std::optional<nb::object>& traceback) { scm.exit(); },
"exc_type"_a = nb::none(),
"exc_value"_a = nb::none(),
"traceback"_a = nb::none());
m.def(
"stream",
[](mx::StreamOrDevice s) { return PyStreamContext(s); },
"s"_a,
R"pbdoc(
Create a context manager to set the default device and stream.
Args:
s: The :obj:`Stream` or :obj:`Device` to set as the default.
Returns:
A context manager that sets the default device and stream.
Example:
.. code-block::python
import mlx.core as mx
# Create a context manager for the default device and stream.
with mx.stream(mx.cpu):
# Operations here will use mx.cpu by default.
pass
)pbdoc");
m.def(
"synchronize",
[](const std::optional<mx::Stream>& s) {
s ? mx::synchronize(s.value()) : mx::synchronize();
},
"stream"_a = nb::none(),
R"pbdoc(
Synchronize with the given stream.
Args:
stream (Stream, optional): The stream to synchronize with. If ``None``
then the default stream of the default device is used.
Default: ``None``.
)pbdoc");
}
|