File size: 1,952 Bytes
712dbf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
// Copyright © 2023-2024 Apple Inc.
#include <sstream>
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include "mlx/device.h"
#include "mlx/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
void init_device(nb::module_& m) {
auto device_class = nb::class_<mx::Device>(
m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
nb::enum_<mx::Device::DeviceType>(m, "DeviceType")
.value("cpu", mx::Device::DeviceType::cpu)
.value("gpu", mx::Device::DeviceType::gpu)
.export_values()
.def(
"__eq__",
[](const mx::Device::DeviceType& d, const nb::object& other) {
if (!nb::isinstance<mx::Device>(other) &&
!nb::isinstance<mx::Device::DeviceType>(other)) {
return false;
}
return d == nb::cast<mx::Device>(other);
});
device_class
.def(nb::init<mx::Device::DeviceType, int>(), "type"_a, "index"_a = 0)
.def_ro("type", &mx::Device::type)
.def(
"__repr__",
[](const mx::Device& d) {
std::ostringstream os;
os << d;
return os.str();
})
.def("__eq__", [](const mx::Device& d, const nb::object& other) {
if (!nb::isinstance<mx::Device>(other) &&
!nb::isinstance<mx::Device::DeviceType>(other)) {
return false;
}
return d == nb::cast<mx::Device>(other);
});
nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();
m.def(
"default_device",
&mx::default_device,
R"pbdoc(Get the default device.)pbdoc");
m.def(
"set_default_device",
&mx::set_default_device,
"device"_a,
R"pbdoc(Set the default device.)pbdoc");
m.def(
"is_available",
&mx::is_available,
"device"_a,
R"pbdoc(Check if a back-end is available for the given device.)pbdoc");
}
|