| | |
| | #include <cstdint> |
| | #include <cstring> |
| | #include <sstream> |
| |
|
| | #include <nanobind/ndarray.h> |
| | #include <nanobind/stl/complex.h> |
| | #include <nanobind/stl/optional.h> |
| | #include <nanobind/stl/string.h> |
| | #include <nanobind/stl/variant.h> |
| | #include <nanobind/stl/vector.h> |
| | #include <nanobind/typing.h> |
| |
|
| | #include "mlx/backend/metal/metal.h" |
| | #include "python/src/buffer.h" |
| | #include "python/src/convert.h" |
| | #include "python/src/indexing.h" |
| | #include "python/src/small_vector.h" |
| | #include "python/src/utils.h" |
| |
|
| | #include "mlx/mlx.h" |
| |
|
| | namespace mx = mlx::core; |
| | namespace nb = nanobind; |
| | using namespace nb::literals; |
| |
|
| | class ArrayAt { |
| | public: |
| | ArrayAt(mx::array x) : x_(std::move(x)) {} |
| | ArrayAt& set_indices(nb::object indices) { |
| | initialized_ = true; |
| | indices_ = indices; |
| | return *this; |
| | } |
| | void check_initialized() { |
| | if (!initialized_) { |
| | throw std::invalid_argument( |
| | "Must give indices to array.at (e.g. `x.at[0].add(4)`)."); |
| | } |
| | } |
| |
|
| | mx::array add(const ScalarOrArray& v) { |
| | check_initialized(); |
| | return mlx_add_item(x_, indices_, v); |
| | } |
| | mx::array subtract(const ScalarOrArray& v) { |
| | check_initialized(); |
| | return mlx_subtract_item(x_, indices_, v); |
| | } |
| | mx::array multiply(const ScalarOrArray& v) { |
| | check_initialized(); |
| | return mlx_multiply_item(x_, indices_, v); |
| | } |
| | mx::array divide(const ScalarOrArray& v) { |
| | check_initialized(); |
| | return mlx_divide_item(x_, indices_, v); |
| | } |
| | mx::array maximum(const ScalarOrArray& v) { |
| | check_initialized(); |
| | return mlx_maximum_item(x_, indices_, v); |
| | } |
| | mx::array minimum(const ScalarOrArray& v) { |
| | check_initialized(); |
| | return mlx_minimum_item(x_, indices_, v); |
| | } |
| |
|
| | private: |
| | mx::array x_; |
| | bool initialized_{false}; |
| | nb::object indices_; |
| | }; |
| |
|
| | class ArrayPythonIterator { |
| | public: |
| | ArrayPythonIterator(mx::array x) : idx_(0), x_(std::move(x)) { |
| | if (x_.shape(0) > 0 && x_.shape(0) < 10) { |
| | splits_ = mx::split(x_, x_.shape(0)); |
| | } |
| | } |
| |
|
| | mx::array next() { |
| | if (idx_ >= x_.shape(0)) { |
| | throw nb::stop_iteration(); |
| | } |
| |
|
| | if (idx_ >= 0 && idx_ < splits_.size()) { |
| | return mx::squeeze(splits_[idx_++], 0); |
| | } |
| |
|
| | return *(x_.begin() + idx_++); |
| | } |
| |
|
| | private: |
| | int idx_; |
| | mx::array x_; |
| | std::vector<mx::array> splits_; |
| | }; |
| |
|
| | void init_array(nb::module_& m) { |
| | |
| | mx::get_global_formatter().capitalize_bool = true; |
| |
|
| | |
| | nb::class_<mx::Dtype>( |
| | m, |
| | "Dtype", |
| | R"pbdoc( |
| | An object to hold the type of a :class:`array`. |
| | |
| | See the :ref:`list of types <data_types>` for more details |
| | on available data types. |
| | )pbdoc") |
| | .def_prop_ro( |
| | "size", &mx::Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc") |
| | .def( |
| | "__repr__", |
| | [](const mx::Dtype& t) { |
| | std::ostringstream os; |
| | os << "mlx.core."; |
| | os << t; |
| | return os.str(); |
| | }) |
| | .def( |
| | "__eq__", |
| | [](const mx::Dtype& t, const nb::object& other) { |
| | return nb::isinstance<mx::Dtype>(other) && |
| | t == nb::cast<mx::Dtype>(other); |
| | }) |
| | .def("__hash__", [](const mx::Dtype& t) { |
| | return static_cast<int64_t>(t.val()); |
| | }); |
| |
|
| | m.attr("bool_") = nb::cast(mx::bool_); |
| | m.attr("uint8") = nb::cast(mx::uint8); |
| | m.attr("uint16") = nb::cast(mx::uint16); |
| | m.attr("uint32") = nb::cast(mx::uint32); |
| | m.attr("uint64") = nb::cast(mx::uint64); |
| | m.attr("int8") = nb::cast(mx::int8); |
| | m.attr("int16") = nb::cast(mx::int16); |
| | m.attr("int32") = nb::cast(mx::int32); |
| | m.attr("int64") = nb::cast(mx::int64); |
| | m.attr("float16") = nb::cast(mx::float16); |
| | m.attr("float32") = nb::cast(mx::float32); |
| | m.attr("float64") = nb::cast(mx::float64); |
| | m.attr("bfloat16") = nb::cast(mx::bfloat16); |
| | m.attr("complex64") = nb::cast(mx::complex64); |
| | nb::enum_<mx::Dtype::Category>( |
| | m, |
| | "DtypeCategory", |
| | R"pbdoc( |
| | Type to hold categories of :class:`dtypes <Dtype>`. |
| | |
| | * :attr:`~mlx.core.generic` |
| | |
| | * :ref:`bool_ <data_types>` |
| | * :attr:`~mlx.core.number` |
| | |
| | * :attr:`~mlx.core.integer` |
| | |
| | * :attr:`~mlx.core.unsignedinteger` |
| | |
| | * :ref:`uint8 <data_types>` |
| | * :ref:`uint16 <data_types>` |
| | * :ref:`uint32 <data_types>` |
| | * :ref:`uint64 <data_types>` |
| | |
| | * :attr:`~mlx.core.signedinteger` |
| | |
| | * :ref:`int8 <data_types>` |
| | * :ref:`int32 <data_types>` |
| | * :ref:`int64 <data_types>` |
| | |
| | * :attr:`~mlx.core.inexact` |
| | |
| | * :attr:`~mlx.core.floating` |
| | |
| | * :ref:`float16 <data_types>` |
| | * :ref:`bfloat16 <data_types>` |
| | * :ref:`float32 <data_types>` |
| | * :ref:`float64 <data_types>` |
| | |
| | * :attr:`~mlx.core.complexfloating` |
| | |
| | * :ref:`complex64 <data_types>` |
| | |
| | See also :func:`~mlx.core.issubdtype`. |
| | )pbdoc") |
| | .value("complexfloating", mx::complexfloating) |
| | .value("floating", mx::floating) |
| | .value("inexact", mx::inexact) |
| | .value("signedinteger", mx::signedinteger) |
| | .value("unsignedinteger", mx::unsignedinteger) |
| | .value("integer", mx::integer) |
| | .value("number", mx::number) |
| | .value("generic", mx::generic) |
| | .export_values(); |
| |
|
| | nb::class_<mx::finfo>( |
| | m, |
| | "finfo", |
| | R"pbdoc( |
| | Get information on floating-point types. |
| | )pbdoc") |
| | .def(nb::init<mx::Dtype>()) |
| | .def_ro( |
| | "min", |
| | &mx::finfo::min, |
| | R"pbdoc(The smallest representable number.)pbdoc") |
| | .def_ro( |
| | "max", |
| | &mx::finfo::max, |
| | R"pbdoc(The largest representable number.)pbdoc") |
| | .def_ro( |
| | "eps", |
| | &mx::finfo::eps, |
| | R"pbdoc( |
| | The difference between 1.0 and the next smallest |
| | representable number larger than 1.0. |
| | )pbdoc") |
| | .def_ro("dtype", &mx::finfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc") |
| | .def("__repr__", [](const mx::finfo& f) { |
| | std::ostringstream os; |
| | os << "finfo(" |
| | << "min=" << f.min << ", max=" << f.max << ", dtype=" << f.dtype |
| | << ")"; |
| | return os.str(); |
| | }); |
| |
|
| | nb::class_<mx::iinfo>( |
| | m, |
| | "iinfo", |
| | R"pbdoc( |
| | Get information on integer types. |
| | )pbdoc") |
| | .def(nb::init<mx::Dtype>()) |
| | .def_ro( |
| | "min", |
| | &mx::iinfo::min, |
| | R"pbdoc(The smallest representable number.)pbdoc") |
| | .def_ro( |
| | "max", |
| | &mx::iinfo::max, |
| | R"pbdoc(The largest representable number.)pbdoc") |
| | .def_ro("dtype", &mx::iinfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc") |
| | .def("__repr__", [](const mx::iinfo& i) { |
| | std::ostringstream os; |
| | os << "iinfo(" |
| | << "min=" << i.min << ", max=" << i.max << ", dtype=" << i.dtype |
| | << ")"; |
| | return os.str(); |
| | }); |
| |
|
| | nb::class_<ArrayAt>( |
| | m, |
| | "ArrayAt", |
| | R"pbdoc( |
| | A helper object to apply updates at specific indices. |
| | )pbdoc") |
| | .def("__getitem__", &ArrayAt::set_indices, "indices"_a.none()) |
| | .def("add", &ArrayAt::add, "value"_a) |
| | .def("subtract", &ArrayAt::subtract, "value"_a) |
| | .def("multiply", &ArrayAt::multiply, "value"_a) |
| | .def("divide", &ArrayAt::divide, "value"_a) |
| | .def("maximum", &ArrayAt::maximum, "value"_a) |
| | .def("minimum", &ArrayAt::minimum, "value"_a); |
| |
|
| | nb::class_<ArrayLike>( |
| | m, |
| | "ArrayLike", |
| | R"pbdoc( |
| | Any Python object which has an ``__mlx__array__`` method that |
| | returns an :obj:`array`. |
| | )pbdoc") |
| | .def(nb::init_implicit<nb::object>()); |
| |
|
| | nb::class_<ArrayPythonIterator>( |
| | m, |
| | "ArrayIterator", |
| | R"pbdoc( |
| | A helper object to iterate over the 1st dimension of an array. |
| | )pbdoc") |
| | .def("__next__", &ArrayPythonIterator::next) |
| | .def("__iter__", [](const ArrayPythonIterator& it) { return it; }); |
| |
|
| | |
| | PyType_Slot array_slots[] = { |
| | {Py_bf_getbuffer, (void*)getbuffer}, |
| | {Py_bf_releasebuffer, (void*)releasebuffer}, |
| | {0, nullptr}}; |
| |
|
| | nb::class_<mx::array>( |
| | m, |
| | "array", |
| | R"pbdoc(An N-dimensional array object.)pbdoc", |
| | nb::type_slots(array_slots), |
| | nb::is_weak_referenceable()) |
| | .def( |
| | "__init__", |
| | [](mx::array* aptr, ArrayInitType v, std::optional<mx::Dtype> t) { |
| | new (aptr) mx::array(create_array(v, t)); |
| | }, |
| | "val"_a, |
| | "dtype"_a = nb::none(), |
| | nb::sig( |
| | "def __init__(self: array, val: Union[scalar, list, tuple, numpy.ndarray, array], dtype: Optional[Dtype] = None)")) |
| | .def_prop_ro( |
| | "size", |
| | &mx::array::size, |
| | R"pbdoc(Number of elements in the array.)pbdoc") |
| | .def_prop_ro( |
| | "ndim", &mx::array::ndim, R"pbdoc(The array's dimension.)pbdoc") |
| | .def_prop_ro( |
| | "itemsize", |
| | &mx::array::itemsize, |
| | R"pbdoc(The size of the array's datatype in bytes.)pbdoc") |
| | .def_prop_ro( |
| | "nbytes", |
| | &mx::array::nbytes, |
| | R"pbdoc(The number of bytes in the array.)pbdoc") |
| | .def_prop_ro( |
| | "shape", |
| | [](const mx::array& a) { return nb::cast(a.shape()); }, |
| | nb::sig("def shape(self) -> tuple[int, ...]"), |
| | R"pbdoc( |
| | The shape of the array as a Python tuple. |
| | |
| | Returns: |
| | tuple(int): A tuple containing the sizes of each dimension. |
| | )pbdoc") |
| | .def_prop_ro( |
| | "dtype", |
| | &mx::array::dtype, |
| | R"pbdoc( |
| | The array's :class:`Dtype`. |
| | )pbdoc") |
| | .def_prop_ro( |
| | "real", |
| | [](const mx::array& a) { return mx::real(a); }, |
| | R"pbdoc( |
| | The real part of a complex array. |
| | )pbdoc") |
| | .def_prop_ro( |
| | "imag", |
| | [](const mx::array& a) { return mx::imag(a); }, |
| | R"pbdoc( |
| | The imaginary part of a complex array. |
| | )pbdoc") |
| | .def( |
| | "item", |
| | &to_scalar, |
| | nb::sig("def item(self) -> scalar"), |
| | R"pbdoc( |
| | Access the value of a scalar array. |
| | |
| | Returns: |
| | Standard Python scalar. |
| | )pbdoc") |
| | .def( |
| | "tolist", |
| | &tolist, |
| | nb::sig("def tolist(self) -> list_or_scalar"), |
| | R"pbdoc( |
| | Convert the array to a Python :class:`list`. |
| | |
| | Returns: |
| | list: The Python list. |
| | |
| | If the array is a scalar then a standard Python scalar is returned. |
| | |
| | If the array has more than one dimension then the result is a nested |
| | list of lists. |
| | |
| | The value type of the list corresponding to the last dimension is either |
| | ``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array. |
| | )pbdoc") |
| | .def( |
| | "astype", |
| | &mx::astype, |
| | "dtype"_a, |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | Cast the array to a specified type. |
| | |
| | Args: |
| | dtype (Dtype): Type to which the array is cast. |
| | stream (Stream): Stream (or device) for the operation. |
| | |
| | Returns: |
| | array: The array with type ``dtype``. |
| | )pbdoc") |
| | .def( |
| | "__array_namespace__", |
| | [](const mx::array& a, |
| | const std::optional<std::string>& api_version) { |
| | if (api_version) { |
| | throw std::invalid_argument( |
| | "Explicitly specifying api_version is not yet implemented."); |
| | } |
| | return nb::module_::import_("mlx.core"); |
| | }, |
| | "api_version"_a = nb::none(), |
| | R"pbdoc( |
| | Returns an object that has all the array API functions on it. |
| | |
| | See the `Python array API <https://data-apis.org/array-api/latest/index.html>`_ |
| | for more information. |
| | |
| | Args: |
| | api_version (str, optional): String representing the version |
| | of the array API spec to return. Default: ``None``. |
| | |
| | Returns: |
| | out (Any): An object representing the array API namespace. |
| | )pbdoc") |
| | .def("__getitem__", mlx_get_item, nb::arg().none()) |
| | .def("__setitem__", mlx_set_item, nb::arg().none(), nb::arg()) |
| | .def_prop_ro( |
| | "at", |
| | [](const mx::array& a) { return ArrayAt(a); }, |
| | R"pbdoc( |
| | Used to apply updates at the given indices. |
| | |
| | .. note:: |
| | |
| | Regular in-place updates map to assignment. For instance ``x[idx] += y`` |
| | maps to ``x[idx] = x[idx] + y``. As a result, assigning to the |
| | same index ignores all but one update. Using ``x.at[idx].add(y)`` |
| | will correctly apply all updates to all indices. |
| | |
| | .. list-table:: |
| | :header-rows: 1 |
| | |
| | * - array.at syntax |
| | - In-place syntax |
| | * - ``x = x.at[idx].add(y)`` |
| | - ``x[idx] += y`` |
| | * - ``x = x.at[idx].subtract(y)`` |
| | - ``x[idx] -= y`` |
| | * - ``x = x.at[idx].multiply(y)`` |
| | - ``x[idx] *= y`` |
| | * - ``x = x.at[idx].divide(y)`` |
| | - ``x[idx] /= y`` |
| | * - ``x = x.at[idx].maximum(y)`` |
| | - ``x[idx] = mx.maximum(x[idx], y)`` |
| | * - ``x = x.at[idx].minimum(y)`` |
| | - ``x[idx] = mx.minimum(x[idx], y)`` |
| | |
| | Example: |
| | >>> a = mx.array([0, 0]) |
| | >>> idx = mx.array([0, 1, 0, 1]) |
| | >>> a[idx] += 1 |
| | >>> a |
| | array([1, 1], dtype=int32) |
| | >>> |
| | >>> a = mx.array([0, 0]) |
| | >>> a.at[idx].add(1) |
| | array([2, 2], dtype=int32) |
| | )pbdoc") |
| | .def( |
| | "__len__", |
| | [](const mx::array& a) { |
| | if (a.ndim() == 0) { |
| | throw nb::type_error("len() 0-dimensional array."); |
| | } |
| | return a.shape(0); |
| | }) |
| | .def( |
| | "__iter__", [](const mx::array& a) { return ArrayPythonIterator(a); }) |
| | .def( |
| | "__getstate__", |
| | [](const mx::array& a) { |
| | auto nd = (a.dtype() == mx::bfloat16) |
| | ? mlx_to_np_array(mx::view(a, mx::uint16)) |
| | : mlx_to_np_array(a); |
| | return nb::make_tuple(nd, static_cast<uint8_t>(a.dtype().val())); |
| | }) |
| | .def( |
| | "__setstate__", |
| | [](mx::array& arr, const nb::tuple& state) { |
| | if (nb::len(state) != 2) { |
| | throw std::invalid_argument( |
| | "Invalid pickle state: expected (ndarray, Dtype::Val)"); |
| | } |
| | using ND = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>; |
| | ND nd = nb::cast<ND>(state[0]); |
| | auto val = static_cast<mx::Dtype::Val>(nb::cast<uint8_t>(state[1])); |
| | if (val == mx::Dtype::Val::bfloat16) { |
| | auto owner = nb::handle(state[0].ptr()); |
| | new (&arr) mx::array(nd_array_to_mlx( |
| | ND(nd.data(), |
| | nd.ndim(), |
| | reinterpret_cast<const size_t*>(nd.shape_ptr()), |
| | owner, |
| | nullptr, |
| | nb::bfloat16), |
| | mx::bfloat16)); |
| | } else { |
| | new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt)); |
| | } |
| | }) |
| | .def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); }) |
| | .def( |
| | "__dlpack_device__", |
| | [](const mx::array& a) { |
| | |
| | |
| | if (mx::metal::is_available()) { |
| | return nb::make_tuple(8, 0); |
| | } else if (mx::cu::is_available()) { |
| | return nb::make_tuple(13, 0); |
| | } else { |
| | |
| | return nb::make_tuple(1, 0); |
| | } |
| | }) |
| | .def("__copy__", [](const mx::array& self) { return mx::array(self); }) |
| | .def( |
| | "__deepcopy__", |
| | [](const mx::array& self, nb::dict) { return mx::array(self); }, |
| | "memo"_a) |
| | .def( |
| | "__add__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("addition", v); |
| | } |
| | auto b = to_array(v, a.dtype()); |
| | return mx::add(a, b); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__iadd__", |
| | [](mx::array& a, const ScalarOrArray v) -> mx::array& { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("inplace addition", v); |
| | } |
| | a.overwrite_descriptor(mx::add(a, to_array(v, a.dtype()))); |
| | return a; |
| | }, |
| | "other"_a, |
| | nb::rv_policy::none) |
| | .def( |
| | "__radd__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("addition", v); |
| | } |
| | return mx::add(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__sub__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("subtraction", v); |
| | } |
| | return mx::subtract(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__isub__", |
| | [](mx::array& a, const ScalarOrArray v) -> mx::array& { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("inplace subtraction", v); |
| | } |
| | a.overwrite_descriptor(mx::subtract(a, to_array(v, a.dtype()))); |
| | return a; |
| | }, |
| | "other"_a, |
| | nb::rv_policy::none) |
| | .def( |
| | "__rsub__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("subtraction", v); |
| | } |
| | return mx::subtract(to_array(v, a.dtype()), a); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__mul__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("multiplication", v); |
| | } |
| | return mx::multiply(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__imul__", |
| | [](mx::array& a, const ScalarOrArray v) -> mx::array& { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("inplace multiplication", v); |
| | } |
| | a.overwrite_descriptor(mx::multiply(a, to_array(v, a.dtype()))); |
| | return a; |
| | }, |
| | "other"_a, |
| | nb::rv_policy::none) |
| | .def( |
| | "__rmul__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("multiplication", v); |
| | } |
| | return mx::multiply(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__truediv__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("division", v); |
| | } |
| | return mx::divide(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__itruediv__", |
| | [](mx::array& a, const ScalarOrArray v) -> mx::array& { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("inplace division", v); |
| | } |
| | if (!mx::issubdtype(a.dtype(), mx::inexact)) { |
| | throw std::invalid_argument( |
| | "In place division cannot cast to non-floating point type."); |
| | } |
| | a.overwrite_descriptor(divide(a, to_array(v, a.dtype()))); |
| | return a; |
| | }, |
| | "other"_a, |
| | nb::rv_policy::none) |
| | .def( |
| | "__rtruediv__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("division", v); |
| | } |
| | return mx::divide(to_array(v, a.dtype()), a); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__div__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("division", v); |
| | } |
| | return mx::divide(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__rdiv__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("division", v); |
| | } |
| | return mx::divide(to_array(v, a.dtype()), a); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__floordiv__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("floor division", v); |
| | } |
| | return mx::floor_divide(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__ifloordiv__", |
| | [](mx::array& a, const ScalarOrArray v) -> mx::array& { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("inplace floor division", v); |
| | } |
| | a.overwrite_descriptor(mx::floor_divide(a, to_array(v, a.dtype()))); |
| | return a; |
| | }, |
| | "other"_a, |
| | nb::rv_policy::none) |
| | .def( |
| | "__rfloordiv__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("floor division", v); |
| | } |
| | auto b = to_array(v, a.dtype()); |
| | return mx::floor_divide(b, a); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__mod__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("modulus", v); |
| | } |
| | return mx::remainder(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__imod__", |
| | [](mx::array& a, const ScalarOrArray v) -> mx::array& { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("inplace modulus", v); |
| | } |
| | a.overwrite_descriptor(mx::remainder(a, to_array(v, a.dtype()))); |
| | return a; |
| | }, |
| | "other"_a, |
| | nb::rv_policy::none) |
| | .def( |
| | "__rmod__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("modulus", v); |
| | } |
| | return mx::remainder(to_array(v, a.dtype()), a); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__eq__", |
| | [](const mx::array& a, |
| | const ScalarOrArray& v) -> std::variant<mx::array, bool> { |
| | if (!is_comparable_with_array(v)) { |
| | return false; |
| | } |
| | return mx::equal(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__lt__", |
| | [](const mx::array& a, const ScalarOrArray v) -> mx::array { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("less than", v); |
| | } |
| | return mx::less(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__le__", |
| | [](const mx::array& a, const ScalarOrArray v) -> mx::array { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("less than or equal", v); |
| | } |
| | return mx::less_equal(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__gt__", |
| | [](const mx::array& a, const ScalarOrArray v) -> mx::array { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("greater than", v); |
| | } |
| | return mx::greater(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__ge__", |
| | [](const mx::array& a, const ScalarOrArray v) -> mx::array { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("greater than or equal", v); |
| | } |
| | return mx::greater_equal(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__ne__", |
| | [](const mx::array& a, |
| | const ScalarOrArray v) -> std::variant<mx::array, bool> { |
| | if (!is_comparable_with_array(v)) { |
| | return true; |
| | } |
| | return mx::not_equal(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def("__neg__", [](const mx::array& a) { return -a; }) |
| | .def("__bool__", [](mx::array& a) { return nb::bool_(to_scalar(a)); }) |
| | .def( |
| | "__repr__", |
| | [](mx::array& a) { |
| | nb::gil_scoped_release nogil; |
| | std::ostringstream os; |
| | os << a; |
| | return os.str(); |
| | }) |
| | .def( |
| | "__matmul__", |
| | [](const mx::array& a, mx::array& other) { |
| | return mx::matmul(a, other); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__imatmul__", |
| | [](mx::array& a, mx::array& other) -> mx::array& { |
| | a.overwrite_descriptor(mx::matmul(a, other)); |
| | return a; |
| | }, |
| | "other"_a, |
| | nb::rv_policy::none) |
| | .def( |
| | "__pow__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("power", v); |
| | } |
| | return mx::power(a, to_array(v, a.dtype())); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__rpow__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("power", v); |
| | } |
| | return mx::power(to_array(v, a.dtype()), a); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__ipow__", |
| | [](mx::array& a, const ScalarOrArray v) -> mx::array& { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("inplace power", v); |
| | } |
| | a.overwrite_descriptor(mx::power(a, to_array(v, a.dtype()))); |
| | return a; |
| | }, |
| | "other"_a, |
| | nb::rv_policy::none) |
| | .def( |
| | "__invert__", |
| | [](const mx::array& a) { |
| | if (mx::issubdtype(a.dtype(), mx::inexact)) { |
| | throw std::invalid_argument( |
| | "Floating point types not allowed with bitwise inversion."); |
| | } |
| | if (a.dtype() == mx::bool_) { |
| | return mx::logical_not(a); |
| | } |
| | return mx::bitwise_invert(a); |
| | }) |
| | .def( |
| | "__and__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("bitwise and", v); |
| | } |
| | auto b = to_array(v, a.dtype()); |
| | if (mx::issubdtype(a.dtype(), mx::inexact) || |
| | mx::issubdtype(b.dtype(), mx::inexact)) { |
| | throw std::invalid_argument( |
| | "Floating point types not allowed with bitwise and."); |
| | } |
| | return mx::bitwise_and(a, b); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__iand__", |
| | [](mx::array& a, const ScalarOrArray v) -> mx::array& { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("inplace bitwise and", v); |
| | } |
| | auto b = to_array(v, a.dtype()); |
| | if (mx::issubdtype(a.dtype(), mx::inexact) || |
| | mx::issubdtype(b.dtype(), mx::inexact)) { |
| | throw std::invalid_argument( |
| | "Floating point types not allowed with bitwise and."); |
| | } |
| | a.overwrite_descriptor(mx::bitwise_and(a, b)); |
| | return a; |
| | }, |
| | "other"_a, |
| | nb::rv_policy::none) |
| | .def( |
| | "__or__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("bitwise or", v); |
| | } |
| | auto b = to_array(v, a.dtype()); |
| | if (mx::issubdtype(a.dtype(), mx::inexact) || |
| | mx::issubdtype(b.dtype(), mx::inexact)) { |
| | throw std::invalid_argument( |
| | "Floating point types not allowed with bitwise or."); |
| | } |
| | return mx::bitwise_or(a, b); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__ior__", |
| | [](mx::array& a, const ScalarOrArray v) -> mx::array& { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("inplace bitwise or", v); |
| | } |
| | auto b = to_array(v, a.dtype()); |
| | if (mx::issubdtype(a.dtype(), mx::inexact) || |
| | mx::issubdtype(b.dtype(), mx::inexact)) { |
| | throw std::invalid_argument( |
| | "Floating point types not allowed with bitwise or."); |
| | } |
| | a.overwrite_descriptor(mx::bitwise_or(a, b)); |
| | return a; |
| | }, |
| | "other"_a, |
| | nb::rv_policy::none) |
| | .def( |
| | "__lshift__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("left shift", v); |
| | } |
| | auto b = to_array(v, a.dtype()); |
| | if (mx::issubdtype(a.dtype(), mx::inexact) || |
| | mx::issubdtype(b.dtype(), mx::inexact)) { |
| | throw std::invalid_argument( |
| | "Floating point types not allowed with left shift."); |
| | } |
| | return mx::left_shift(a, b); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__ilshift__", |
| | [](mx::array& a, const ScalarOrArray v) -> mx::array& { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("inplace left shift", v); |
| | } |
| | auto b = to_array(v, a.dtype()); |
| | if (mx::issubdtype(a.dtype(), mx::inexact) || |
| | mx::issubdtype(b.dtype(), mx::inexact)) { |
| | throw std::invalid_argument( |
| | "Floating point types not allowed with left shift."); |
| | } |
| | a.overwrite_descriptor(mx::left_shift(a, b)); |
| | return a; |
| | }, |
| | "other"_a, |
| | nb::rv_policy::none) |
| | .def( |
| | "__rshift__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("right shift", v); |
| | } |
| | auto b = to_array(v, a.dtype()); |
| | if (mx::issubdtype(a.dtype(), mx::inexact) || |
| | mx::issubdtype(b.dtype(), mx::inexact)) { |
| | throw std::invalid_argument( |
| | "Floating point types not allowed with right shift."); |
| | } |
| | return mx::right_shift(a, b); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__irshift__", |
| | [](mx::array& a, const ScalarOrArray v) -> mx::array& { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("inplace right shift", v); |
| | } |
| | auto b = to_array(v, a.dtype()); |
| | if (mx::issubdtype(a.dtype(), mx::inexact) || |
| | mx::issubdtype(b.dtype(), mx::inexact)) { |
| | throw std::invalid_argument( |
| | "Floating point types not allowed with right shift."); |
| | } |
| | a.overwrite_descriptor(mx::right_shift(a, b)); |
| | return a; |
| | }, |
| | "other"_a, |
| | nb::rv_policy::none) |
| | .def( |
| | "__xor__", |
| | [](const mx::array& a, const ScalarOrArray v) { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("bitwise xor", v); |
| | } |
| | auto b = to_array(v, a.dtype()); |
| | if (mx::issubdtype(a.dtype(), mx::inexact) || |
| | mx::issubdtype(b.dtype(), mx::inexact)) { |
| | throw std::invalid_argument( |
| | "Floating point types not allowed with bitwise xor."); |
| | } |
| | return mx::bitwise_xor(a, b); |
| | }, |
| | "other"_a) |
| | .def( |
| | "__ixor__", |
| | [](mx::array& a, const ScalarOrArray v) -> mx::array& { |
| | if (!is_comparable_with_array(v)) { |
| | throw_invalid_operation("inplace bitwise xor", v); |
| | } |
| | auto b = to_array(v, a.dtype()); |
| | if (mx::issubdtype(a.dtype(), mx::inexact) || |
| | mx::issubdtype(b.dtype(), mx::inexact)) { |
| | throw std::invalid_argument( |
| | "Floating point types not allowed bitwise xor."); |
| | } |
| | a.overwrite_descriptor(mx::bitwise_xor(a, b)); |
| | return a; |
| | }, |
| | "other"_a, |
| | nb::rv_policy::none) |
| | .def("__int__", [](mx::array& a) { return nb::int_(to_scalar(a)); }) |
| | .def("__float__", [](mx::array& a) { return nb::float_(to_scalar(a)); }) |
| | .def( |
| | "__format__", |
| | [](mx::array& a, nb::object format_spec) { |
| | if (nb::len(nb::str(format_spec)) > 0 && a.ndim() > 0) { |
| | throw nb::type_error( |
| | "unsupported format string passed to mx.array.__format__"); |
| | } else if (a.ndim() == 0) { |
| | auto obj = to_scalar(a); |
| | return nb::cast<std::string>( |
| | nb::handle(PyObject_Format(obj.ptr(), format_spec.ptr()))); |
| | } else { |
| | nb::gil_scoped_release nogil; |
| | std::ostringstream os; |
| | os << a; |
| | return os.str(); |
| | } |
| | }) |
| | .def( |
| | "flatten", |
| | [](const mx::array& a, |
| | int start_axis, |
| | int end_axis, |
| | const mx::StreamOrDevice& s) { |
| | return mx::flatten(a, start_axis, end_axis, s); |
| | }, |
| | "start_axis"_a = 0, |
| | "end_axis"_a = -1, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | See :func:`flatten`. |
| | )pbdoc") |
| | .def( |
| | "reshape", |
| | [](const mx::array& a, nb::args shape_, mx::StreamOrDevice s) { |
| | mx::Shape shape; |
| | if (!nb::isinstance<int>(shape_[0])) { |
| | shape = nb::cast<mx::Shape>(shape_[0]); |
| | } else { |
| | shape = nb::cast<mx::Shape>(shape_); |
| | } |
| | return mx::reshape(a, std::move(shape), s); |
| | }, |
| | "shape"_a, |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | Equivalent to :func:`reshape` but the shape can be passed either as a |
| | :obj:`tuple` or as separate arguments. |
| | |
| | See :func:`reshape` for full documentation. |
| | )pbdoc") |
| | .def( |
| | "squeeze", |
| | [](const mx::array& a, |
| | const IntOrVec& v, |
| | const mx::StreamOrDevice& s) { |
| | if (std::holds_alternative<std::monostate>(v)) { |
| | return mx::squeeze(a, s); |
| | } else if (auto pv = std::get_if<int>(&v); pv) { |
| | return mx::squeeze(a, *pv, s); |
| | } else { |
| | return mx::squeeze(a, std::get<std::vector<int>>(v), s); |
| | } |
| | }, |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | See :func:`squeeze`. |
| | )pbdoc") |
| | .def( |
| | "abs", |
| | &mx::abs, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`abs`.") |
| | .def( |
| | "__abs__", |
| | [](const mx::array& a) { return mx::abs(a); }, |
| | "See :func:`abs`.") |
| | .def( |
| | "square", |
| | &mx::square, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`square`.") |
| | .def( |
| | "sqrt", |
| | &mx::sqrt, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`sqrt`.") |
| | .def( |
| | "rsqrt", |
| | &mx::rsqrt, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`rsqrt`.") |
| | .def( |
| | "reciprocal", |
| | &mx::reciprocal, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`reciprocal`.") |
| | .def( |
| | "exp", |
| | &mx::exp, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`exp`.") |
| | .def( |
| | "log", |
| | &mx::log, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`log`.") |
| | .def( |
| | "log2", |
| | &mx::log2, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`log2`.") |
| | .def( |
| | "log10", |
| | &mx::log10, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`log10`.") |
| | .def( |
| | "sin", |
| | &mx::sin, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`sin`.") |
| | .def( |
| | "cos", |
| | &mx::cos, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`cos`.") |
| | .def( |
| | "log1p", |
| | &mx::log1p, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`log1p`.") |
| | .def( |
| | "all", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::all(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`all`.") |
| | .def( |
| | "any", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::any(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`any`.") |
| | .def( |
| | "moveaxis", |
| | &mx::moveaxis, |
| | "source"_a, |
| | "destination"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`moveaxis`.") |
| | .def( |
| | "swapaxes", |
| | &mx::swapaxes, |
| | "axis1"_a, |
| | "axis2"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`swapaxes`.") |
| | .def( |
| | "transpose", |
| | [](const mx::array& a, nb::args axes_, mx::StreamOrDevice s) { |
| | if (axes_.size() == 0) { |
| | return mx::transpose(a, s); |
| | } |
| | std::vector<int> axes; |
| | if (!nb::isinstance<int>(axes_[0])) { |
| | axes = nb::cast<std::vector<int>>(axes_[0]); |
| | } else { |
| | axes = nb::cast<std::vector<int>>(axes_); |
| | } |
| | return mx::transpose(a, axes, s); |
| | }, |
| | "axes"_a, |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | Equivalent to :func:`transpose` but the axes can be passed either as |
| | a tuple or as separate arguments. |
| | |
| | See :func:`transpose` for full documentation. |
| | )pbdoc") |
| | .def_prop_ro( |
| | "T", |
| | [](const mx::array& a) { return mx::transpose(a); }, |
| | "Equivalent to calling ``self.transpose()`` with no arguments.") |
| | .def( |
| | "sum", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`sum`.") |
| | .def( |
| | "prod", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`prod`.") |
| | .def( |
| | "min", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::min(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`min`.") |
| | .def( |
| | "max", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::max(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`max`.") |
| | .def( |
| | "logcumsumexp", |
| | [](const mx::array& a, |
| | std::optional<int> axis, |
| | bool reverse, |
| | bool inclusive, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::logcumsumexp(a, *axis, reverse, inclusive, s); |
| | } else { |
| | |
| | |
| | return mx::logcumsumexp( |
| | mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); |
| | } |
| | }, |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "reverse"_a = false, |
| | "inclusive"_a = true, |
| | "stream"_a = nb::none(), |
| | "See :func:`logcumsumexp`.") |
| | .def( |
| | "logsumexp", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::logsumexp( |
| | a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`logsumexp`.") |
| | .def( |
| | "mean", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`mean`.") |
| | .def( |
| | "std", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | int ddof, |
| | mx::StreamOrDevice s) { |
| | return mx::std( |
| | a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); |
| | }, |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | "ddof"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`std`.") |
| | .def( |
| | "var", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | int ddof, |
| | mx::StreamOrDevice s) { |
| | return mx::var( |
| | a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); |
| | }, |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | "ddof"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`var`.") |
| | .def( |
| | "split", |
| | [](const mx::array& a, |
| | const std::variant<int, mx::Shape>& indices_or_sections, |
| | int axis, |
| | mx::StreamOrDevice s) { |
| | if (auto pv = std::get_if<int>(&indices_or_sections); pv) { |
| | return mx::split(a, *pv, axis, s); |
| | } else { |
| | return mx::split( |
| | a, std::get<mx::Shape>(indices_or_sections), axis, s); |
| | } |
| | }, |
| | "indices_or_sections"_a, |
| | "axis"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`split`.") |
| | .def( |
| | "argmin", |
| | [](const mx::array& a, |
| | std::optional<int> axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::argmin(a, *axis, keepdims, s); |
| | } else { |
| | return mx::argmin(a, keepdims, s); |
| | } |
| | }, |
| | "axis"_a = std::nullopt, |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`argmin`.") |
| | .def( |
| | "argmax", |
| | [](const mx::array& a, |
| | std::optional<int> axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::argmax(a, *axis, keepdims, s); |
| | } else { |
| | return mx::argmax(a, keepdims, s); |
| | } |
| | }, |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`argmax`.") |
| | .def( |
| | "cumsum", |
| | [](const mx::array& a, |
| | std::optional<int> axis, |
| | bool reverse, |
| | bool inclusive, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::cumsum(a, *axis, reverse, inclusive, s); |
| | } else { |
| | |
| | |
| | return mx::cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s); |
| | } |
| | }, |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "reverse"_a = false, |
| | "inclusive"_a = true, |
| | "stream"_a = nb::none(), |
| | "See :func:`cumsum`.") |
| | .def( |
| | "cumprod", |
| | [](const mx::array& a, |
| | std::optional<int> axis, |
| | bool reverse, |
| | bool inclusive, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::cumprod(a, *axis, reverse, inclusive, s); |
| | } else { |
| | |
| | |
| | return mx::cumprod( |
| | mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); |
| | } |
| | }, |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "reverse"_a = false, |
| | "inclusive"_a = true, |
| | "stream"_a = nb::none(), |
| | "See :func:`cumprod`.") |
| | .def( |
| | "cummax", |
| | [](const mx::array& a, |
| | std::optional<int> axis, |
| | bool reverse, |
| | bool inclusive, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::cummax(a, *axis, reverse, inclusive, s); |
| | } else { |
| | |
| | |
| | return mx::cummax( |
| | mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); |
| | } |
| | }, |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "reverse"_a = false, |
| | "inclusive"_a = true, |
| | "stream"_a = nb::none(), |
| | "See :func:`cummax`.") |
| | .def( |
| | "cummin", |
| | [](const mx::array& a, |
| | std::optional<int> axis, |
| | bool reverse, |
| | bool inclusive, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::cummin(a, *axis, reverse, inclusive, s); |
| | } else { |
| | |
| | |
| | return mx::cummin( |
| | mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); |
| | } |
| | }, |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "reverse"_a = false, |
| | "inclusive"_a = true, |
| | "stream"_a = nb::none(), |
| | "See :func:`cummin`.") |
| | .def( |
| | "round", |
| | [](const mx::array& a, int decimals, mx::StreamOrDevice s) { |
| | return mx::round(a, decimals, s); |
| | }, |
| | "decimals"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`round`.") |
| | .def( |
| | "diagonal", |
| | [](const mx::array& a, |
| | int offset, |
| | int axis1, |
| | int axis2, |
| | mx::StreamOrDevice s) { |
| | return mx::diagonal(a, offset, axis1, axis2, s); |
| | }, |
| | "offset"_a = 0, |
| | "axis1"_a = 0, |
| | "axis2"_a = 1, |
| | "stream"_a = nb::none(), |
| | "See :func:`diagonal`.") |
| | .def( |
| | "diag", |
| | [](const mx::array& a, int k, mx::StreamOrDevice s) { |
| | return mx::diag(a, k, s); |
| | }, |
| | "k"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | Extract a diagonal or construct a diagonal matrix. |
| | )pbdoc") |
| | .def( |
| | "conj", |
| | [](const mx::array& a, mx::StreamOrDevice s) { |
| | return mx::conjugate(to_array(a), s); |
| | }, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`conj`.") |
| | .def( |
| | "view", |
| | [](const ScalarOrArray& a, |
| | const mx::Dtype& dtype, |
| | mx::StreamOrDevice s) { return mx::view(to_array(a), dtype, s); }, |
| | "dtype"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | "See :func:`view`."); |
| | } |
| |
|