File size: 3,085 Bytes
712dbf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
// Copyright © 2024 Apple Inc.
#pragma once
#include <optional>
#include <nanobind/nanobind.h>
#include "mlx/array.h"
#include "mlx/utils.h"
// Only defined in >= Python 3.9
// https://github.com/python/cpython/blob/f6cdc6b4a191b75027de342aa8b5d344fb31313e/Include/typeslots.h#L2-L3
#ifndef Py_bf_getbuffer
#define Py_bf_getbuffer 1
#define Py_bf_releasebuffer 2
#endif
namespace mx = mlx::core;
namespace nb = nanobind;
std::string buffer_format(const mx::array& a) {
// https://docs.python.org/3.10/library/struct.html#format-characters
switch (a.dtype()) {
case mx::bool_:
return "?";
case mx::uint8:
return "B";
case mx::uint16:
return "H";
case mx::uint32:
return "I";
case mx::uint64:
return "Q";
case mx::int8:
return "b";
case mx::int16:
return "h";
case mx::int32:
return "i";
case mx::int64:
return "q";
case mx::float16:
return "e";
case mx::float32:
return "f";
case mx::bfloat16:
return "B";
case mx::float64:
return "d";
case mx::complex64:
return "Zf\0";
default: {
std::ostringstream os;
os << "bad dtype: " << a.dtype();
throw std::runtime_error(os.str());
}
}
}
struct buffer_info {
std::string format;
std::vector<Py_ssize_t> shape;
std::vector<Py_ssize_t> strides;
buffer_info(
std::string format,
std::vector<Py_ssize_t> shape_in,
std::vector<Py_ssize_t> strides_in)
: format(std::move(format)),
shape(std::move(shape_in)),
strides(std::move(strides_in)) {}
buffer_info(const buffer_info&) = delete;
buffer_info& operator=(const buffer_info&) = delete;
buffer_info(buffer_info&& other) noexcept {
(*this) = std::move(other);
}
buffer_info& operator=(buffer_info&& rhs) noexcept {
format = std::move(rhs.format);
shape = std::move(rhs.shape);
strides = std::move(rhs.strides);
return *this;
}
};
extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
std::memset(view, 0, sizeof(Py_buffer));
auto a = nb::cast<mx::array>(nb::handle(obj));
{
nb::gil_scoped_release nogil;
a.eval();
}
std::vector<Py_ssize_t> shape(a.shape().begin(), a.shape().end());
std::vector<Py_ssize_t> strides(a.strides().begin(), a.strides().end());
for (auto& s : strides) {
s *= a.itemsize();
}
buffer_info* info =
new buffer_info(buffer_format(a), std::move(shape), std::move(strides));
view->obj = obj;
view->ndim = a.ndim();
view->internal = info;
view->buf = a.data<void>();
view->itemsize = a.itemsize();
view->len = a.nbytes();
view->readonly = false;
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
view->format = const_cast<char*>(info->format.c_str());
}
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
view->strides = info->strides.data();
view->shape = info->shape.data();
}
Py_INCREF(view->obj);
return 0;
}
extern "C" inline void releasebuffer(PyObject*, Py_buffer* view) {
delete (buffer_info*)view->internal;
}
|