|
|
|
|
|
|
|
|
#include <sstream> |
|
|
|
|
|
#include <nanobind/nanobind.h> |
|
|
#include <nanobind/stl/string.h> |
|
|
|
|
|
#include "mlx/device.h" |
|
|
#include "mlx/utils.h" |
|
|
|
|
|
namespace mx = mlx::core; |
|
|
namespace nb = nanobind; |
|
|
using namespace nb::literals; |
|
|
|
|
|
void init_device(nb::module_& m) { |
|
|
auto device_class = nb::class_<mx::Device>( |
|
|
m, "Device", R"pbdoc(A device to run operations on.)pbdoc"); |
|
|
nb::enum_<mx::Device::DeviceType>(m, "DeviceType") |
|
|
.value("cpu", mx::Device::DeviceType::cpu) |
|
|
.value("gpu", mx::Device::DeviceType::gpu) |
|
|
.export_values() |
|
|
.def( |
|
|
"__eq__", |
|
|
[](const mx::Device::DeviceType& d, const nb::object& other) { |
|
|
if (!nb::isinstance<mx::Device>(other) && |
|
|
!nb::isinstance<mx::Device::DeviceType>(other)) { |
|
|
return false; |
|
|
} |
|
|
return d == nb::cast<mx::Device>(other); |
|
|
}); |
|
|
|
|
|
device_class |
|
|
.def(nb::init<mx::Device::DeviceType, int>(), "type"_a, "index"_a = 0) |
|
|
.def_ro("type", &mx::Device::type) |
|
|
.def( |
|
|
"__repr__", |
|
|
[](const mx::Device& d) { |
|
|
std::ostringstream os; |
|
|
os << d; |
|
|
return os.str(); |
|
|
}) |
|
|
.def("__eq__", [](const mx::Device& d, const nb::object& other) { |
|
|
if (!nb::isinstance<mx::Device>(other) && |
|
|
!nb::isinstance<mx::Device::DeviceType>(other)) { |
|
|
return false; |
|
|
} |
|
|
return d == nb::cast<mx::Device>(other); |
|
|
}); |
|
|
|
|
|
nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>(); |
|
|
|
|
|
m.def( |
|
|
"default_device", |
|
|
&mx::default_device, |
|
|
R"pbdoc(Get the default device.)pbdoc"); |
|
|
m.def( |
|
|
"set_default_device", |
|
|
&mx::set_default_device, |
|
|
"device"_a, |
|
|
R"pbdoc(Set the default device.)pbdoc"); |
|
|
m.def( |
|
|
"is_available", |
|
|
&mx::is_available, |
|
|
"device"_a, |
|
|
R"pbdoc(Check if a back-end is available for the given device.)pbdoc"); |
|
|
} |
|
|
|