| | |
| | #pragma once |
| | #include <optional> |
| |
|
| | #include <nanobind/nanobind.h> |
| |
|
| | #include "mlx/array.h" |
| | #include "mlx/utils.h" |
| |
|
| | |
| | |
| | #ifndef Py_bf_getbuffer |
| | #define Py_bf_getbuffer 1 |
| | #define Py_bf_releasebuffer 2 |
| | #endif |
| |
|
| | namespace mx = mlx::core; |
| | namespace nb = nanobind; |
| |
|
| | std::string buffer_format(const mx::array& a) { |
| | |
| | switch (a.dtype()) { |
| | case mx::bool_: |
| | return "?"; |
| | case mx::uint8: |
| | return "B"; |
| | case mx::uint16: |
| | return "H"; |
| | case mx::uint32: |
| | return "I"; |
| | case mx::uint64: |
| | return "Q"; |
| | case mx::int8: |
| | return "b"; |
| | case mx::int16: |
| | return "h"; |
| | case mx::int32: |
| | return "i"; |
| | case mx::int64: |
| | return "q"; |
| | case mx::float16: |
| | return "e"; |
| | case mx::float32: |
| | return "f"; |
| | case mx::bfloat16: |
| | return "B"; |
| | case mx::float64: |
| | return "d"; |
| | case mx::complex64: |
| | return "Zf\0"; |
| | default: { |
| | std::ostringstream os; |
| | os << "bad dtype: " << a.dtype(); |
| | throw std::runtime_error(os.str()); |
| | } |
| | } |
| | } |
| |
|
| | struct buffer_info { |
| | std::string format; |
| | std::vector<Py_ssize_t> shape; |
| | std::vector<Py_ssize_t> strides; |
| |
|
| | buffer_info( |
| | std::string format, |
| | std::vector<Py_ssize_t> shape_in, |
| | std::vector<Py_ssize_t> strides_in) |
| | : format(std::move(format)), |
| | shape(std::move(shape_in)), |
| | strides(std::move(strides_in)) {} |
| |
|
| | buffer_info(const buffer_info&) = delete; |
| | buffer_info& operator=(const buffer_info&) = delete; |
| |
|
| | buffer_info(buffer_info&& other) noexcept { |
| | (*this) = std::move(other); |
| | } |
| |
|
| | buffer_info& operator=(buffer_info&& rhs) noexcept { |
| | format = std::move(rhs.format); |
| | shape = std::move(rhs.shape); |
| | strides = std::move(rhs.strides); |
| | return *this; |
| | } |
| | }; |
| |
|
| | extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) { |
| | std::memset(view, 0, sizeof(Py_buffer)); |
| | auto a = nb::cast<mx::array>(nb::handle(obj)); |
| |
|
| | { |
| | nb::gil_scoped_release nogil; |
| | a.eval(); |
| | } |
| |
|
| | std::vector<Py_ssize_t> shape(a.shape().begin(), a.shape().end()); |
| | std::vector<Py_ssize_t> strides(a.strides().begin(), a.strides().end()); |
| | for (auto& s : strides) { |
| | s *= a.itemsize(); |
| | } |
| | buffer_info* info = |
| | new buffer_info(buffer_format(a), std::move(shape), std::move(strides)); |
| |
|
| | view->obj = obj; |
| | view->ndim = a.ndim(); |
| | view->internal = info; |
| | view->buf = a.data<void>(); |
| | view->itemsize = a.itemsize(); |
| | view->len = a.nbytes(); |
| | view->readonly = false; |
| | if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { |
| | view->format = const_cast<char*>(info->format.c_str()); |
| | } |
| | if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { |
| | view->strides = info->strides.data(); |
| | view->shape = info->shape.data(); |
| | } |
| | Py_INCREF(view->obj); |
| | return 0; |
| | } |
| |
|
| | extern "C" inline void releasebuffer(PyObject*, Py_buffer* view) { |
| | delete (buffer_info*)view->internal; |
| | } |
| |
|