|
|
|
|
|
#include <iostream> |
|
|
|
|
|
#include <nanobind/nanobind.h> |
|
|
#include <nanobind/stl/optional.h> |
|
|
#include <nanobind/stl/string.h> |
|
|
#include <nanobind/stl/unordered_map.h> |
|
|
#include <nanobind/stl/variant.h> |
|
|
#include <nanobind/stl/vector.h> |
|
|
|
|
|
#include "mlx/backend/metal/metal.h" |
|
|
#include "mlx/memory.h" |
|
|
#include "python/src/small_vector.h" |
|
|
|
|
|
namespace mx = mlx::core; |
|
|
namespace nb = nanobind; |
|
|
using namespace nb::literals; |
|
|
|
|
|
bool DEPRECATE(const char* old_fn, const char* new_fn) { |
|
|
std::cerr << old_fn << " is deprecated and will be removed in a future " |
|
|
<< "version. Use " << new_fn << " instead." << std::endl; |
|
|
return true; |
|
|
} |
|
|
|
|
|
#define DEPRECATE(oldfn, newfn) static bool dep = DEPRECATE(oldfn, newfn) |
|
|
|
|
|
void init_metal(nb::module_& m) { |
|
|
nb::module_ metal = m.def_submodule("metal", "mlx.metal"); |
|
|
metal.def( |
|
|
"is_available", |
|
|
&mx::metal::is_available, |
|
|
R"pbdoc( |
|
|
Check if the Metal back-end is available. |
|
|
)pbdoc"); |
|
|
metal.def("get_active_memory", []() { |
|
|
DEPRECATE("mx.metal.get_active_memory", "mx.get_active_memory"); |
|
|
return mx::get_active_memory(); |
|
|
}); |
|
|
metal.def("get_peak_memory", []() { |
|
|
DEPRECATE("mx.metal.get_peak_memory", "mx.get_peak_memory"); |
|
|
return mx::get_peak_memory(); |
|
|
}); |
|
|
metal.def("reset_peak_memory", []() { |
|
|
DEPRECATE("mx.metal.reset_peak_memory", "mx.reset_peak_memory"); |
|
|
mx::reset_peak_memory(); |
|
|
}); |
|
|
metal.def("get_cache_memory", []() { |
|
|
DEPRECATE("mx.metal.get_cache_memory", "mx.get_cache_memory"); |
|
|
return mx::get_cache_memory(); |
|
|
}); |
|
|
metal.def( |
|
|
"set_memory_limit", |
|
|
[](size_t limit) { |
|
|
DEPRECATE("mx.metal.set_memory_limit", "mx.set_memory_limit"); |
|
|
return mx::set_memory_limit(limit); |
|
|
}, |
|
|
"limit"_a); |
|
|
metal.def( |
|
|
"set_cache_limit", |
|
|
[](size_t limit) { |
|
|
DEPRECATE("mx.metal.set_cache_limit", "mx.set_cache_limit"); |
|
|
return mx::set_cache_limit(limit); |
|
|
}, |
|
|
"limit"_a); |
|
|
metal.def( |
|
|
"set_wired_limit", |
|
|
[](size_t limit) { |
|
|
DEPRECATE("mx.metal.set_wired_limit", "mx.set_wired_limit"); |
|
|
return mx::set_wired_limit(limit); |
|
|
}, |
|
|
"limit"_a); |
|
|
metal.def("clear_cache", []() { |
|
|
DEPRECATE("mx.metal.clear_cache", "mx.clear_cache"); |
|
|
mx::clear_cache(); |
|
|
}); |
|
|
metal.def( |
|
|
"start_capture", |
|
|
&mx::metal::start_capture, |
|
|
"path"_a, |
|
|
R"pbdoc( |
|
|
Start a Metal capture. |
|
|
|
|
|
Args: |
|
|
path (str): The path to save the capture which should have |
|
|
the extension ``.gputrace``. |
|
|
)pbdoc"); |
|
|
metal.def( |
|
|
"stop_capture", |
|
|
&mx::metal::stop_capture, |
|
|
R"pbdoc( |
|
|
Stop a Metal capture. |
|
|
)pbdoc"); |
|
|
metal.def( |
|
|
"device_info", |
|
|
&mx::metal::device_info, |
|
|
R"pbdoc( |
|
|
Get information about the GPU device and system settings. |
|
|
|
|
|
Currently returns: |
|
|
|
|
|
* ``architecture`` |
|
|
* ``max_buffer_size`` |
|
|
* ``max_recommended_working_set_size`` |
|
|
* ``memory_size`` |
|
|
* ``resource_limit`` |
|
|
|
|
|
Returns: |
|
|
dict: A dictionary with string keys and string or integer values. |
|
|
)pbdoc"); |
|
|
} |
|
|
|