| | |
| | #pragma once |
| |
|
| | #include <optional> |
| |
|
| | #include <nanobind/nanobind.h> |
| | #include <nanobind/ndarray.h> |
| |
|
| | #include "mlx/array.h" |
| | #include "mlx/ops.h" |
| |
|
| | namespace mx = mlx::core; |
| | namespace nb = nanobind; |
| |
|
| | namespace nanobind { |
| | static constexpr dlpack::dtype bfloat16{4, 16, 1}; |
| | }; |
| |
|
| | struct ArrayLike { |
| | ArrayLike(nb::object obj) : obj(obj) {}; |
| | nb::object obj; |
| | }; |
| |
|
| | using ArrayInitType = std::variant< |
| | nb::bool_, |
| | nb::int_, |
| | nb::float_, |
| | |
| | mx::array, |
| | |
| | nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>, |
| | std::complex<float>, |
| | nb::list, |
| | nb::tuple, |
| | ArrayLike>; |
| |
|
| | mx::array nd_array_to_mlx( |
| | nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array, |
| | std::optional<mx::Dtype> dtype); |
| |
|
| | nb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a); |
| | nb::ndarray<> mlx_to_dlpack(const mx::array& a); |
| |
|
| | nb::object to_scalar(mx::array& a); |
| |
|
| | nb::object tolist(mx::array& a); |
| |
|
| | mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t); |
| | mx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype); |
| | mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype); |
| |
|