| | |
| |
|
| | #include <nanobind/stl/complex.h> |
| |
|
| | #include "python/src/convert.h" |
| | #include "python/src/utils.h" |
| |
|
| | #include "mlx/utils.h" |
| |
|
| | enum PyScalarT { |
| | pybool = 0, |
| | pyint = 1, |
| | pyfloat = 2, |
| | pycomplex = 3, |
| | }; |
| |
|
| | namespace nanobind { |
| | template <> |
| | struct ndarray_traits<mx::float16_t> { |
| | static constexpr bool is_complex = false; |
| | static constexpr bool is_float = true; |
| | static constexpr bool is_bool = false; |
| | static constexpr bool is_int = false; |
| | static constexpr bool is_signed = true; |
| | }; |
| | }; |
| |
|
| | int check_shape_dim(int64_t dim) { |
| | if (dim > std::numeric_limits<int>::max()) { |
| | throw std::invalid_argument( |
| | "Shape dimension falls outside supported `int` range."); |
| | } |
| | return static_cast<int>(dim); |
| | } |
| |
|
| | template <typename T> |
| | mx::array nd_array_to_mlx_contiguous( |
| | nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array, |
| | const mx::Shape& shape, |
| | mx::Dtype dtype) { |
| | |
| | |
| | auto data_ptr = nd_array.data(); |
| | return mx::array(static_cast<const T*>(data_ptr), shape, dtype); |
| | } |
| |
|
| | mx::array nd_array_to_mlx( |
| | nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array, |
| | std::optional<mx::Dtype> dtype) { |
| | |
| | mx::Shape shape; |
| | shape.reserve(nd_array.ndim()); |
| | for (int i = 0; i < nd_array.ndim(); i++) { |
| | shape.push_back(check_shape_dim(nd_array.shape(i))); |
| | } |
| | auto type = nd_array.dtype(); |
| |
|
| | |
| | if (type == nb::dtype<bool>()) { |
| | return nd_array_to_mlx_contiguous<bool>( |
| | nd_array, shape, dtype.value_or(mx::bool_)); |
| | } else if (type == nb::dtype<uint8_t>()) { |
| | return nd_array_to_mlx_contiguous<uint8_t>( |
| | nd_array, shape, dtype.value_or(mx::uint8)); |
| | } else if (type == nb::dtype<uint16_t>()) { |
| | return nd_array_to_mlx_contiguous<uint16_t>( |
| | nd_array, shape, dtype.value_or(mx::uint16)); |
| | } else if (type == nb::dtype<uint32_t>()) { |
| | return nd_array_to_mlx_contiguous<uint32_t>( |
| | nd_array, shape, dtype.value_or(mx::uint32)); |
| | } else if (type == nb::dtype<uint64_t>()) { |
| | return nd_array_to_mlx_contiguous<uint64_t>( |
| | nd_array, shape, dtype.value_or(mx::uint64)); |
| | } else if (type == nb::dtype<int8_t>()) { |
| | return nd_array_to_mlx_contiguous<int8_t>( |
| | nd_array, shape, dtype.value_or(mx::int8)); |
| | } else if (type == nb::dtype<int16_t>()) { |
| | return nd_array_to_mlx_contiguous<int16_t>( |
| | nd_array, shape, dtype.value_or(mx::int16)); |
| | } else if (type == nb::dtype<int32_t>()) { |
| | return nd_array_to_mlx_contiguous<int32_t>( |
| | nd_array, shape, dtype.value_or(mx::int32)); |
| | } else if (type == nb::dtype<int64_t>()) { |
| | return nd_array_to_mlx_contiguous<int64_t>( |
| | nd_array, shape, dtype.value_or(mx::int64)); |
| | } else if (type == nb::dtype<mx::float16_t>()) { |
| | return nd_array_to_mlx_contiguous<mx::float16_t>( |
| | nd_array, shape, dtype.value_or(mx::float16)); |
| | } else if (type == nb::bfloat16) { |
| | return nd_array_to_mlx_contiguous<mx::bfloat16_t>( |
| | nd_array, shape, dtype.value_or(mx::bfloat16)); |
| | } else if (type == nb::dtype<float>()) { |
| | return nd_array_to_mlx_contiguous<float>( |
| | nd_array, shape, dtype.value_or(mx::float32)); |
| | } else if (type == nb::dtype<double>()) { |
| | return nd_array_to_mlx_contiguous<double>( |
| | nd_array, shape, dtype.value_or(mx::float32)); |
| | } else if (type == nb::dtype<std::complex<float>>()) { |
| | return nd_array_to_mlx_contiguous<mx::complex64_t>( |
| | nd_array, shape, dtype.value_or(mx::complex64)); |
| | } else if (type == nb::dtype<std::complex<double>>()) { |
| | return nd_array_to_mlx_contiguous<mx::complex128_t>( |
| | nd_array, shape, dtype.value_or(mx::complex64)); |
| | } else { |
| | throw std::invalid_argument("Cannot convert numpy array to mlx array."); |
| | } |
| | } |
| |
|
| | template <typename T, typename... NDParams> |
| | nb::ndarray<NDParams...> mlx_to_nd_array_impl( |
| | mx::array a, |
| | std::optional<nb::dlpack::dtype> t = {}) { |
| | { |
| | nb::gil_scoped_release nogil; |
| | a.eval(); |
| | } |
| | std::vector<size_t> shape(a.shape().begin(), a.shape().end()); |
| | return nb::ndarray<NDParams...>( |
| | a.data<T>(), |
| | a.ndim(), |
| | shape.data(), |
| | nb::none(), |
| | a.strides().data(), |
| | t.value_or(nb::dtype<T>())); |
| | } |
| |
|
| | template <typename... NDParams> |
| | nb::ndarray<NDParams...> mlx_to_nd_array(const mx::array& a) { |
| | switch (a.dtype()) { |
| | case mx::bool_: |
| | return mlx_to_nd_array_impl<bool, NDParams...>(a); |
| | case mx::uint8: |
| | return mlx_to_nd_array_impl<uint8_t, NDParams...>(a); |
| | case mx::uint16: |
| | return mlx_to_nd_array_impl<uint16_t, NDParams...>(a); |
| | case mx::uint32: |
| | return mlx_to_nd_array_impl<uint32_t, NDParams...>(a); |
| | case mx::uint64: |
| | return mlx_to_nd_array_impl<uint64_t, NDParams...>(a); |
| | case mx::int8: |
| | return mlx_to_nd_array_impl<int8_t, NDParams...>(a); |
| | case mx::int16: |
| | return mlx_to_nd_array_impl<int16_t, NDParams...>(a); |
| | case mx::int32: |
| | return mlx_to_nd_array_impl<int32_t, NDParams...>(a); |
| | case mx::int64: |
| | return mlx_to_nd_array_impl<int64_t, NDParams...>(a); |
| | case mx::float16: |
| | return mlx_to_nd_array_impl<mx::float16_t, NDParams...>(a); |
| | case mx::bfloat16: |
| | throw nb::type_error("bfloat16 arrays cannot be converted to NumPy."); |
| | case mx::float32: |
| | return mlx_to_nd_array_impl<float, NDParams...>(a); |
| | case mx::float64: |
| | return mlx_to_nd_array_impl<double, NDParams...>(a); |
| | case mx::complex64: |
| | return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a); |
| | default: |
| | throw nb::type_error("type cannot be converted to NumPy."); |
| | } |
| | } |
| |
|
| | nb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a) { |
| | return mlx_to_nd_array<nb::numpy>(a); |
| | } |
| |
|
| | nb::ndarray<> mlx_to_dlpack(const mx::array& a) { |
| | return mlx_to_nd_array<>(a); |
| | } |
| |
|
| | nb::object to_scalar(mx::array& a) { |
| | if (a.size() != 1) { |
| | throw std::invalid_argument( |
| | "[convert] Only length-1 arrays can be converted to Python scalars."); |
| | } |
| | { |
| | nb::gil_scoped_release nogil; |
| | a.eval(); |
| | } |
| | switch (a.dtype()) { |
| | case mx::bool_: |
| | return nb::cast(a.item<bool>()); |
| | case mx::uint8: |
| | return nb::cast(a.item<uint8_t>()); |
| | case mx::uint16: |
| | return nb::cast(a.item<uint16_t>()); |
| | case mx::uint32: |
| | return nb::cast(a.item<uint32_t>()); |
| | case mx::uint64: |
| | return nb::cast(a.item<uint64_t>()); |
| | case mx::int8: |
| | return nb::cast(a.item<int8_t>()); |
| | case mx::int16: |
| | return nb::cast(a.item<int16_t>()); |
| | case mx::int32: |
| | return nb::cast(a.item<int32_t>()); |
| | case mx::int64: |
| | return nb::cast(a.item<int64_t>()); |
| | case mx::float16: |
| | return nb::cast(static_cast<float>(a.item<mx::float16_t>())); |
| | case mx::float32: |
| | return nb::cast(a.item<float>()); |
| | case mx::bfloat16: |
| | return nb::cast(static_cast<float>(a.item<mx::bfloat16_t>())); |
| | case mx::complex64: |
| | return nb::cast(a.item<std::complex<float>>()); |
| | case mx::float64: |
| | return nb::cast(a.item<double>()); |
| | default: |
| | throw nb::type_error("type cannot be converted to Python scalar."); |
| | } |
| | } |
| |
|
| | template <typename T, typename U = T> |
| | nb::list to_list(mx::array& a, size_t index, int dim) { |
| | nb::list pl; |
| | auto stride = a.strides()[dim]; |
| | for (int i = 0; i < a.shape(dim); ++i) { |
| | if (dim == a.ndim() - 1) { |
| | pl.append(static_cast<U>(a.data<T>()[index])); |
| | } else { |
| | pl.append(to_list<T, U>(a, index, dim + 1)); |
| | } |
| | index += stride; |
| | } |
| | return pl; |
| | } |
| |
|
| | nb::object tolist(mx::array& a) { |
| | if (a.ndim() == 0) { |
| | return to_scalar(a); |
| | } |
| | { |
| | nb::gil_scoped_release nogil; |
| | a.eval(); |
| | } |
| | switch (a.dtype()) { |
| | case mx::bool_: |
| | return to_list<bool>(a, 0, 0); |
| | case mx::uint8: |
| | return to_list<uint8_t>(a, 0, 0); |
| | case mx::uint16: |
| | return to_list<uint16_t>(a, 0, 0); |
| | case mx::uint32: |
| | return to_list<uint32_t>(a, 0, 0); |
| | case mx::uint64: |
| | return to_list<uint64_t>(a, 0, 0); |
| | case mx::int8: |
| | return to_list<int8_t>(a, 0, 0); |
| | case mx::int16: |
| | return to_list<int16_t>(a, 0, 0); |
| | case mx::int32: |
| | return to_list<int32_t>(a, 0, 0); |
| | case mx::int64: |
| | return to_list<int64_t>(a, 0, 0); |
| | case mx::float16: |
| | return to_list<mx::float16_t, float>(a, 0, 0); |
| | case mx::float32: |
| | return to_list<float>(a, 0, 0); |
| | case mx::bfloat16: |
| | return to_list<mx::bfloat16_t, float>(a, 0, 0); |
| | case mx::complex64: |
| | return to_list<std::complex<float>>(a, 0, 0); |
| | default: |
| | throw nb::type_error("data type cannot be converted to Python list."); |
| | } |
| | } |
| |
|
| | template <typename T, typename U> |
| | void fill_vector(T list, std::vector<U>& vals) { |
| | for (auto l : list) { |
| | if (nb::isinstance<nb::list>(l)) { |
| | fill_vector(nb::cast<nb::list>(l), vals); |
| | } else if (nb::isinstance<nb::tuple>(*list.begin())) { |
| | fill_vector(nb::cast<nb::tuple>(l), vals); |
| | } else { |
| | vals.push_back(nb::cast<U>(l)); |
| | } |
| | } |
| | } |
| |
|
| | template <typename T> |
| | PyScalarT validate_shape( |
| | T list, |
| | const mx::Shape& shape, |
| | int idx, |
| | bool& all_python_primitive_elements) { |
| | if (idx >= shape.size()) { |
| | throw std::invalid_argument("Initialization encountered extra dimension."); |
| | } |
| | auto s = shape[idx]; |
| | if (nb::len(list) != s) { |
| | throw std::invalid_argument( |
| | "Initialization encountered non-uniform length."); |
| | } |
| |
|
| | if (s == 0) { |
| | return pyfloat; |
| | } |
| |
|
| | PyScalarT type = pybool; |
| | for (auto l : list) { |
| | PyScalarT t; |
| | if (nb::isinstance<nb::list>(l)) { |
| | t = validate_shape( |
| | nb::cast<nb::list>(l), shape, idx + 1, all_python_primitive_elements); |
| | } else if (nb::isinstance<nb::tuple>(*list.begin())) { |
| | t = validate_shape( |
| | nb::cast<nb::tuple>(l), |
| | shape, |
| | idx + 1, |
| | all_python_primitive_elements); |
| | } else if (nb::isinstance<mx::array>(l)) { |
| | all_python_primitive_elements = false; |
| | auto arr = nb::cast<mx::array>(l); |
| | if (arr.ndim() + idx + 1 == shape.size() && |
| | std::equal( |
| | arr.shape().cbegin(), |
| | arr.shape().cend(), |
| | shape.cbegin() + idx + 1)) { |
| | t = pybool; |
| | } else { |
| | throw std::invalid_argument( |
| | "Initialization encountered non-uniform length."); |
| | } |
| | } else { |
| | if (nb::isinstance<nb::bool_>(l)) { |
| | t = pybool; |
| | } else if (nb::isinstance<nb::int_>(l)) { |
| | t = pyint; |
| | } else if (nb::isinstance<nb::float_>(l)) { |
| | t = pyfloat; |
| | } else if (PyComplex_Check(l.ptr())) { |
| | t = pycomplex; |
| | } else { |
| | std::ostringstream msg; |
| | msg << "Invalid type " << nb::type_name(l.type()).c_str() |
| | << " received in array initialization."; |
| | throw std::invalid_argument(msg.str()); |
| | } |
| |
|
| | if (idx + 1 != shape.size()) { |
| | throw std::invalid_argument( |
| | "Initialization encountered non-uniform length."); |
| | } |
| | } |
| | type = std::max(type, t); |
| | } |
| | return type; |
| | } |
| |
|
| | template <typename T> |
| | void get_shape(T list, mx::Shape& shape) { |
| | shape.push_back(check_shape_dim(nb::len(list))); |
| | if (shape.back() > 0) { |
| | auto l = list.begin(); |
| | if (nb::isinstance<nb::list>(*l)) { |
| | return get_shape(nb::cast<nb::list>(*l), shape); |
| | } else if (nb::isinstance<nb::tuple>(*l)) { |
| | return get_shape(nb::cast<nb::tuple>(*l), shape); |
| | } else if (nb::isinstance<mx::array>(*l)) { |
| | auto arr = nb::cast<mx::array>(*l); |
| | for (int i = 0; i < arr.ndim(); i++) { |
| | shape.push_back(arr.shape(i)); |
| | } |
| | return; |
| | } |
| | } |
| | } |
| |
|
| | template <typename T> |
| | mx::array array_from_list_impl( |
| | T pl, |
| | const PyScalarT& inferred_type, |
| | std::optional<mx::Dtype> specified_type, |
| | const mx::Shape& shape) { |
| | |
| | switch (inferred_type) { |
| | case pybool: { |
| | std::vector<bool> vals; |
| | fill_vector(pl, vals); |
| | return mx::array(vals.begin(), shape, specified_type.value_or(mx::bool_)); |
| | } |
| | case pyint: { |
| | auto dtype = specified_type.value_or(mx::int32); |
| | if (dtype == mx::int64) { |
| | std::vector<int64_t> vals; |
| | fill_vector(pl, vals); |
| | return mx::array(vals.begin(), shape, dtype); |
| | } else if (dtype == mx::uint64) { |
| | std::vector<uint64_t> vals; |
| | fill_vector(pl, vals); |
| | return mx::array(vals.begin(), shape, dtype); |
| | } else if (dtype == mx::uint32) { |
| | std::vector<uint32_t> vals; |
| | fill_vector(pl, vals); |
| | return mx::array(vals.begin(), shape, dtype); |
| | } else if (mx::issubdtype(dtype, mx::inexact)) { |
| | std::vector<float> vals; |
| | fill_vector(pl, vals); |
| | return mx::array(vals.begin(), shape, dtype); |
| | } else { |
| | std::vector<int> vals; |
| | fill_vector(pl, vals); |
| | return mx::array(vals.begin(), shape, dtype); |
| | } |
| | } |
| | case pyfloat: { |
| | std::vector<float> vals; |
| | fill_vector(pl, vals); |
| | return mx::array( |
| | vals.begin(), shape, specified_type.value_or(mx::float32)); |
| | } |
| | case pycomplex: { |
| | std::vector<std::complex<float>> vals; |
| | fill_vector(pl, vals); |
| | return mx::array( |
| | reinterpret_cast<mx::complex64_t*>(vals.data()), |
| | shape, |
| | specified_type.value_or(mx::complex64)); |
| | } |
| | default: { |
| | std::ostringstream msg; |
| | msg << "Should not happen, inferred: " << inferred_type |
| | << " on subarray made of only python primitive types."; |
| | throw std::runtime_error(msg.str()); |
| | } |
| | } |
| | } |
| |
|
| | template <typename T> |
| | mx::array array_from_list_impl(T pl, std::optional<mx::Dtype> dtype) { |
| | |
| | mx::Shape shape; |
| | get_shape(pl, shape); |
| |
|
| | |
| | bool all_python_primitive_elements = true; |
| | auto type = validate_shape(pl, shape, 0, all_python_primitive_elements); |
| |
|
| | if (all_python_primitive_elements) { |
| | |
| | return array_from_list_impl(pl, type, dtype, shape); |
| | } |
| |
|
| | |
| | std::vector<mx::array> arrays; |
| | for (auto l : pl) { |
| | arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype)); |
| | } |
| | return mx::stack(arrays); |
| | } |
| |
|
| | mx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype) { |
| | return array_from_list_impl(pl, dtype); |
| | } |
| |
|
| | mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype) { |
| | return array_from_list_impl(pl, dtype); |
| | } |
| |
|
| | mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) { |
| | if (auto pv = std::get_if<nb::bool_>(&v); pv) { |
| | return mx::array(nb::cast<bool>(*pv), t.value_or(mx::bool_)); |
| | } else if (auto pv = std::get_if<nb::int_>(&v); pv) { |
| | auto val = nb::cast<long>(*pv); |
| | auto default_type = (val > std::numeric_limits<int>::max() || |
| | val < std::numeric_limits<int>::min()) |
| | ? mx::int64 |
| | : mx::int32; |
| | return mx::array(val, t.value_or(default_type)); |
| | } else if (auto pv = std::get_if<nb::float_>(&v); pv) { |
| | return mx::array(nb::cast<float>(*pv), t.value_or(mx::float32)); |
| | } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) { |
| | return mx::array( |
| | static_cast<mx::complex64_t>(*pv), t.value_or(mx::complex64)); |
| | } else if (auto pv = std::get_if<nb::list>(&v); pv) { |
| | return array_from_list(*pv, t); |
| | } else if (auto pv = std::get_if<nb::tuple>(&v); pv) { |
| | return array_from_list(*pv, t); |
| | } else if (auto pv = std::get_if< |
| | nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v); |
| | pv) { |
| | return nd_array_to_mlx(*pv, t); |
| | } else if (auto pv = std::get_if<mx::array>(&v); pv) { |
| | return mx::astype(*pv, t.value_or((*pv).dtype())); |
| | } else { |
| | auto arr = to_array_with_accessor(std::get<ArrayLike>(v).obj); |
| | return mx::astype(arr, t.value_or(arr.dtype())); |
| | } |
| | } |
| |
|